접근

처음에 아무리 생각해도 내 머리 속에서 떠올릴 수 있는 방법은 플로이드 와샬 알고리즘 하나였다. 그렇게 푸는 것이 아니란 것을 알면서도 코딩했더니 결과는 역시나 메모리초과였다.

방법을 고민하다가 알 방법이 없어 검색해보니 굉장히 놀라운 사실을 알 수 있었다.

임의의 노드에서 각 노드까지의 거리를 측정하여 최대 거리를 가지는 노드는 트리의 지름을 이루는 한 노드이다.

즉, 임의의 한 점에서 DFS나 BFS 알고리즘을 이용하여 각 노드까지의 거리를 구하고, 이 중 최대 거리를 갖는 노드에서 시작하여 다시 한번 각 노드까지의 최대 거리를 구한다면 그 최대 거리가 트리의 지름이 된다.

그 증명은 다음의 링크에서 확인할 수 있다.

트리의 지름 증명: 트리의 지름 구하기

 

트리의 지름 구하기

트리에서 지름이란, 가장 먼 두 정점 사이의 거리 혹은 가장 먼 두 정점을 연결하는 경로를 의미한다. 선형 시간안에 트리에서 지름을 구하는 방법은 다음과 같다: 1. 트리에서 임의의 정점 $x$를

blog.myungwoo.kr

코드

import sys

n = int(sys.stdin.readline())
tree = [[] for _ in range(n + 1)]
for _ in range(n):
    tmp = list(map(int, sys.stdin.readline().split()))
    i = 1
    while i != len(tmp) - 1:
        tree[tmp[0]].append([tmp[i], tmp[i + 1]])
        i += 2


def dfs(start, result):
    for e, d in tree[start]:
        if result[e] == 0:
            result[e] = result[start] + d
            dfs(e, result)


result1 = [0 for _ in range(n + 1)]
dfs(1, result1)  # 임의의 점 1에서 DFS 알고리즘을 이용하여 각 노드들까지의 거리 측정
result1[1] = 0
tmpmax = 0
tmpindex = 0
for i, x in enumerate(result1):
    if tmpmax < x:
        tmpmax = x
        tmpindex = i

result2 = [0 for _ in range(n + 1)]
dfs(tmpindex, result2)  # 첫번째 dfs 로 구해진 최대거리 노드에서 새로운 dfs 수행
result2[tmpindex] = 0
print(max(result2))

bfs 를 이용한 코드

from collections import deque
import sys

n = int(sys.stdin.readline())
tree = [[] for _ in range(n + 1)]
for _ in range(n):
    tmp = list(map(int, sys.stdin.readline().split()))
    i = 1
    while i != len(tmp) - 1:
        tree[tmp[0]].append([tmp[i], tmp[i + 1]])
        i += 2


def bfs(start):
    queue = deque()
    queue.append([start, 0])
    visited = [0 for _ in range(n + 1)]
    maxd = 0
    maxi = start
    visited[start] = 1
    while queue:
        nowi, nowd = queue.popleft()
        if maxd < nowd:
            maxd = nowd
            maxi = nowi
        for e, d in tree[nowi]:
            if visited[e] == 0:
                queue.append([e, nowd + d])
                visited[e] = 1
    return maxi, maxd


tmpindex, tmpmax = bfs(1)
finalindex, finalmax = bfs(tmpindex)
print(finalmax)

더 생각해 볼 것?

...

코드나 내용 관련 조언, 부족한 점 및 질문 언제든 환영합니다!

반응형
  • 네이버 블러그 공유하기
  • 네이버 밴드에 공유하기
  • 페이스북 공유하기
  • 카카오스토리 공유하기