트리 병합을 이용한 최대 비용 계산 알고리즘

문제 정의

노드 개수가 \\(n\\)개인 트리가 주어집니다. 각 노드 \\(i\\)에는 두 정수값 \\(a_i\\)와 \\(b_i\\)가 할당되어 있습니다. 노드 선택 순서는 다음 조건을 만족해야 합니다: 어떤 노드를 선택하기 전에 그 노드의 모든 조상 노드가 먼저 선택되어야 합니다. 이러한 순서로 선택된 노드 배열 \\(p\\)에 대해 비용은 다음 식으로 계산됩니다:

\[\sum_{i=1}^n \left( b_{p_i} \times \sum_{j=i+1}^n a_{p_j} \right)\]

가능한 최대 비용을 계산하시오.

  • \\(1 \leq n \leq 3 \times 10^5\\)
  • \\(1 \leq a_i, b_i \leq 5000\\)
  • 부모 노드 정보 \\(f_i\\)는 \\(f_i < i\\)를 만족
  • 루트 노드의 경우 \\(a_1 = b_1 = 0\\)

입출력 예시

입력:
4
1 1 2
0 0
3 1
5 1
4 1

출력:
14

해결 전략

각 노드의 \\(\frac{b_i}{a_i}\\) 비율이 높을수록 초기에 선택될 때 전체 비용에 긍정적 기여를 합니다. 그러나 조상 선택 제약 조건이 존재하므로 다음과 같은 병합 접근법을 사용합니다:

  1. 루트 노드를 제외한 모든 노드를 \\(\frac{b_i}{a_i}\\) 기준으로 최대 힙에 추가
  2. 힙에서 최상위 노드를 추출하여 해당 노드와 조상 노드 병합
  3. 병합 시 \\(b_{\text{조상}} \times a_i\\) 값을 결과에 누적
  4. 병합된 노드의 \\(a\\)와 \\(b\\) 값을 조상 노드에 합산
  5. 트리가 완전히 병합될 때까지 반복

이 과정에서 조상 노드 추적은 Union-Find 자료구조를 활용해 효율적으로 처리합니다.

구현 코드

#include <queue>
#include <cstdio>
#include <algorithm>
using namespace std;

typedef long long LL;
const int MAX_NODES = 300005;
int nodeParent[MAX_NODES];
LL totalCost;

struct TreeNode {
    LL weightA, weightB;
    int nodeID;
    bool operator < (const TreeNode &target) const {
        return weightB * target.weightA < target.weightB * weightA;
    }
};
priority_queue<TreeNode> maxHeap;

int ufParent[MAX_NODES];
int findRoot(int x) {
    if(ufParent[x] != x) ufParent[x] = findRoot(ufParent[x]);
    return ufParent[x];
}
void mergeNodes(int x, int y) {
    x = findRoot(x); 
    y = findRoot(y);
    if(x != y) ufParent[x] = y;
}

int main() {
    int numNodes;
    scanf("%d", &numNodes);
    for(int i = 1; i <= numNodes; ++i) ufParent[i] = i;
    for(int i = 2; i <= numNodes; ++i) 
        scanf("%d", &nodeParent[i]);
    for(int i = 1; i <= numNodes; ++i) 
        scanf("%lld %lld", &weightA[i], &weightB[i]);
    for(int i = 2; i <= numNodes; ++i) 
        maxHeap.push({weightA[i], weightB[i], i});
    
    while(!maxHeap.empty()) {
        TreeNode cur = maxHeap.top(); 
        maxHeap.pop();
        if(cur.weightA != weightA[cur.nodeID] || 
           cur.weightB != weightB[cur.nodeID]) 
            continue;
        
        int ancestor = findRoot(nodeParent[cur.nodeID]);
        if(ancestor == 0) continue;
        
        mergeNodes(cur.nodeID, ancestor);
        totalCost += weightB[ancestor] * weightA[cur.nodeID];
        weightA[ancestor] += weightA[cur.nodeID];
        weightB[ancestor] += weightB[cur.nodeID];
        maxHeap.push({weightA[ancestor], weightB[ancestor], ancestor});
    }
    printf("%lld\n", totalCost);
    return 0;
}

태그: 그리디 알고리즘 유니온-파인드 트리 구조 C++

5월 23일 15:28에 게시됨