접근

파인드유니온 알고리즘을 모르지만 일단 구현되게 하여 풀기는 하였다. 하지만 파인드유니온 알고리즘을 처음 접했으니 제대로 공부하고 최적화 방법도 알아보고 코드를 다시 작성해보았다.

유니온 파인드 알고리즘이란?

그래프 알고리즘의 일종으로 상호배타적 집합, Disjoint-set이라고도 한다. 여러 노드가 존재할 때 어떤 두개의 노드를 묶어주고, 또 어떤 두 노드가 같은 집합에 속해있는지 확인하는 알고리즘이다.

  1. FIND: 노드 x가 어느 집합에 포함되어 있는지 찾는 연산
  2. UNION: 노드 x가 포함된 집합과 노드 y가 포함된 집합을 합치는 연산

유니온 파인드 구현

트리를 통해서 유니온파인드가 구현될 수 있다.
parent[i] 리스트를 작성하여 i노드의 부모노드라고 정의하고 초기화해준다.
parent[i] == i 일 경우, 루트 노드이다.

# parent 초기화
parent = [i for i in range(n + 1)

# find 함수 구현
def find(x):
    if parent[x] == x:
        return x
    return find(parent[x])

parent[x] == x 라면 x가 루트 노드이므로 return 해주고, 루트 노드가 아니라면 재귀적으로 연산하여 루트 노드를 return 해준다. 이와 같이 find를 구현할 경우 문제가 발생한다.

위와 같이 한쪽으로 치우쳐진 tree가 있을 경우, find 함수가 루트 노드를 찾는데 O(N)의 시간복잡도를 가지기 때문에 tree로 구현하는 이점이 없어진다. 이를 해결하기 위해 finde 함수를 아래와 같이 개선할 수 있다.

# find 함수 개선
def find(x):
    if parent[x] == x:
        return x
    p = find(parent[x])
    parent[x] = y
    return y

# 유니온 함수 구현
def union(x, y):
    x = find(x)
    y = find(y)
    if x != y:
        parent[y] = x

유니온 함수는 두개의 값을 받아 두 노드가 포함되어 있는 집합을 합쳐준다. find를 통해 각각의 루트를 찾아준 후, y의 부모 노드를 x로 바꾸어준다.

이 경우 발생하는 효율성 문제가 있다.

높이가 더 높은 트리가 높이가 낮은 트리 밑으로 들어가게 되면 트리가 점점 깊어질 위험이 있다. 따라서 트리의 높이가 낮은 트리 밑으로 들어가야 하는데, 이를 위해서는 트리의 높이를 기록해주어야 한다. 이를 위해 rank라는 리스트를 선언하고 초기화해준다.

# rank 구현
rank = [1 for i in range(n + 1)

이제 union 연산할 때 두개의 집합의 크기를 비교해주고 합쳐줄 때 크기를 갱신해준다.

def union(x, y):
    x = find(x)
    y = find(y)

    if x == y:
        return

    if rank[x] > rank[y]:
        parent[y] = x
        rank[x] += rank[y]
    else:
        parent[x] = y
        rank[y] += rank[x]

이 경우 결국 rank라는 리스트를 새로 사용하기 때문에 메모리를 두배로 사용하게 된다. 이를 개선하기 위해 Weighted Union Find 방법이 고안되었다.

기본적으로 유니온파인드 알고리즘과 비슷하지만 parent 배열에 저장하는 값이 조금 달라진다.

parent[i] 에 부모 노드를 저장하는 것은 동일하지만, i가 루트 노드일 경우 집합의 size를 음수로 저장하게 된다. 즉 parent[i] 가 음수일 경우 그 수의 절대값은 size이고, 양수일 경우 그 값은 부모 노드를 가리킨다.

예를 들어 parent[2] = -3 일 경우 2번 노드 밑에 두개의 노드가 더 있어서 총 3개의 노드가 집합을 이루고 있다는 뜻이고, parent[3] = 5 일 경우 3번 노드의 부모가 5번 노드라는 뜻이다.

이를 코드로 구현하면 아래와 같다.

parent = [-1 for i in range(n + 1)


def find(x):
    if parent[x] < 0:
        return x
    p = find(parent[x])
    parent[x] = p
    return p


def union(x, y):
    x = find(x)
    y = find(y)

    if x == y:
        return

    if parent[x] < parent[y]:
        parent[x] += parent[y]
        parent[y] = x
    else:
        parent[y] += parent[x]
        parent[x] = y

이를 기반으로 문제를 풀면 매우 빠르고 효율적으로 문제를 해결할 수 있다.

파인드 유니온 알고리즘 내용 참고: 유니온 파인드(Union - Find)

 

[Algorithm] 유니온 파인드(Union - Find)

유니온 파인드 알고리즘이란? 그래프 알고리즘의 일종으로서 상호 배타적 집합, Disjoint-set 이라고도 합니다. 여러 노드가 존재할 때 어떤 두 개의 노드를 같은 집합으로 묶어주고, 다시 어떤 두

ssungkang.tistory.com

코드

import sys


def find(x):
    if parent[x] < 0:
        return x
    p = find(parent[x])
    parent[x] = p
    return p


def union(x, y):
    x = find(x)
    y = find(y)

    if x == y:
        return

    if parent[x] < parent[y]:
        parent[x] += parent[y]
        parent[y] = x
    else:
        parent[y] += parent[x]
        parent[x] = y


n, m = map(int, sys.stdin.readline().split())
parent = [-1 for i in range(n + 1)]
for _ in range(m):
    k, a, b = map(int, sys.stdin.readline().split())
    if not k:
        union(a, b)
    if k:
        if find(a) == find(b):
            print("YES")
        else:
            print("NO")

더 생각해 볼 것?

그냥 풀렸다고 넘어갔으면 모를수 있을 방법들을 공부하는 것이 매우 중요한 것 같다.

코드나 내용 관련 조언, 부족한 점 및 질문 언제든 환영합니다!

반응형
  • 네이버 블러그 공유하기
  • 네이버 밴드에 공유하기
  • 페이스북 공유하기
  • 카카오스토리 공유하기