要统计所有路径的信息,那我们考虑点分治,每次算经过分治中心的路径的贡献。然而路径的颜色数量实在是不好统计,既然只需要求从每个点出发的所有路径的颜色数量之和,那换一种思路,改为求从每个点出发包含某种颜色的路径数量之和。这两者显然是等价的。
考虑在点分治过程中怎么算这个东西。首先算出每种颜色被多少条由根到分治块中的点的路径(特别地,根本身也是一条路径)包含。这个可以dfs求出,dfs时用桶记录一下当前出现了哪些颜色,若出现新颜色就记录并把该颜色的贡献加上当前点的子树大小。之后利用这个统计,计算某子树的答案时先把该子树贡献减去,dfs到某个点时把这个点的颜色的贡献改为由根到其他子树的路径条数,更新总贡献并更新该点的答案。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
using namespace std;
int read()
{
int x=,f=;char c=getchar();
while (c<''||c>'') {if (c=='-') f=-;c=getchar();}
while (c>=''&&c<='') x=(x<<)+(x<<)+(c^),c=getchar();
return x*f;
}
#define N 100010
int n,color[N],p[N],size[N],cnt[N],tag[N],t=;
long long ans[N],tot;
bool flag[N];
struct data{int to,nxt;
}edge[N<<];
void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;}
void makes(int k,int from)
{
size[k]=;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from&&!flag[edge[i].to])
{
makes(edge[i].to,k);
size[k]+=size[edge[i].to];
}
}
int findroot(int k,int s,int from)
{
int mx=;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from&&!flag[edge[i].to]&&size[edge[i].to]>size[mx]) mx=edge[i].to;
if ((size[mx]<<)>s) return findroot(mx,s,k);
else return k;
}
void calc(int k,int from,int v)
{
if (!tag[color[k]]) cnt[color[k]]+=size[k]*v,tot+=size[k]*v;
tag[color[k]]++;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from&&!flag[edge[i].to]) calc(edge[i].to,k,v);
tag[color[k]]--;
}
void work(int k,int from,int s)
{
int tmp=cnt[color[k]];tot+=s-cnt[color[k]];cnt[color[k]]=s;
ans[k]+=tot;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from&&!flag[edge[i].to]) work(edge[i].to,k,s);
cnt[color[k]]=tmp;tot-=s-cnt[color[k]];
}
void solve(int k)
{
makes(k,k);
k=findroot(k,size[k],k);flag[k]=;
makes(k,k);
tot=;
calc(k,k,);
ans[k]+=tot;
tag[color[k]]=;
for (int i=p[k];i;i=edge[i].nxt)
if (!flag[edge[i].to])
{
calc(edge[i].to,k,-);
cnt[color[k]]=size[k]-size[edge[i].to];tot-=size[edge[i].to];
work(edge[i].to,k,size[k]-size[edge[i].to]);
tot+=size[edge[i].to];cnt[color[k]]=size[k];
calc(edge[i].to,k,);
}
tag[color[k]]=;
calc(k,k,-);
for (int i=p[k];i;i=edge[i].nxt)
if (!flag[edge[i].to]) solve(edge[i].to);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("game.in","r",stdin);
freopen("game.out","w",stdout);
const char LL[]="%I64d\n";
#else
const char LL[]="%lld\n";
#endif
n=read();
for (int i=;i<=n;i++) color[i]=read();
for (int i=;i<n;i++)
{
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
solve();
for (int i=;i<=n;i++) printf(LL,ans[i]);
return ;
}