Search

BOJ 2042 구간 합 구하기

태그
자료 구조
세그먼트 트리
생성일
2025/02/14 04:22

문제 설명

세그먼트 트리를 이용해 빈번한 수 변경과 구간 합 구하기를 수행하는 문제

예제 입력/출력

입력1
5 2 2 1 2 3 4 5 1 3 6 2 2 5 1 5 2 2 3 5
Plain Text
복사
출력1
17 12
Plain Text
복사

제약 조건

1N1,000,0001 ≤ N ≤ 1,000,000
1M,K10,0001 ≤ M, K ≤ 10,000

문제 풀이

접근1 누적합 배열을 이용해 구하기 - O((M+K)N)O((M + K) \cdot N)
접근2 트리 구조 이용하여 구하기 - O((M+K)logN)O((M + K) \cdot log N)
세그먼트 트리를 이용하여 구간 합 쿼리값 변경 연산을 모두 O(logN)O(log N) 에 가능하여 효율적으로 문제를 풀 수 있다.
빨간색: 세그먼트 트리 인덱스
파란색: 배열의 구간
원 안 숫자: 특정 구간의 원본 배열의 누적합

입력 받기

N, M, K = map(int, input().split()) arr = [] for _ in range(N): arr.append(int(input())) tree = [0] * (4 * N)
Python
복사
수의 개수 N, 수가 변경되는 횟수 M, 구간의 합을 구하는 횟수 K를 입력 받는다.
다음으로 N개의 숫자를 입력받아 arr이라는 리스트에 넣는다.
다음으로 우리가 만들어줄 세그먼트 트리의 크기를 4N으로 만들어 준다.
세그먼트 트리의 크기는 완전 이진트리로 2의 제곱수 형태가 된다.

초기화 하기

def init(node, start, end): if start == end: tree[node] = arr[start] return tree[node] mid = (start + end) // 2 left_value = init(2 * node, start, mid) right_value = init(2 * node + 1, mid + 1, end) tree[node] = left_value + right_value return tree[node] init(1, 0, N - 1)
Python
복사
init이라는 함수를 통해 세그먼트 트리를 초기화 한다.
init 함수에는 3개의 매개변수가 들어간다.
1.
node: 현재 노드의 번호
2.
start: 범위의 시작 번호
3.
end: 범위의 끝 번호
init(1, 0, N - 1)은 1번 노드가 0부터 N - 1의 범위를 가지고 있다는 뜻이 된다.
init 함수는 이분 탐색과 비슷하게 재귀적으로 반씩 줄여나가며 트리를 초기화 한다.
start와 end가 같아지는 종료조건이 되면 해당 노드에 초기 배열 값(arr[start])을 저장한다.
그리고 재귀가 돌면서 tree[node] 값을 구성한다.
리프 노드에 arr 리스트의 값이 들어있기 때문에 말단의 합을 구해가며 전체를 구성하게 된다.

구간 값 구하기

인덱스가 1~4인 구간의 데이터 합을 구하기 위해선 위와 같이 색칠된 세 노드의 합만 구하면 된다.
즉, 구간의 합은 ‘범위 안에 있는 경우’에 한해서만 더해주면 된다.
코드로 구현하는 방법은 다음과 같다.
def find_tree(node, start, end, left, right): # 1. 구하고 싶은 구간(left~right)가 현재 트리 구간(start~end)에 포함되지 않는 경우 (범위를 완전히 벗어났을 때) if right < start or end < left: return 0 # 2. 구하고 싶은 구간(left~right) 안에 현재 트리 구간(start~end)이 포함되는 경우 (범위안에 완전히 들어왔을 때) if left <= start and end <= right: return tree[node] # 3. 그 외의 경우 mid = (start + end) // 2 left_value = find_tree(node * 2, start, mid, left, right) right_value = find_tree(node * 2 + 1, mid + 1, end, left, right) return left_value + right_value
Python
복사
find_tree 함수를 통해 세그먼트 트리의 구간 합을 구할 수 있다.
find_tree 함수에는 5개의 매개변수가 들어간다.
node: 현재 노드의 번호
start: 현재 트리 구간의 시작 번호
end: 현재 트리 구간의 끝 번호
left: 구하고자 하는 구간의 시작 번호
right: 구하고자 하는 구간의 끝 번호
동작 방식
1.
범위를 완전히 벗어났을 때
먼저 범위를 완전히 벗어나는 경우는 0을 리턴한다.
예를 들어, 1~4 구간의 합을 구하는데 5~6을 포함하는 노드라면, 해당 노드의 값은 더하지 않고 0을 반환한다.
2.
범위안에 완전히 들어왔을 때
반대로 현재 노드의 구간이 구하고자 하는 범위에 완전히 포함되는 경우 노드 값을 반환한다.
예를 들어, 1~4의 합을 구하려고 할 때, 현재 탐색 중인 노드가 3~4라면, 3~4 구간은 1~4 범위에 완전히 포함되므로, 자식 노드로 내려갈 필요 없이 현재 노드의 값을 바로 사용하면 된다.
3.
그 외의 경우 (부분적으로 겹칠 때)
현재 노드가 탐색하려는 범위와 일부 겹칠 경우, 자식 노드들을 계속 탐색해야 한다.
따라서, 왼쪽 자식과 오른쪽 자식을 탐색한 후 결과를 합산한다.
예를 들어, 1~4의 합을 구하는데, 현재 노드가 1~3, 4~5로 나뉘면 1~3은 사용하고 4~5에서 4만 선택해야 한다.

업데이트하기

특정 인덱스의 값을 수정할 때는 해당 인덱스를 포함하고 있는 모든 구간의 합 노드들을 차이값 만큼 갱신해주면 된다.
예를 들어, 인덱스 2의 값에 3을 더해줘서 6으로 수정한다고 하면 해당 인덱스를 포함하고 있는 노드들에 3을 더해주면 된다.
def update_tree(node, start, end, idx, diff): # 범위를 완전히 벗어나는 경우 if idx < start or end < idx: return tree[node] += diff if start != end: mid = (start + end) // 2 update_tree(node * 2, start, mid, idx, diff) update_tree(node * 2 + 1, mid + 1, end, idx, diff)
Python
복사
update_tree 함수를 통해 세그먼트 트리의 값을 갱신할 수 있다.
update_tree 함수에는 5개의 매개변수가 들어간다.
node: 현재 노드의 번호
start: 현재 트리 구간의 시작 번호
end: 현재 트리 구간의 끝 번호
idx: 변경하려는 값의 원본 배열 내 인덱스
diff: 기존 값에서 변경된 차이값
find_tree와 마찬가지로 업데이트 하기 위한 범위를 벗어나면 리턴한다.
현재 트리 구간 안에 인덱스가 포함된다면 차이값인 diff만큼 노드를 업데이트 한다.
이 작업을 리프 노드에 도달하기 전까지 계속 업데이트하면 전체적으로 업데이트가 된다.

계산하기

for i in range(M + K): a, b, c = map(int, input().split()) if a == 1: diff = c - arr[b - 1] arr[b - 1] = c update_tree(1, 0, N - 1, b - 1, diff) if a == 2: print(find_tree(1, 0, N - 1, b - 1, c - 1))
Python
복사
입력된 내용에 따라 a가 1인 경우는 업데이트를, a가 2인 경우는 구간의 합을 구해준다.
주의할 점은 노드의 값들은 1부터 시작하지만 우리가 만든 세그먼트 트리의 노드는 0이기 때문에, b의 값에서 1을 빼줘야 한다.
또한 a=2일 때 c의 값도 범위를 뜻하기 때문에 c의 값에서 1을 빼줘야 한다.

풀이 코드

import sys input = sys.stdin.readline N, M, K = map(int, input().split()) arr = [] for _ in range(N): arr.append(int(input())) tree = [0] * (4 * N) def init(node, start, end): if start == end: tree[node] = arr[start] return tree[node] mid = (start + end) // 2 left_value = init(2 * node, start, mid) right_value = init(2 * node + 1, mid + 1, end) tree[node] = left_value + right_value return tree[node] init(1, 0, N - 1) def find_tree(node, start, end, left, right): # 1. 구하고 싶은 구간(left~right)가 현재 트리 구간(start~end)에 포함되지 않는 경우 (범위를 완전히 벗어났을 때) if right < start or end < left: return 0 # 2. 구하고 싶은 구간(left~right) 안에 현재 트리 구간(start~end)이 포함되는 경우 (범위안에 완전히 들어왔을 때) if left <= start and end <= right: return tree[node] # 3. 그 외의 경우 mid = (start + end) // 2 left_value = find_tree(node * 2, start, mid, left, right) right_value = find_tree(node * 2 + 1, mid + 1, end, left, right) return left_value + right_value def update_tree(node, start, end, idx, diff): # 범위를 완전히 벗어나는 경우 if idx < start or end < idx: return tree[node] += diff if start != end: mid = (start + end) // 2 update_tree(node * 2, start, mid, idx, diff) update_tree(node * 2 + 1, mid + 1, end, idx, diff) for i in range(M + K): a, b, c = map(int, input().split()) if a == 1: diff = c - arr[b - 1] arr[b - 1] = c update_tree(1, 0, N - 1, b - 1, diff) if a == 2: print(find_tree(1, 0, N - 1, b - 1, c - 1))
Python
복사

참고 자료