题意:
记 \(i\) 到 \(j\) 的路径颜色数为 \(s(i,j)\),对每个 \(i\) 求 \(sum_i=\sum\limits_{j=1}^n s(i,j)\)
\(\text{Solution}\)
考虑在点分治时 \(\text{dp}\)
设当前分治重心为 \(x\) , 正在处理子树 \(y\) 中全部点此时的贡献。
对每个颜色考虑 \(x\) 子树内出除 \(y\) 子树外所有点与当前点有多少条路径经过此颜色。
先分析一下 \(\text{dp}\) 时要干些什么。
如果一个点 \(u\) 到 \(x\) 之间没有与 \(x\) 相同颜色的点。
那如果 \(x\) 的另一个儿子的子树中的点 \(v\) 到 \(x\) 之间也没有与 \(u\) 同色的点。
那 \(u\) 的子树中的全部点到 \(v\) 的路径都经过了 \(u\) 的颜色,产生了 \(size(u)\) 的贡献。
但若 \(v\) 到 \(x\) 的路径上有与 \(u\) 同色的点,那当前分治对 \(v\) 产生的贡献就是 \(size(x)-size(y)\) , \(y\) 是 \(x\) 的一个子节点且 \(v\) 在 \(y\) 的子树内。
这样就能知道分治时要如何处理:
先统计出所有 \(x\) 的儿子的子树中,到 \(x\) 的路径上没有与这个点的颜色相同的点的子树大小之和,对每个颜色 \(a\) 统计出来的这个值之和为 \(sw_a\),并记所有的 \(sw\) 之和为 \(res\)。
那对于正在处理的 \(x\) 的一个儿子 \(y\) 来说,希望的到的是 \(x\) 的所有除 \(y\) 的子树的每个颜色的 \(sw\) 之和。
这可以再对 \(y\) 的子树扫一次,减去每个颜色的 \(sw\) 的一部分。
注意 \(x\) 的颜色的 \(sw\) 需减去 \(size(y)\)。
这样对 \(y\) 中的子树内的一个点 \(v\),就可以很方便地按之前分析的统计了。
时间复杂度仅仅是 \(O(n\log n)\)
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10;
int n,m,x,y,edg,nn,rt;char ch;
int to[N<<1],nextn[N<<1],h[N];
int a[N],size[N],mxsz[N],cnt[N];
bool b[N];ll ans[N],sw[N],res,psx;
#define add(x,y) to[++edg]=y,nextn[edg]=h[x],h[x]=edg
inline void read(int &x){
x=0;ch=getchar();
while(ch<48||ch>57)ch=getchar();
while(ch>47&&ch<58)x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
}
void write(ll x){if(x>9)write(x/10);putchar(48+x%10);}
void findrt(int x,int anc){
int i,y;size[x]=1;mxsz[x]=0;
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]&&y^anc){
findrt(y,x);
size[x]+=size[y];
if(size[y]>mxsz[x])mxsz[x]=size[y];
}
mxsz[x]=max(mxsz[x],nn-size[x]);
if(mxsz[x]<mxsz[rt])rt=x;
}
void dfs(int x,int anc){
int i,y,ax=a[x];
++cnt[ax];size[x]=1;
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]&&y^anc){
dfs(y,x);size[x]+=size[y];
}
--cnt[ax];//cnt 都是判断到x路径上是否有与其同色的点
if(!cnt[ax])sw[ax]+=size[x],res+=size[x];//统计sw及res
}
void dfs1(int x,int anc){
int i,y,ax=a[x];
++cnt[ax];
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]&&y^anc)dfs1(y,x);
--cnt[ax];
if(!cnt[ax])sw[ax]-=size[x],res-=size[x];
}
void dfs2(int x,int anc){
int i,y,ax=a[x];
++cnt[ax];
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]&&y^anc)dfs2(y,x);
--cnt[ax];
if(!cnt[ax])sw[ax]+=size[x],res+=size[x];
}
void dfs_(int x,int anc){
int i,y,ax=a[x];
if(!cnt[ax])res+=psx-sw[ax];
ans[x]+=res;
++cnt[ax];
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]&&y^anc)dfs_(y,x);
--cnt[ax];
if(!cnt[ax])res-=psx-sw[ax];
}
void clear(int x,int anc){
int i,y;sw[a[x]]=0;
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]&&y^anc)clear(y,x);
}
void work(int x){
int i,y,ax=a[x];
res=0;dfs(x,0);
cnt[ax]=1;
ans[x]+=res;
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]){
res-=size[y];
sw[ax]-=size[y];//这两行都是减去x的颜色的sw
dfs1(y,x);//减去y内的信息
psx=size[x]-size[y];
dfs_(y,x);//统计对y子树内的贡献
res+=size[y];
sw[ax]+=size[y];//减完要加回来处理后面的y
dfs2(y,x);
}
cnt[ax]=0;
clear(x,0);
}
void solve(int x){
b[x]=1;
work(x);
int i,y;
for(i=h[x];y=to[i],i;i=nextn[i])if(!b[y]){
rt=0;nn=size[y];
findrt(y,x);
solve(rt);
}
}
main(){
read(n);register int i;
for(i=1;i<=n;++i)read(a[i]);
for(i=1;i^n;++i)read(x),read(y),add(x,y),add(y,x);
rt=0;nn=n;mxsz[0]=n;
findrt(1,0);solve(rt);
for(i=1;i<=n;++i)write(ans[i]),putchar('\n');
}