미소를뿌리는감자의 코딩

[백준 2024/02/17] 1068번 트리 본문

코딩 테스트/백준

[백준 2024/02/17] 1068번 트리

미뿌감 2024. 2. 17. 21:08
728x90

https://www.acmicpc.net/problem/1068

 

1068번: 트리

첫째 줄에 트리의 노드의 개수 N이 주어진다. N은 50보다 작거나 같은 자연수이다. 둘째 줄에는 0번 노드부터 N-1번 노드까지, 각 노드의 부모가 주어진다. 만약 부모가 없다면 (루트) -1이 주어진다

www.acmicpc.net

 

1. 접근 방법

이번 문제는 리프 노드의 개수를 출력하는 것이 목표이다.

노드들을 부모 노드 : 자식 노드로 정렬을 해주었다.

이 문제에 대해서 defaultdict dictionary로 정리한 모습은 다음과 같다.

defaultdict(<class 'list'>, {-1: [0], 0: [1, 2], 2: [3, 4], 4: [5, 6], 6: [7, 8]})

-1 이 루트 노드를 가리키므로, -1에서 부터 시작한다.

 

-1 이 가지고 있는 value를 확인해본다. 그리고 해당 value를 key로 사용해서, value를 또 search 한다. 즉, 자식 노드를 찾아본다.

 

-1 -> 0 을 search

          0 -> 1 search

          0 -> 2 search

                  2 -> 3 search

                  2 -> 4 search 

 

이런식으로 BFS를 진행한다. 

0 -> 1 을 search 할 때, 1 은 자식 노드가 없다. ( dictionary 에 1을 key로 가지는 것이 없음을 통해 확인할 수 있다. )

즉, leaf node 이다. 이 경우에는, total 이라는 dictionary에 1을 추가해 주었다.

 

if i not in node_dict:
	total.append(i)
	return

 

또한 해당 노드가 k 와 같다면, 즉, 잘라지는 노드에 도달 했다면, return 을 해주었다.

하지만, 하나 확인해야하는 부분이 있다. 해당 노드를 잘랐을 때, 그 전 노드가 자식을 1개만 가지고 있는가 이다.

그 전 노드가 자식을 1개 가지고 있다면, 그리고 그 자식이 잘려져야 하는 노드라면, 잘렸을 때, 해당 노드는 리프 노드가 되기 때문이다.

 

따라서 다음과 같은 코드를 통해 k 노드가 잘린다면, 그 전 노드가 리프 노드가 되는 지 확인해 주었다.

if i == k:
	if len(node_dict[before]) == 1:
		total.append(before)
	return

 

또한 출력 하기 전에, 루트 노드가 잘리는 k 인지 확인하고, 루트 노드가 잘린다면 0 을 출력해주었다.

if root_node():
    print(0)
else:
    total = []
    BFS(-1, k, -2)
    print(len(total))

 

+ 반례를 찾을 때, 아래 링크가 많은 도움이 되었다.

https://www.acmicpc.net/board/view/132447

2. 코드

from collections import defaultdict
input()

def sort_nodes(node_dict):
    input_node = map(int, input().split())
    node_i = 0
    for node in input_node:
        node_dict[node].append(node_i)
        node_i += 1
    return node_dict

def BFS(i, k, before):
    if i == k:
        if len(node_dict[before]) == 1:
            total.append(before)
        return
    if i not in node_dict:
        total.append(i)
        return
    for node in node_dict[i]:
        BFS(node, k, i)

def root_node():
    if node_dict[-1][0] == k:
        return True


node_dict = defaultdict(list)
node_dict = sort_nodes(node_dict)
print(node_dict)
k = int(input())


if root_node():
    print(0)
else:
    total = []
    BFS(-1, k, -2)
    print(len(total))

728x90