문제 정의
노드 개수가 \\(n\\)개인 트리가 주어집니다. 각 노드 \\(i\\)에는 두 정수값 \\(a_i\\)와 \\(b_i\\)가 할당되어 있습니다. 노드 선택 순서는 다음 조건을 만족해야 합니다: 어떤 노드를 선택하기 전에 그 노드의 모든 조상 노드가 먼저 선택되어야 합니다. 이러한 순서로 선택된 노드 배열 \\(p\\)에 대해 비용은 다음 식으로 계산됩니다:
\[\sum_{i=1}^n \left( b_{p_i} \times \sum_{j=i+1}^n a_{p_j} \right)\]
가능한 최대 비용을 계산하시오.
- \\(1 \leq n \leq 3 \times 10^5\\)
- \\(1 \leq a_i, b_i \leq 5000\\)
- 부모 노드 정보 \\(f_i\\)는 \\(f_i < i\\)를 만족
- 루트 노드의 경우 \\(a_1 = b_1 = 0\\)
입출력 예시
입력:
4
1 1 2
0 0
3 1
5 1
4 1
출력:
14
해결 전략
각 노드의 \\(\frac{b_i}{a_i}\\) 비율이 높을수록 초기에 선택될 때 전체 비용에 긍정적 기여를 합니다. 그러나 조상 선택 제약 조건이 존재하므로 다음과 같은 병합 접근법을 사용합니다:
- 루트 노드를 제외한 모든 노드를 \\(\frac{b_i}{a_i}\\) 기준으로 최대 힙에 추가
- 힙에서 최상위 노드를 추출하여 해당 노드와 조상 노드 병합
- 병합 시 \\(b_{\text{조상}} \times a_i\\) 값을 결과에 누적
- 병합된 노드의 \\(a\\)와 \\(b\\) 값을 조상 노드에 합산
- 트리가 완전히 병합될 때까지 반복
이 과정에서 조상 노드 추적은 Union-Find 자료구조를 활용해 효율적으로 처리합니다.
구현 코드
#include <queue>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAX_NODES = 300005;
int nodeParent[MAX_NODES];
LL totalCost;
struct TreeNode {
LL weightA, weightB;
int nodeID;
bool operator < (const TreeNode &target) const {
return weightB * target.weightA < target.weightB * weightA;
}
};
priority_queue<TreeNode> maxHeap;
int ufParent[MAX_NODES];
int findRoot(int x) {
if(ufParent[x] != x) ufParent[x] = findRoot(ufParent[x]);
return ufParent[x];
}
void mergeNodes(int x, int y) {
x = findRoot(x);
y = findRoot(y);
if(x != y) ufParent[x] = y;
}
int main() {
int numNodes;
scanf("%d", &numNodes);
for(int i = 1; i <= numNodes; ++i) ufParent[i] = i;
for(int i = 2; i <= numNodes; ++i)
scanf("%d", &nodeParent[i]);
for(int i = 1; i <= numNodes; ++i)
scanf("%lld %lld", &weightA[i], &weightB[i]);
for(int i = 2; i <= numNodes; ++i)
maxHeap.push({weightA[i], weightB[i], i});
while(!maxHeap.empty()) {
TreeNode cur = maxHeap.top();
maxHeap.pop();
if(cur.weightA != weightA[cur.nodeID] ||
cur.weightB != weightB[cur.nodeID])
continue;
int ancestor = findRoot(nodeParent[cur.nodeID]);
if(ancestor == 0) continue;
mergeNodes(cur.nodeID, ancestor);
totalCost += weightB[ancestor] * weightA[cur.nodeID];
weightA[ancestor] += weightA[cur.nodeID];
weightB[ancestor] += weightB[cur.nodeID];
maxHeap.push({weightA[ancestor], weightB[ancestor], ancestor});
}
printf("%lld\n", totalCost);
return 0;
}