구간합 구하기
문제 해석
- 중간에 수의 변경이 빈번히 일어나고, 어떤 부분의 합을 구하는 문제입니다.
- N<= 10^6, M,K<=10^4
설계1(X)
- 세그먼트 트리를 구현하는 기본문제입니다.
- 세그먼트 트리의 개념을 정확하게 알지 못해 crocus님의 블로그에서 참조하였습니다.
- 먼저, 세그먼트 트리의 크기는 2^K>N조건을 만족하는 최소 2^k값의 2배가 됩니다. 이유는 세그먼트 트리는 완전이진트리로 구성이 되어있기 때문입니다.
- 첫번째로, 세그먼트 트리를 구하는 과정입니다.
- 1번 루트로 시작해서 왼쪽과 오른쪽 자식으로 타고 내려갑니다.
- 왼쪽자식의 범위와 오른쪽자식의 범위가 같을 때(start==end, 기저사례)가 될 때까지 init(왼쪽)+init(오른쪽)을 수행하게 됩니다.
- 2번의 경우가 되면 해당 세그먼트 트리의 노드는 리프노드가 됩니다.
- 두번째로, 특정 인덱스의 값을 변경하는 과정입니다.
- 먼저 diff(변경할 값 - 현재 값)를 구한 후, 1번부터 update를 합니다.
- 업데이트를 할 때, 해당 인덱스가 포함되어있지 않은 세그먼트트리는 무시합니다.
- 해당 인덱스가 포함되어있고, start!=end이면(자기 자신은 리프노드이기 때문에 더이상 나눌 수 없기 때문입니다.) 해당 노드의 자식들을 탐색합니다. 이 과정에서 diff를 더하게 됩니다. 결과적으로 특정 인덱스가 포함된 구간을 가지고 있는 세그먼트 트리의 노드를 diff씩 더해주는 셈이 됩니다.
- 마지막으로, 구간합(left,right)을 구하는 과정입니다.
- 먼저, 구하고 싶은 구간의 범위에 아예 벗어나는 경우는 무시합니다.
- [left,right]범위가 각 세그먼트트리 노드의 구간을 완전히 포함하면 해당 세그먼트 트리 노드의 구간합을 리턴합니다.
- 나머지로, 세그먼트 트리 노드의 구간이 [left,right]를 완전히 포함하거나, 서로 걸쳐있는 경우에는 각 왼쪽과 오른쪽 자식들을 탐색하도록 합니다.
- 위 3 과정을 거치게 되면 [left,right]에 해당하는 구간합을 구할 수 있습니다.
설계2(X)
- 세그먼트 트리와 같은 기능을 하되 메모리 소모를 줄일 수 있는 펜윅트리로도 구현이 가능합니다.
- 펜윅트리는 비트를 이용하는 구조로 되어있습니다.
- 첫번째로, 업데이트 과정이 있습니다.
- 구간이 포함된 노드에 가는 과정에 비트 연산을 살펴보면 bit(현재노드의 비트)를 bit의 1이 있는 최하위 비트에 1을 더하면 다음 상위 노드로 가는 것을 알 수 있습니다.
- 이를 구하기 위해서는 bit +=(bit & -bit)를 하면 됩니다. bit의 2의 보수와 and연산을 한 비트값이 최하위 비트에 있는 1을 의미합니다. O(logN)에 업데이트를 할 수 있습니다.
- 두번째로, 구간합[left,right]을 구하는 과정입니다.
- right를 업데이트와는 반대로 1이 있는 최하위 비트를 거꾸로 빼주는 작업을 하면 [1,right]의 구간합을 구할 수 있습니다.
- left에 대해서도 1번작업을 똑같이 해주면 [1,left] 구간합을 구할 수 있습니다.
- left<=right 가 명시되어있기 때문에 (right구간합)-(left-1구간합)을 구하면 (left,right)구간합을 구할 수 있게 됩니다. O(logN)+O(logN) => O(logN)에 구간합을 구할 수 있게 됩니다.
구현(X)
코드1 (세그먼트트리)
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;
int N,M,K;
ll init(vector<ll>& arr, vector<ll>& tree,int node ,int start ,int end){
if (start == end) return tree[node] = arr[start];
int mid = (start + end)/2;
return tree[node] = init(arr,tree,node*2,start,mid)+init(arr,tree,node*2+1,mid+1,end);
}
void update(vector<ll>& tree,int node, int start, int end, int index,ll diff){
if (!(start<=index && index <=end)) return;
tree[node] += diff;
if (start != end){
int mid = (start + end)/2;
update(tree,node*2,start,mid,index,diff);
update(tree,node*2+1,mid+1,end,index,diff);
}
}
ll sum(vector<ll>& tree,int node,int start, int end,int left, int right){
if (left>end || right <start) return 0;
if (left<=start && right >= end) return tree[node];
int mid = (start + end)/2;
return sum(tree,node*2,start,mid,left,right)+ sum(tree,node*2+1,mid+1,end,left,right);
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin >> N >> M >> K;
M+=K;
int h = 1;
while (N>h) h <<=1;
vector<ll> arr(N);
vector<ll> tree(h*2);
for (int i = 0 ;i <N; i++){
cin >> arr[i];
}
init(arr,tree,1,0,N-1);
while (M--){
int a;
cin >> a;
//update
if (a == 1){
int pos;
ll val;
cin >> pos >> val;
ll diff = val - arr[pos-1];
arr[pos-1] = val;
update(tree,1,0,N-1,pos-1,diff);
}
//sum
else{
int left,right;
cin >> left >> right;
cout << sum(tree,1,0,N-1,left-1,right-1)<<"\n";
}
}
return 0;
}
코드2 (펜윅트리)
#include <iostream>
#include <vector>
using namespace std;
typedef long long ll;
int N,M,K;
void update(vector<ll>& tree,int i, ll diff){
while (i<tree.size()){
tree[i] += diff;
i += (i & -i);
}
}
ll sum(vector<ll>& tree,int index){
ll ans = 0;
while (index >0){
ans += tree[index];
index -= (index & -index);
}
return ans;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin >> N >> M >> K;
M+=K;
vector<ll> arr(N+1);
vector<ll> tree(N+1);
for (int i = 1 ;i <= N; i++){
cin >> arr[i];
update(tree,i,arr[i]);
}
while (M--){
int a;
cin >> a;
//update
if (a == 1){
int pos;
ll val;
cin >> pos >> val;
ll diff = val - arr[pos];
arr[pos] = val;
update(tree,pos,diff);
}
//sum
else{
int left,right;
cin >> left >> right;
cout << sum(tree,right) - sum(tree,left-1)<<"\n";
}
}
return 0;
}
디버깅
- 문제에서 정답이 long long범위를 벗어나지 않고, 변경할 값 또한 long long범위를 벗어나지 않는다고 하여 각 배열의 변수형을 int형으로 구현하였습니다.
- 그러나, sum을 하는 과정에서 최악(int최대값이 diff이면)의 경우에 int범위를 벗어나는 것을 간과했습니다. 각 배열들을 long long형으로 변경하였습니다.
제출결과
- 26188KB / 152ms (세그먼트 트리)
- 17616KB / 160ms (펜윅트리)
마무리
- 세그먼트 트리를 구현하는 기본문제입니다. 완전2진트리의 특성을 이용해서 구간합을 구하는데에 O(logN)에 구할 수 있는 최적화에 필요한 알고리즘입니다.