문제 설명
•
세그먼트 트리를 이용해 빈번한 수 변경과 구간 합 구하기를 수행하는 문제
예제 입력/출력
•
입력1
5 2 2
1
2
3
4
5
1 3 6
2 2 5
1 5 2
2 3 5
Plain Text
복사
•
출력1
17
12
Plain Text
복사
제약 조건
•
•
문제 풀이
접근1 누적합 배열을 이용해 구하기 -
•
접근2 트리 구조 이용하여 구하기 -
◦
세그먼트 트리를 이용하여 구간 합 쿼리와 값 변경 연산을 모두 에 가능하여 효율적으로 문제를 풀 수 있다.
◦
빨간색: 세그먼트 트리 인덱스
◦
파란색: 배열의 구간
◦
원 안 숫자: 특정 구간의 원본 배열의 누적합
입력 받기
N, M, K = map(int, input().split())
arr = []
for _ in range(N):
arr.append(int(input()))
tree = [0] * (4 * N)
Python
복사
◦
수의 개수 N, 수가 변경되는 횟수 M, 구간의 합을 구하는 횟수 K를 입력 받는다.
◦
다음으로 N개의 숫자를 입력받아 arr이라는 리스트에 넣는다.
◦
다음으로 우리가 만들어줄 세그먼트 트리의 크기를 4N으로 만들어 준다.
▪
세그먼트 트리의 크기는 완전 이진트리로 2의 제곱수 형태가 된다.
초기화 하기
def init(node, start, end):
if start == end:
tree[node] = arr[start]
return tree[node]
mid = (start + end) // 2
left_value = init(2 * node, start, mid)
right_value = init(2 * node + 1, mid + 1, end)
tree[node] = left_value + right_value
return tree[node]
init(1, 0, N - 1)
Python
복사
◦
init이라는 함수를 통해 세그먼트 트리를 초기화 한다.
◦
init 함수에는 3개의 매개변수가 들어간다.
1.
node: 현재 노드의 번호
2.
start: 범위의 시작 번호
3.
end: 범위의 끝 번호
◦
init(1, 0, N - 1)은 1번 노드가 0부터 N - 1의 범위를 가지고 있다는 뜻이 된다.
◦
init 함수는 이분 탐색과 비슷하게 재귀적으로 반씩 줄여나가며 트리를 초기화 한다.
◦
start와 end가 같아지는 종료조건이 되면 해당 노드에 초기 배열 값(arr[start])을 저장한다.
◦
그리고 재귀가 돌면서 tree[node] 값을 구성한다.
◦
리프 노드에 arr 리스트의 값이 들어있기 때문에 말단의 합을 구해가며 전체를 구성하게 된다.
구간 값 구하기
◦
인덱스가 1~4인 구간의 데이터 합을 구하기 위해선 위와 같이 색칠된 세 노드의 합만 구하면 된다.
▪
즉, 구간의 합은 ‘범위 안에 있는 경우’에 한해서만 더해주면 된다.
◦
코드로 구현하는 방법은 다음과 같다.
def find_tree(node, start, end, left, right):
# 1. 구하고 싶은 구간(left~right)가 현재 트리 구간(start~end)에 포함되지 않는 경우 (범위를 완전히 벗어났을 때)
if right < start or end < left:
return 0
# 2. 구하고 싶은 구간(left~right) 안에 현재 트리 구간(start~end)이 포함되는 경우 (범위안에 완전히 들어왔을 때)
if left <= start and end <= right:
return tree[node]
# 3. 그 외의 경우
mid = (start + end) // 2
left_value = find_tree(node * 2, start, mid, left, right)
right_value = find_tree(node * 2 + 1, mid + 1, end, left, right)
return left_value + right_value
Python
복사
◦
find_tree 함수를 통해 세그먼트 트리의 구간 합을 구할 수 있다.
◦
find_tree 함수에는 5개의 매개변수가 들어간다.
▪
node: 현재 노드의 번호
▪
start: 현재 트리 구간의 시작 번호
▪
end: 현재 트리 구간의 끝 번호
▪
left: 구하고자 하는 구간의 시작 번호
▪
right: 구하고자 하는 구간의 끝 번호
◦
동작 방식
1.
범위를 완전히 벗어났을 때
•
먼저 범위를 완전히 벗어나는 경우는 0을 리턴한다.
•
예를 들어, 1~4 구간의 합을 구하는데 5~6을 포함하는 노드라면, 해당 노드의 값은 더하지 않고 0을 반환한다.
2.
범위안에 완전히 들어왔을 때
•
반대로 현재 노드의 구간이 구하고자 하는 범위에 완전히 포함되는 경우 노드 값을 반환한다.
•
예를 들어, 1~4의 합을 구하려고 할 때, 현재 탐색 중인 노드가 3~4라면, 3~4 구간은 1~4 범위에 완전히 포함되므로, 자식 노드로 내려갈 필요 없이 현재 노드의 값을 바로 사용하면 된다.
3.
그 외의 경우 (부분적으로 겹칠 때)
•
현재 노드가 탐색하려는 범위와 일부 겹칠 경우, 자식 노드들을 계속 탐색해야 한다.
•
따라서, 왼쪽 자식과 오른쪽 자식을 탐색한 후 결과를 합산한다.
•
예를 들어, 1~4의 합을 구하는데, 현재 노드가 1~3, 4~5로 나뉘면 1~3은 사용하고 4~5에서 4만 선택해야 한다.
업데이트하기
•
특정 인덱스의 값을 수정할 때는 해당 인덱스를 포함하고 있는 모든 구간의 합 노드들을 차이값 만큼 갱신해주면 된다.
•
예를 들어, 인덱스 2의 값에 3을 더해줘서 6으로 수정한다고 하면 해당 인덱스를 포함하고 있는 노드들에 3을 더해주면 된다.
def update_tree(node, start, end, idx, diff):
# 범위를 완전히 벗어나는 경우
if idx < start or end < idx:
return
tree[node] += diff
if start != end:
mid = (start + end) // 2
update_tree(node * 2, start, mid, idx, diff)
update_tree(node * 2 + 1, mid + 1, end, idx, diff)
Python
복사
•
update_tree 함수를 통해 세그먼트 트리의 값을 갱신할 수 있다.
•
update_tree 함수에는 5개의 매개변수가 들어간다.
◦
node: 현재 노드의 번호
◦
start: 현재 트리 구간의 시작 번호
◦
end: 현재 트리 구간의 끝 번호
◦
idx: 변경하려는 값의 원본 배열 내 인덱스
◦
diff: 기존 값에서 변경된 차이값
•
find_tree와 마찬가지로 업데이트 하기 위한 범위를 벗어나면 리턴한다.
•
현재 트리 구간 안에 인덱스가 포함된다면 차이값인 diff만큼 노드를 업데이트 한다.
•
이 작업을 리프 노드에 도달하기 전까지 계속 업데이트하면 전체적으로 업데이트가 된다.
계산하기
for i in range(M + K):
a, b, c = map(int, input().split())
if a == 1:
diff = c - arr[b - 1]
arr[b - 1] = c
update_tree(1, 0, N - 1, b - 1, diff)
if a == 2:
print(find_tree(1, 0, N - 1, b - 1, c - 1))
Python
복사
◦
입력된 내용에 따라 a가 1인 경우는 업데이트를, a가 2인 경우는 구간의 합을 구해준다.
◦
주의할 점은 노드의 값들은 1부터 시작하지만 우리가 만든 세그먼트 트리의 노드는 0이기 때문에, b의 값에서 1을 빼줘야 한다.
◦
또한 a=2일 때 c의 값도 범위를 뜻하기 때문에 c의 값에서 1을 빼줘야 한다.
풀이 코드
import sys
input = sys.stdin.readline
N, M, K = map(int, input().split())
arr = []
for _ in range(N):
arr.append(int(input()))
tree = [0] * (4 * N)
def init(node, start, end):
if start == end:
tree[node] = arr[start]
return tree[node]
mid = (start + end) // 2
left_value = init(2 * node, start, mid)
right_value = init(2 * node + 1, mid + 1, end)
tree[node] = left_value + right_value
return tree[node]
init(1, 0, N - 1)
def find_tree(node, start, end, left, right):
# 1. 구하고 싶은 구간(left~right)가 현재 트리 구간(start~end)에 포함되지 않는 경우 (범위를 완전히 벗어났을 때)
if right < start or end < left:
return 0
# 2. 구하고 싶은 구간(left~right) 안에 현재 트리 구간(start~end)이 포함되는 경우 (범위안에 완전히 들어왔을 때)
if left <= start and end <= right:
return tree[node]
# 3. 그 외의 경우
mid = (start + end) // 2
left_value = find_tree(node * 2, start, mid, left, right)
right_value = find_tree(node * 2 + 1, mid + 1, end, left, right)
return left_value + right_value
def update_tree(node, start, end, idx, diff):
# 범위를 완전히 벗어나는 경우
if idx < start or end < idx:
return
tree[node] += diff
if start != end:
mid = (start + end) // 2
update_tree(node * 2, start, mid, idx, diff)
update_tree(node * 2 + 1, mid + 1, end, idx, diff)
for i in range(M + K):
a, b, c = map(int, input().split())
if a == 1:
diff = c - arr[b - 1]
arr[b - 1] = c
update_tree(1, 0, N - 1, b - 1, diff)
if a == 2:
print(find_tree(1, 0, N - 1, b - 1, c - 1))
Python
복사