접근

이번 기회를 통해 최소 신장 트리(MST, Minimum Spanning Tree) 알고리즘에 대해 공부하였다.

Spanning Tree란 그래프 내의 모든 정점을 포함하는 트리로써 그래프의 최소 연결 부분 그래프이다. n 개의 정점을 가지는 그래프의 최소 연결 간선의 개수는 n - 1 개이다. 그 중 Minimum Spanning Tree란 그래프의 간선이 가중치를 가지고 있을 때, Spanning Tree 중 그 가중치의 합이 가장 작은 트리를 의미한다.

이 MST를 구하는 알고리즘은 Kruskal 알고리즘과 Prim 알고리즘 두가지가 있다.

Kruskal MST 알고리즘

kruskal 알고리즘은 탐욕적인 방법(Greedy method)를 이용하여 정점을 연결하는 가중치의 최소 값을 구하는 방법이다.

  1. 그래프의 간선들을 가중치를 기준으로 오름차순으로 정렬한다.
  2. 정렬된 간선들을 순서대로 선택하되, 간선들이 싸이클을 이루지 않도록 주의하며 탐색한다.
  3. 간선의 개수가 n - 1 개가 될 때까지 반복한다.

Prim MST 알고리즘

Prim 알고리즘은 시작 정점으로부터 현재의 신장 트리에 연결된 간선 중 가장 작은 가중치를 가진 간선을 선택하여 정점을 확장해나가는 방식이다.

  1. 시작 정점을 기준으로 시작 정점에서 연결된 간선 중 가장 작은 가중치를 가진 간선으로 다음 정점과 연결한다.
  2. 추가된 정점을 포함하여 다시 한번 가장 작은 가중치를 가진 간선을 선택하여 정점을 연결한다.
  3. 간선의 개수가 n - 1 개가 될 때까지 반복한다.

최적화

기본적으로 두가지 알고리즘 모두 간선들의 가중치에 따라 정렬하고 가장 작은 가중치를 가진 간선을 선택해야 하는데, 최소값만 중요한 상황이기 때문에 최소 힙을 이용하여 그 시간을 단축시킬 수 있다. 코딩에는 heapq 모듈을 불러와 문제를 해결하였다.

알고리즘 관련 내용: [Algorithm] 최소 신장 트리란

 

[알고리즘] 최소 신장 트리(MST, Minimum Spanning Tree)란 - Heee's Development Blog

Step by step goes a long way.

gmlwjd9405.github.io

코드1

Kruskal 알고리즘

import sys
import heapq


def find(x):
    if parent[x] < 0:
        return x
    p = find(parent[x])
    parent[x] = p
    return p


def union(x, y):
    if parent[x] < parent[y]:
        parent[x] += parent[y]
        parent[y] = x
    else:
        parent[y] += parent[x]
        parent[x] = y


v, e = map(int, sys.stdin.readline().split())
parent = [-1 for _ in range(v + 1)]
graph = []
for _ in range(e):
    a, b, c = map(int, sys.stdin.readline().split())
    heapq.heappush(graph, [c, a, b])  # 가중치를 기준으로 정렬하기 위해 [c, a, b]로 저장해준다.
cnt = v - 1
ans = 0
while cnt:
    d, a, b = heapq.heappop(graph)  # 그래프 중 가장 작은 가중치를 가진 간선을 선택한다
    ar = find(a)
    br = find(b)
    if ar != br:  # 해당 간선의 루트 노드가 같지 않으면(해당 간선 선택으로 인한 싸이클 성립을 막기 위해)
        union(ar, br)
        ans += d
        cnt -= 1
print(ans)

코드2

Prim 알고리즘

import sys
import heapq


def find(x):
    if parent[x] < 0:
        return x
    p = find(parent[x])
    parent[x] = p
    return p


def union(x, y):
    if parent[x] < parent[y]:
        parent[x] += parent[y]
        parent[y] = x
    else:
        parent[y] += parent[x]
        parent[x] = y


v, e = map(int, sys.stdin.readline().split())
parent = [-1 for _ in range(v + 1)]
graph = [[] for _ in range(v + 1)]
for _ in range(e):
    a, b, c = map(int, sys.stdin.readline().split())
    graph[a].append([c, b])
    graph[b].append([c, a])
cnt = v - 1
tmp = graph[1][::]
heapq.heapify(tmp)
ans = 0
a = 1
while cnt:
    d, n = heapq.heappop(tmp)
    ar, nr = find(a), find(n)
    if ar != nr:
        union(ar, nr)
        ans += d
        for i in graph[n]:
            heapq.heappush(tmp, i)
        cnt -= 1
print(ans)

더 생각해 볼 것?

...

코드나 내용 관련 조언, 부족한 점 및 질문 언제든 환영합니다!

반응형
  • 네이버 블러그 공유하기
  • 네이버 밴드에 공유하기
  • 페이스북 공유하기
  • 카카오스토리 공유하기