SSAFY10-Class5-Algorithm / BOJ

📘SSAFY 10기 5반의 백준 문제 알고리즘 스터디
https://www.acmicpc.net/
5 stars 6 forks source link

[Python] 1504 특정한 최단 경로 #6

Open jehunyoo opened 1 year ago

jehunyoo commented 1 year ago

접근 방법

시작 지점이 정해져있고 최단 거리를 구해야 하므로 Dijkstra's algorithm을 사용하는 문제이다. 거쳐가야 하는 두 개의 서로 다른 정점 v1, v2가 있으므로 (1번 정점과 N번 정점이 아니라는 조건이 있는데 왜 있는건지는 모르겠다.) 다음의 두 가지 경우를 고려하면 된다.

  1. 1 -> v1 -> v2 -> N - (경로 1)
  2. 1 -> v2 -> v1 -> N - (경로 2)

따라서 Dijkstra's algorithm을 통해 다음을 구한다.

  1. 1에서 모든 정점까지의 최단 거리 - (1)
  2. v1에서 모든 정점까지의 최단 거리 - (2)
  3. v2에서 모든 정점까지의 최단 거리 - (3) (실제로 문제풀 때에는 1에서 v1까지의 최단 거리, 1에서 v2까지의 최단 거리, v1에서 N까지의 최단거리, v2에서 N까지의 최단거리 v1에서 v2까지의 최단 거리 총 5번을 구했다. (비효율적입니다. 저처럼 풀지 마세요!))

(1), (2), (3)에서 처럼 최단 거리 list를 얻으면,

이제 마지막으로 (경로 1)과 (경로 2)의 최단 거리를 구하고 둘 중에 더 작은 값이 정답이다. 만약 두 값 모두 INFINITY로 설정한 값보다 크거나 같다면 조건을 만족하는 경로가 없으므로 -1을 출력한다.

구현 방법

$E$의 최댓값이 충분히 크기 때문에 입력 시간을 줄이기 위해 sys.stdin.readline을 사용하는 것이 좋다.

import sys
input = sys.stdin.readline # 이렇게 하면 평소처럼 input을 쓰면 된다.

최단 경로를 구하는 과정에서 "방문하지 않은 가장 가까운 노드"를 찾아야하는데 heap(또는 우선순위 큐)를 사용하는 것이 완전 탐색$O(N)$하는 것보다 효율적이다 heap의 삽입, 삭제 연산은 $O(logN)$의 시간 복잡도를 갖기 때문이다.

# python 3.9+
def dijkstra(graph: list[list[int]], start: int, end: int) -> list[int]:
    distance = [INF for _ in range(N + 1)]
    distance[start] = 0
    heap = []
    heapq.heappush(heap, (0, start)) # (거리, 정점 번호)의 튜플 형태로 heap에 삽입하면 튜플의 첫 번째 요소(=거리)를 기준으로 한다.

    while heap:
        dist, node = heapq.heappop(heap)
         if distance[node] < dist:
            continue
         for neighbor, d in enumerate(graph[node][1:], 1): # 여기서 graph는 2차원 배열
            if d != INF and (shorter := dist + d) < distance[neighbor]:
                distance[neighbor] = shorter
                heapq.heappush(heap, (shorter, neighbor))

    return distance

전체 코드

import sys
import heapq
input = sys.stdin.readline
INF = int(1e9)

def shortest(graph, start, end) -> int:
    distance = [INF for _ in range(N + 1)]
    distance[start] = 0
    heap = []
    heapq.heappush(heap, (0, start))

    while heap:
        dist, node = heapq.heappop(heap)
        if distance[node] < dist:
            continue
        for neighbor, d in enumerate(graph[node][1:], 1):
            if d != INF and (shorter := dist + d) < distance[neighbor]:
                distance[neighbor] = shorter
                heapq.heappush(heap, (shorter, neighbor))

    return distance[end]

N, E = map(int, input().split())
graph = [[INF for _ in range(N + 1)] for _ in range(N + 1)]
for i in range(N + 1):
    graph[i][i] = 0
for _ in range(E):
    a, b, c = map(int, input().split())
    graph[a][b] = graph[b][a] = c
v1, v2 = map(int, input().split())

bridge = shortest(graph, v1, v2)
answer = min(shortest(graph, 1, v1) + bridge + shortest(graph, v2, N), shortest(graph, 1, v2) + bridge + shortest(graph, v1, N))

if answer < INF:
    print(answer)
else:
    print(-1)
peppermintt0504 commented 1 year ago

코드가 깔끔하고 좋습니다~ 해당 문제 PR 주세요~