题目链接:
可以发现每条链上的所有点都要放在不同的段里,那么最多只需要树的深度这么多段就够了。
因为这样可以保证每条链上的点可以放在不同的段中而且一个点放在这些段中一定会比新开一个段更优。
那么我们先考虑一条链的情况,显然是先将较长的一条链计入答案,然后将另一条链上的点分别与长链上的点合并。
假设长链上的两段分别为$A,B$,其中$A>B$,那么对于另一条链上两个点$C,D$(假设$C>D$)一定是$A$与$C$合并、$B$与$D$合并最优(即大的与大的合并,小的与小的合并)。
可以分情况讨论证明一下:
假设$A>C>B>D$,$max(A,C)+max(B,D)=A+B<max(A,D)+max(B,C)=A+C$
假设$A>C>D>B$,$max(A,C)+max(B,D)=A+D<max(A,D)+max(B,C)=A+C$
假设$C>A>B>D$,$max(A,C)+max(B,D)=C+B<max(A,D)+max(B,C)=A+C$
假设$C>A>D>B$,$max(A,C)+max(B,D)=C+D<max(A,D)+max(B,C)=A+C$
假设$A>B>C>D$,$max(A,C)+max(B,D)=A+B=max(A,D)+max(B,C)=A+B$
假设$C>D>A>B$,$max(A,C)+max(B,D)=C+D=max(A,D)+max(B,C)=C+D$
那么我们只需要将两条链上点的权值都从大到小排序然后依次合并取最大值即可。
现在考虑树的情况,可以发现对于一个节点的所有子树,我们依旧可以按照上述从大到小的顺序将两个子树的最优分段方案合并。
所以只需要对于每个点维护一个大根堆维护这个点的子树分的所有段的权值即可。
因为每次合并的时间复杂度取决于两个堆中较小的那个堆的大小,所以可以将原树长链剖分,每个点继承重儿子的堆然后将其他儿子的堆与重儿子的堆合并,最后将这个点的权值加入堆中即可。
因为每个点被合并一次,所以时间复杂度是$O(nlog_{n})$。
因为每个点继承重儿子的堆,所以实际需要堆的数量是长链数。而且每个点的堆在被合并到其他堆之后就没用了,所以可以类似内存回收一样回收堆。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<vector> #include<bitset> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; priority_queue<int>q[200010]; int num[200010]; int son[200010]; int dep[200010]; int tot; int head[200010]; int to[200010]; int next[200010]; int n,x; int val[200010]; int mx[200010]; int cnt; int sum; ll ans; void add(int x,int y) { next[++tot]=head[x]; head[x]=tot; to[tot]=y; } void dfs(int x) { for(int i=head[x];i;i=next[i]) { dfs(to[i]); dep[x]=max(dep[x],dep[to[i]]+1); if(dep[to[i]]>=dep[son[x]]) { son[x]=to[i]; } } } void solve(int x) { if(son[x]) { solve(son[x]); num[x]=num[son[x]]; } for(int i=head[x];i;i=next[i]) { if(to[i]!=son[x]) { solve(to[i]); cnt=0; int a=num[to[i]],b=num[x]; while(!q[a].empty()) { int x=q[a].top(); int y=q[b].top(); q[a].pop(),q[b].pop(); if(x>y) { mx[++cnt]=x; ans-=y; } else { ans-=x; mx[++cnt]=y; } } for(int j=1;j<=cnt;j++) { q[b].push(mx[j]); } } } if(!son[x]) { num[x]=++sum; } ans+=val[x]; q[num[x]].push(val[x]); } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) { scanf("%d",&val[i]); } for(int i=2;i<=n;i++) { scanf("%d",&x); add(x,i); } dfs(1); solve(1); printf("%lld",ans); }