접근

https://www.acmicpc.net/problem/2458

 

2458번: 키 순서

1번부터 N번까지 번호가 붙여져 있는 학생들에 대하여 두 학생끼리 키를 비교한 결과의 일부가 주어져 있다. 단, N명의 학생들의 키는 모두 다르다고 가정한다. 예를 들어, 6명의 학생들에 대하여

www.acmicpc.net

플로이드-와샬 문제라고 알고 풀었는데 아무리 봐도 dfs 로 푸는게 빠른 문제였다.

플로이드-와샬 알고리즘을 이용해서는 속도가 느려 pypy 컴파일러로만 통과할 수 있었고, dfs 를 이용해서는 python 컴파일러로도 충분히 시간 안에 들어올 수 있었다.

플로이드-와샬 알고리즘을 이용하면 각 노드에서 다른 노드로 이동할 수 있는 최단거리를 구할 수 있다. 키의 순서를 그래프의 화살표라고 가정하고 문제를 푼다면 자신의 키가 전체에서 몇번째인지 안다는 것은 다음과 같이 알 수 있다. 먼저 키의 순서를 화살표로 하여 그래프를 입력하고 플로이드-와샬로 풀어준다면, 예를 들어 i 번 학생에서 도달할 수 있는 곳들은 i 보다 크다는 것이고, j 번 학생이 i 번에 도달할 수 있다는 것은 i 번 학생이 더 크다는 것이다. 이를 이용하여 i 번 학생이 j 에 도달할 수 없고, j 번 학생도 i 번 학생에 도달할 수 없다면, 둘은 서로의 키 차이를 비교할 수 없다는 뜻이다. 그러므로 모든 학생과 키를 비교할 수 있다면 카운트 해주면 된다.

dfs 알고리즘 풀이는 더욱 간단해서, 처음에 입력받을 때 큰 쪽의 그래프와 작은 쪽의 그래프를 그려놓고 양쪽으로 dfs 를 하여 자신보다 큰 학생들과 자신보다 작은 학생들을 카운트한다. 둘의 합이 N - 1 이면 자신이 총 몇번째인지 알 수 있다.

코드

플로이드 와샬 알고리즘

import sys

input = sys.stdin.readline
INF = float("inf")

N, M = map(int, input().split(" "))
graph = [[INF] * (N + 1) for _ in range(N + 1)]
for i in range(1, N + 1):
    graph[i][i] = 0
for _ in range(M):
    a, b = map(int, input().split(" "))
    graph[a][b] = 1

for k in range(1, N + 1):
    for i in range(1, N + 1):
        for j in range(1, N + 1):
            graph[i][j] = min(graph[i][j], graph[i][k] + graph[k][j])

cnt = [0] * (N + 1)
for i in range(1, N + 1):
    for j in range(1, N + 1):
        if graph[i][j] != INF:
            cnt[i] += 1
            cnt[j] += 1
print(cnt.count(N + 1))

dfs

import sys

input = sys.stdin.readline
INF = float("inf")

N, M = map(int, input().split(" "))
taller_graph = [[] for _ in range(N + 1)]
shorter_graph = [[] for _ in range(N + 1)]
for _ in range(M):
    a, b = map(int, input().split(" "))
    taller_graph[a].append(b)
    shorter_graph[b].append(a)
taller = [set() for _ in range(N + 1)]
shorter = [set() for _ in range(N + 1)]
taller_num = [0] * (N + 1)
shorter_num = [0] * (N + 1)


def find_taller(i):
    if taller_num[i]:
        return taller[i]
    taller_num[i] = 1
    for j in taller_graph[i]:
        taller[i].add(j)
        taller[i] |= find_taller(j)
    return taller[i]


def find_shorter(i):
    if shorter_num[i]:
        return shorter[i]
    shorter_num[i] = 1
    for j in shorter_graph[i]:
        shorter[i].add(j)
        shorter[i] |= find_shorter(j)
    return shorter[i]


cnt = 0
for i in range(1, N + 1):
    find_taller(i)
    find_shorter(i)
    if len(taller[i]) + len(shorter[i]) == N - 1:
        cnt += 1
print(cnt)

더 생각해 볼 것?

...

코드나 내용 관련 조언, 부족한 점 및 질문 언제든 환영합니다!

반응형
  • 네이버 블러그 공유하기
  • 네이버 밴드에 공유하기
  • 페이스북 공유하기
  • 카카오스토리 공유하기