접근

이전에 풀었던 도로 네트워크 문제와 유사한 문제였다. 오히려 조금 더 간단했던 것 같다. 도로 네트워크 문제에서 거리 저장 개념을 가져오고, k 번째 노드를 구하는 것도 LCA 구하는 것과 마찬가지로 탐색하면 된다. 물론 문제에서 요구하는 것이 많아서 코드는 훨씬 길어지게 되었다.

2021.07.02 - [코딩/백준 (Python)] - 백준 3176번: 도로 네트워크 (Python, PyPy3)

 

백준 3176번: 도로 네트워크 (Python, PyPy3)

접근 바로 이전에 풀었던 LCA 2 문제와 매우 유사하지만, 이번에는 도로의 길이 및 최대 거리, 최소 거리가 추가된 문제였다. 최소공통조상 문제와 동일하게 풀게 되면, 두 도시 사이는 최소공통

ca.ramel.be

기존 문제와 마찬가지로 모든 그래프를 tree 리스트에 거리와 함께 저장해준다. 이후, 1 을 루트로 하는 트리를 탐색하여 각 노드의 부모 노드를 찾아 저장한다. 이전 문제와 마찬가지로 거리를 저장해주었기 때문에 dp[i][j] 를 [i 의 2^j 번째 부모 노드, i 의 2^j 번째 부모 노드까지의 거리] 의 꼴로 저장해줄 수 있다. 이렇게 희소 테이블을 완성하고 나서 쿼리 계산에 들어간다.

쿼리를 입력받고 나면, 일단 두 노드 u, v 를 가지고 lca(최소 공통 조상) 노드가 무엇인지 탐색해준다.

쿼리의 첫번째 값이 1일 경우, lca 노드의 depth를 이용하여 lca-u 거리와 lca-v 거리를 구하여 더해주면 답을 구할 수 있다.

쿼리의 첫번째 값이 2일 경우, 두가지로 나누어 생각할 수 있다.

  1. k <= depth[u] - depth[lca] : 이 경우 u 에서부터 k 번째 노드를 탐색할 때 lca 에 도달하지 못하게 되므로, 간단하게 u 노드에서 k - 1 조상 노드를 찾아주면 된다.
  2. k > depth[u] - depth[lca] : 이 경우에는 u 에서부터 k 번째 노드를 탐색할 때 lca 에 도착한 후 v 쪽 트리를 따라 내려오게 된다. v로 향하는 자식 노드를 탐색하기 보다는 k 번째 노드와 v 노드의 depth 차이를 이용하여 v 노드에서부터 부모 방향으로 탐색해 올라가면 조금 더 간단하게 구할 수 있다.

코드

import sys
from collections import deque
from math import log2

# tree 입력
n = int(sys.stdin.readline())
tree = [[] for _ in range(n + 1)]
for _ in range(n - 1):
    a, b, c = map(int, sys.stdin.readline().split())
    tree[a].append([b, c])
    tree[b].append([a, c])

# tree 정렬, 각 노드의 부모 노드 및 depth 계산
depth = [0] * (n + 1)
parent = [[0, 0] for _ in range(n + 1)]
check = [False] * (n + 1)
q = deque([1])
check[1] = True
while q:
    now = q.popleft()
    for b, c in tree[now]:
        if not check[b]:
            q.append(b)
            depth[b] = depth[now] + 1
            parent[b][0] = now
            parent[b][1] = c
            check[b] = True

# 희소 테이블 초기화
logN = int(log2(n) + 1)
dp = [[[0, 0] for _ in range(logN)] for __ in range(n + 1)]
for i in range(n + 1):
    dp[i][0][0] = parent[i][0]
    dp[i][0][1] = parent[i][1]

# 희소 테이블 작성
for j in range(1, logN):
    for i in range(1, n + 1):
        dp[i][j][0] = dp[dp[i][j - 1][0]][j - 1][0]
        dp[i][j][1] = dp[i][j - 1][1] + dp[dp[i][j - 1][0]][j - 1][1]

# 쿼리
m = int(sys.stdin.readline())
for _ in range(m):
    I = list(map(int, sys.stdin.readline().split()))
    u, v = I[1], I[2]
    # u, v 노드의 최소공통조상 노드 탐색
    u2, v2 = u, v
    if depth[u2] < depth[v2]:
        u2, v2 = v2, u2
    diff = depth[u2] - depth[v2]
    for i in range(logN):
        if diff & 1 << i:
            u2 = dp[u2][i][0]
    if u2 == v2:
        lca = u2
    else:
        for i in range(logN - 1, -1, -1):
            if dp[u2][i][0] != dp[v2][i][0]:
                u2 = dp[u2][i][0]
                v2 = dp[v2][i][0]
        lca = dp[u2][0][0]
    # 쿼리 첫 값이 1일 경우, lca-u 거리 + lca-v 거리 합산을 통해 답 도출
    if I[0] == 1:
        cost = 0
        diff_u = depth[u] - depth[lca]
        diff_v = depth[v] - depth[lca]
        for i in range(logN):
            if diff_u & 1 << i:
                cost += dp[u][i][1]
                u = dp[u][i][0]
            if diff_v & 1 << i:
                cost += dp[v][i][1]
                v = dp[v][i][0]
        print(cost)
    # 쿼리의 첫 값이 2일 경우
    else:
        k = I[3]
        # k 번째 노드로 가는 길이 lca 를 거치지 않을 경우 u 의 k - 1 조상을 계산
        if k <= depth[u] - depth[lca]:
            diff = k - 1
            for i in range(logN):
                if diff & 1 << i:
                    u = dp[u][i][0]
            print(u)
        # k 번째 노드로 가는 길이 lca 를 거치지 않을 경우, 남은 거리를 계산하여 v 에서부터 계산
        else:
            diff = depth[v] + depth[u] - 2 * depth[lca] - k + 1
            for i in range(logN - 1, -1, -1):
                if diff & 1 << i:
                    v = dp[v][i][0]
            print(v)

더 생각해 볼 것?

...

코드나 내용 관련 조언, 부족한 점 및 질문 언제든 환영합니다!

반응형
  • 네이버 블러그 공유하기
  • 네이버 밴드에 공유하기
  • 페이스북 공유하기
  • 카카오스토리 공유하기