미소를뿌리는감자의 코딩

[백준 2024/10/02] 2098번 외판원 순회 본문

코딩 테스트/백준

[백준 2024/10/02] 2098번 외판원 순회

미뿌감 2024. 10. 3. 00:07
728x90

1. 문제

https://www.acmicpc.net/problem/2098

 

2. 접근 방법

https://www.geeksforgeeks.org/travelling-salesman-problem-using-dynamic-programming/

 

Travelling Salesman Problem using Dynamic Programming - GeeksforGeeks

Travelling Salesman Problem (TSP): Given a set of cities and distance between every pair of cities, the problem is to find the shortest p ossible route that

www.geeksforgeeks.org

 

내가 사랑하는 긱스 뽈 긱스,,, 코드를 참고하였다.

 

메모리 초과를 방지하기 위해 '비트 마스크'를 사용해야 한다.

예를 들어 도시가 6개 있다고 가정하자.

 

0으로 방문하지 않았음을, 1으로 방문하였음을 나타내고, 100101 -> 이라고 마크를 해두면, 1번 도시, 3번 도시 그리고 6번 도시를 방문한 게 되는 것이다. 오른쪽 끝부터 1번 도시라고 생각하며 왼쪽으로 갈수록 2번 도시, 3번 도시, ... 을 count하게 된다.

 

이를 통해서 메모리 사용을 줄일 수 있다.

 

비트 마스크를 사용하기 위한 basic 이론들에 대해서 알아보자.

비트 마스크의 경우 우선 계산되는 것들에 대해서 괄호를 잘 사용해 주어야 한다.

 

1 << (n + 1) 

print(bin(1 << (3+1)))
# 결과 : 0b10000

왼쪽으로 4번 shift 한 결과이다. bin()을 이용하여 2진수로 출력할 수 있다.

0b10000의 10진수 값은 2^4 * 1 + 2 ^ 3 * 0 + 2^2 * 0 + 2^1 * 0 2^0 * 0 = 16이다.

 

memoization을 사용하기 위해 

memo = [[-1] * (1 << (n+1)) for _ in range(n+1)]

이를 선언해 준다. ( 1 << (n+1) ) 을 곱한 이유는 4개의 도시 방문 유무가 16가지가 될 것이기 때문이다.

 

0000 - 방문 안함 (0)

0001 - 1번 도시 방문 (1)

0010 - 2번 도시 방문 (2)

0011 - 1번, 2번 도시 방문 (3)

0100 - 3번 도시 방문 (4)

0101 - 1번, 3번 도시 방문 (5)

0110 - 1번, 2번 도시 방문 (6)

0111 - 1번, 2번, 3번 도시 방문 (7)

1000 - 4번 도시 방문 (8)

1001 - 1번, 4번 도시 방문 (9)

1010 - 1번, 3번 도시 방문 (10)

1011 - 1번, 3번 4번 도시 방문 (11)

1100 - 3번, 4번 도시 방문 (12)

1101 - 1번, 3번, 4번 도시 방문 (13)

1110 - 2번, 3번, 4번 도시 방문 (14)

1111 - 1번, 2번, 3번, 4번 도시 방문 (15)

 

가 되게 되는 것이다.

memo 가 사용되는 방법은, 특정 도시에서 도시를 방문한 기록을 이용하는 것이다.

이전에, 똑같이 특정 도시에 방문하였고, 해당 도시를 방문한 기록이 동일할 때, 더 밑으로 탐색을 진행했을 것이다. 이를 기록으로 남겨둔 것이고, 다시 해당 도시에 똑같은 상황으로 접근하였을 때, memo[i][mask] 를 return 해주어, 중복된 탐색을 방지한다.

 

( 1 << ( n + 1 ) ) - 1)

 

해당 부분은 도시 방문 초기값을 설정해 주는 단계이다.

n이 3일 때의 출력 결과는 다음과 같다.

print(bin((1 << (3+1))- 1))
#결과 : 0b1111

 

0b10000 의 결과에서, -1을 하므로, 2의 보수 값을 가지게 된다. 이에 0b1111이라는 결과를 가지게 되는 것이다.

1은 방문한 상태임을 의미하는데, 여기서 역으로 0000을 찾으로 가는 방식으로 문제를 해결할 것이다.

 

if mask == ((1 << i ) | 3 )

 

이 부분은 방문한 도시가 i 번째 도시인지를 확인하는 if 문이다.

i가 3일 때의 결과는 아래와 같다.

print(bin(((1 << 3) | 3)))
# 결과 : 0b1011

 

1은 방문한 도시를 뜻하기에, 이는 3번 도시, 1번도시 그리고 0번도시를 방문한 상황이다.

(0번 도시는 의미가 없다. index랑 도시 번호를 맞춰주기 위한 용도)

 

만약 1번 도시에서 tsp가 출발했을 때, 처음 도착한 도시가 i인 것이다. 따라서, dist[1][i]를 반환해 준다.

    if mask == ((1 << i) | 3):
        return dist[1][i]

 

 

(mask & (1<< j)) != 0

 

이는 j 번째 도시가 방문했는지 유무를 확인하기 위한 문이다.

1000 3번 도시가 방문 되었음을 의미한다. 이는 000000..... 1000 이므로 mask 과 & 연산을 진행하였을 때, 3번 도시가 0일 경우엔, 1 & 0 의 연산을 통해서 0이 반환되게 된다. 즉 != 0 이라는 것은 3번 도시가 방문된 상태일 때는 False를 반환, 방문 안한 상태이면 0 != 0 으로 False를 반환한다.

 

(mask & (~(1 << i)))

 

이는 i 번째 도시를 방문하지 않았다고 나타내기 위한 용도이다.

3이라 가정했을 때, 

1000 의 역은 1111101111 이므로, 3번 도시의 값이 1이었다면, 0으로 바뀌게 된다. 이를 통해서 3번 도시가 방문되지 않았음을 기록할 수 있다.

 

for i in range(2, n+1, 1):
    min_cost = min(min_cost, get_min(i, (1 << (n+1))- 1) + dist[i][1])

 

이렇게 for문을 돌게 된다. 

시작은 1번 도시로 고정이되, 마지막으로 방문하는 도시를 i 번째 도시라고 가정하였다.

이를 통해서 마지막에 1번 도시를 방문하는 경우 + 2번 도시를 방문하는 경우 + .... 4번 도시를 방문하는 경우 => 모든 경우의 수를 탐색하게 된다.

 

if mask == ((1 << i) | 3):
    return dist[1][i]

 

이는 위에서 언급하였듯이, 모든 도시를 방문한 상태에서 0000..00의 상태를 찾으러 가는 과정이다. 

따라서 000001011 이라면, 3번도시, 1번도시 그리고 0 번 도시를 방문한 상태를 의미하기 때문에 1번도시에서 3번도시를 방문하는 값을 반환해 주면 된다.

 

for j in range(1, n+1, 1):
    if j != i and j != 1 and (mask & (1 << j)) != 0 and dist[j][i] != sys.maxsize:
        res = min(res, get_min(j, (mask & (~(1 << i)))) + dist[j][i])

 

j != i 현재 도시를 다시 방문하지 않도록 하고, j != 1로 시작 도시를 방문하지 않도록 한다.

 

3. 코드

import sys


def get_min(i, mask):

    if mask == ((1 << i) | 3):
        return dist[1][i]

    if memo[i][mask] != -1:
        return memo[i][mask]

    res = sys.maxsize
    for j in range(1, n+1, 1):
        if j != i and j != 1 and (mask & (1 << j)) != 0 and dist[j][i] != sys.maxsize:
            res = min(res, get_min(j, (mask & (~(1 << i)))) + dist[j][i])

    memo[i][mask] = res
    return res


if __name__ == "__main__":
    n = int(input())
    memo = [[-1] * (1 << (n+1)) for _ in range(n+1)]

    dist = [[0] * (n+1) for _ in range(n+1)]

    for i in range(1, n+1):
        d = list(map(int, input().split()))
        for j in range(1, n+1):
            dist[i][j] = d[j-1] if d[j-1] != 0 else sys.maxsize

    min_cost = sys.maxsize
    for i in range(2, n+1, 1):
        min_cost = min(min_cost, get_min(i, (1 << (n+1))- 1) + dist[i][1])

    print(min_cost)
728x90