Skip to main content

834. Sum of Distances in Tree

·341 words·2 mins· loading
Table of Contents

문제링크

설명
#

💀 0~n-1까지 방향이 없는 트리와, 연결고리edges가 주어졌을때,
answer[i] = i번째 노드와 다른 모든 노드들과 거리의 합 를 만족하는 answer을 리턴하기

def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
#  n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]

풀이과정
#

1. 배열 2개 answer,node_cnt 생성
#

answer : 정답용 answer[i] = 다른 모든 노드들과의 거리 합 node_cnt : 현재 정점의 누적 노드 개수

2. root를 기준으로 answer,node_cnt 채우기
#

루트를 기준으로, 밑에서 위로올라오면서 배열 채우기

answer[i] = answer[i의 자식]+node_cnt[i의 자식] 의 합 node_cnt[i] = i의 자식노드들 정점의 수 +1 (자기자신)

## dfs1()

//node_cnt 
     [6]#0       #root =  1+4+1=6
    /   \
 [1]#1    [4]#2  #1=1  #2=1+3 = 4
  		/ | \
      [1][1][1]    
      
      

//answer 
     [8]#0       #root = (0+1)+(3+4) =8
    /   \
 [0]#1    [3]#2   #1=> 0+(0*0) =0   #2=0+(3*1) = 3
  		/ | \
      [0][0][0]   #3,4,5=> 0+0 = 0        

3.각 노드의 answer,node_cnt 채우기
#

한번 순회했으면, node_cnt는 완성된 상태고 answer[0]의 값 (루트) 는 건들지 않아도 됨👍 하지만 answer의다른 노드들은 밑에서 올라온 값만 계산한 상태임 따라서, 이번에는 위에서 내려가면서 배열을 갱신해줌

**answer[i] =(answer[i의부모] - node_cnt[i]) + (n-node_cnt[i]) ** 생각을 해보자 만일 2번 노드에 대한 정답을 구하려면, 루트노드와 2번의 하위 노드들을 토대로 구해야 한다

  • answer[0](루트)는 이미 2번에서 올라온 값을 가지고 있으므로, 이 값들은 중복됨. 따라서 node_cnt[i]만큼을 빼줘야 함
  • 반대로, 2번의 하위 노드가 아닌 얘들은, 길이가 하나씩 늘어날거임 => (전체 개수-node_cnt[2])를 하면 2번에 속하지 않은 애들 개수가 나오니, 이를 더해줌
## dfs2()

//node_cnt 
     [6]#0       
    /   \
 [1]#1  [4]#2
  		/ | \
      [1][1][1] 
      
      

//answer (위에서부터 계산됨)
     [8]#0       #root는 그대로
    /   \
 [12]#1  [6]#2   #1=>  (8-1)+(6-1)= 12   #2=>(8-4)+(6-4)=6
  		/ | \
      [5][5][5]   #3,4,5=> (6-1)+(6-1) = 5         

이러면 answer이 완성됨

전체 코드
#

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        tree = [[] for _  in range(n)]
        answer = [0 for _ in range(n)]
        node_cnt = [1 for _ in range(n)]

        for a,b in edges:
            tree[a].append(b)
            tree[b].append(a)


        def dfs(now, parent):
            for child_node in tree[now]:
                if child_node == parent: 
                    continue;

                dfs(child_node,now)
                node_cnt[now] += node_cnt[child_node]
                answer[now] += answer[child_node] + node_cnt[child_node]

        dfs(0,-1)

        def dfs2(now,parent):
            for child_node in tree[now]:
                if child_node == parent:
                    continue

                answer[child_node] = (answer[now]-node_cnt[child_node]) + (n-node_cnt[child_node])
                dfs2(child_node, now)


        dfs2(0,-1)

        return answer