BOJ 2042

구간합 구하기

문제출처

문제 해석

  • 중간에 수의 변경이 빈번히 일어나고, 어떤 부분의 합을 구하는 문제입니다.
  • N<= 10^6, M,K<=10^4

설계1(X)

  • 세그먼트 트리를 구현하는 기본문제입니다.
  • 세그먼트 트리의 개념을 정확하게 알지 못해 crocus님의 블로그에서 참조하였습니다.
  • 먼저, 세그먼트 트리의 크기는 2^K>N조건을 만족하는 최소 2^k값의 2배가 됩니다. 이유는 세그먼트 트리는 완전이진트리로 구성이 되어있기 때문입니다.
  • 첫번째로, 세그먼트 트리를 구하는 과정입니다.
    1. 1번 루트로 시작해서 왼쪽과 오른쪽 자식으로 타고 내려갑니다.
    2. 왼쪽자식의 범위와 오른쪽자식의 범위가 같을 때(start==end, 기저사례)가 될 때까지 init(왼쪽)+init(오른쪽)을 수행하게 됩니다.
    3. 2번의 경우가 되면 해당 세그먼트 트리의 노드는 리프노드가 됩니다.
  • 두번째로, 특정 인덱스의 값을 변경하는 과정입니다.
    1. 먼저 diff(변경할 값 - 현재 값)를 구한 후, 1번부터 update를 합니다.
    2. 업데이트를 할 때, 해당 인덱스가 포함되어있지 않은 세그먼트트리는 무시합니다.
    3. 해당 인덱스가 포함되어있고, start!=end이면(자기 자신은 리프노드이기 때문에 더이상 나눌 수 없기 때문입니다.) 해당 노드의 자식들을 탐색합니다. 이 과정에서 diff를 더하게 됩니다. 결과적으로 특정 인덱스가 포함된 구간을 가지고 있는 세그먼트 트리의 노드를 diff씩 더해주는 셈이 됩니다.
  • 마지막으로, 구간합(left,right)을 구하는 과정입니다.
    1. 먼저, 구하고 싶은 구간의 범위에 아예 벗어나는 경우는 무시합니다.
    2. [left,right]범위가 각 세그먼트트리 노드의 구간을 완전히 포함하면 해당 세그먼트 트리 노드의 구간합을 리턴합니다.
    3. 나머지로, 세그먼트 트리 노드의 구간이 [left,right]를 완전히 포함하거나, 서로 걸쳐있는 경우에는 각 왼쪽과 오른쪽 자식들을 탐색하도록 합니다.
    4. 위 3 과정을 거치게 되면 [left,right]에 해당하는 구간합을 구할 수 있습니다.

설계2(X)

  • 세그먼트 트리와 같은 기능을 하되 메모리 소모를 줄일 수 있는 펜윅트리로도 구현이 가능합니다.
  • 펜윅트리는 비트를 이용하는 구조로 되어있습니다.
  • 첫번째로, 업데이트 과정이 있습니다.
    1. 구간이 포함된 노드에 가는 과정에 비트 연산을 살펴보면 bit(현재노드의 비트)를 bit의 1이 있는 최하위 비트에 1을 더하면 다음 상위 노드로 가는 것을 알 수 있습니다.
    2. 이를 구하기 위해서는 bit +=(bit & -bit)를 하면 됩니다. bit의 2의 보수와 and연산을 한 비트값이 최하위 비트에 있는 1을 의미합니다. O(logN)에 업데이트를 할 수 있습니다.
  • 두번째로, 구간합[left,right]을 구하는 과정입니다.
    1. right를 업데이트와는 반대로 1이 있는 최하위 비트를 거꾸로 빼주는 작업을 하면 [1,right]의 구간합을 구할 수 있습니다.
    2. left에 대해서도 1번작업을 똑같이 해주면 [1,left] 구간합을 구할 수 있습니다.
    3. left<=right 가 명시되어있기 때문에 (right구간합)-(left-1구간합)을 구하면 (left,right)구간합을 구할 수 있게 됩니다. O(logN)+O(logN) => O(logN)에 구간합을 구할 수 있게 됩니다.

구현(X)

코드1 (세그먼트트리)
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;

int N,M,K;
ll init(vector<ll>& arr, vector<ll>& tree,int node ,int start ,int end){
    if (start == end) return tree[node] = arr[start];
    int mid = (start + end)/2;
    return tree[node] = init(arr,tree,node*2,start,mid)+init(arr,tree,node*2+1,mid+1,end);
}
void update(vector<ll>& tree,int node, int start, int end, int index,ll diff){
    if (!(start<=index && index <=end)) return;
    tree[node] += diff;
    if (start != end){
        int mid = (start + end)/2;
        update(tree,node*2,start,mid,index,diff);
        update(tree,node*2+1,mid+1,end,index,diff);
    }
}
ll sum(vector<ll>& tree,int node,int start, int end,int left, int right){
    if (left>end || right <start) return 0;
    if (left<=start && right >= end) return tree[node];
    int mid = (start + end)/2;
    return sum(tree,node*2,start,mid,left,right)+ sum(tree,node*2+1,mid+1,end,left,right);
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> N >> M >> K;
    M+=K;
    int h = 1;
    while (N>h) h <<=1;
    vector<ll> arr(N);
    vector<ll> tree(h*2);
    for (int i = 0 ;i <N; i++){
        cin >> arr[i];
    }
    init(arr,tree,1,0,N-1);
    while (M--){
        int a;
        cin >> a;
        //update
        if (a == 1){
            int pos;
            ll val;
            cin >> pos >> val;
            ll diff = val - arr[pos-1];
            arr[pos-1] = val;
            update(tree,1,0,N-1,pos-1,diff);
        }
        //sum
        else{
            int left,right;
            cin >> left >> right;
            cout << sum(tree,1,0,N-1,left-1,right-1)<<"\n";
        }
    }
    return 0;
}
코드2 (펜윅트리)
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;

int N,M,K;

void update(vector<ll>& tree,int i, ll diff){
    while (i<tree.size()){
        tree[i] += diff;
        i += (i & -i);
    }
}
ll sum(vector<ll>& tree,int index){
    ll ans = 0;
    while (index >0){
        ans += tree[index];
        index -= (index & -index);
    }
    return ans;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cin >> N >> M >> K;
    M+=K;
    vector<ll> arr(N+1);
    vector<ll> tree(N+1);
    for (int i = 1 ;i <= N; i++){
        cin >> arr[i];
        update(tree,i,arr[i]);
    }

    while (M--){
        int a;
        cin >> a;
        //update
        if (a == 1){
            int pos;
            ll val;
            cin >> pos >> val;
            ll diff = val - arr[pos];
            arr[pos] = val;
            update(tree,pos,diff);
        }
        //sum
        else{
            int left,right;
            cin >> left >> right;
            cout << sum(tree,right) - sum(tree,left-1)<<"\n";
        }
    }
    return 0;
}
디버깅
  • 문제에서 정답이 long long범위를 벗어나지 않고, 변경할 값 또한 long long범위를 벗어나지 않는다고 하여 각 배열의 변수형을 int형으로 구현하였습니다.
  • 그러나, sum을 하는 과정에서 최악(int최대값이 diff이면)의 경우에 int범위를 벗어나는 것을 간과했습니다. 각 배열들을 long long형으로 변경하였습니다.
제출결과
  • 26188KB / 152ms (세그먼트 트리)
  • 17616KB / 160ms (펜윅트리)

마무리

  • 세그먼트 트리를 구현하는 기본문제입니다. 완전2진트리의 특성을 이용해서 구간합을 구하는데에 O(logN)에 구할 수 있는 최적화에 필요한 알고리즘입니다.