题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=6035
题目大意:给你一棵树,树上每个节点都有一个颜色。 现在定义两点间的距离为两点最短路径上颜色集合大小,求该树上所有点对的距离之和。其中树上的节点个数$≤2*10^5$
如果直接处理每一条路径上颜色集合大小,显然比较麻烦。我们不妨换一种思路。
我们用S_i表示经过颜色i的路径的数量,显然答案$=\sum S_i$。
考虑如何求S_i。我们先将所有颜色为i的节点全部找出来,按照dfs序排序。
显然,若树上所有的路径均经过该颜色的节点,则$S_i=\frac{n*(n-1)}{2}$。
对于该点集中的节点x的每一棵子树,不妨设当前子树的根节点为v,找出所有点集中满足$dfn[v]<dfn[u]≤low[v]$且不存在$y$,使得$dfn[v]<dfn[y]<dfn[u]≤low[v]$的所有的$u$,则显然有$\frac{(siz[v]-\sum siz[u]) \times (siz[v]-\sum siz[u]-1)}{2}$个点对不会对答案产生贡献。其中$siz[x]$表示以x为根的子树的节点个数。
该统计方法的时间复杂度为$O(n log n)$
#include<bits/stdc++.h>
#define L long long
#define M 200005
using namespace std;
struct edge{int u,next;}e[M*]={}; int head[M]={},use=;
void add(int x,int y){use++;e[use].u=y;e[use].next=head[x];head[x]=use;} int dfn[M]={},low[M]={},t=;
int siz[M]={},col[M]={}; void dfss(int x,int fa){
dfn[x]=++t; siz[x]=;
for(int i=head[x];i;i=e[i].next) if(e[i].u!=fa){
dfss(e[i].u,x);
siz[x]+=siz[e[i].u];
}
low[x]=t;
}
bool cmp(int x,int y){
if(col[x]==col[y]) return dfn[x]<dfn[y];
return col[x]<col[y];
}
int p[M]={};
L ans=,sum=,n; void dfs(int &x,int y){
int xx=p[x]; x++;
for(int i=head[xx];i;i=e[i].next)
if(dfn[xx]<dfn[e[i].u]){
int v=e[i].u;
L cnt=siz[v];
while(x<=y&&dfn[p[x]]<=low[v]){
cnt-=siz[p[x]];
dfs(x,y);
}
sum-=cnt*(cnt-);
}
}
int hh=;
int Main(){
hh++;
ans=; sum=; t=;
memset(head,,sizeof(head)); use=;
for(int i=;i<=n;i++) scanf("%d",col+i);
for(int i=;i<n;i++){
int x,y; scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfss(,); add(,);
for(int i=;i<=n;i++) p[i]=i;
sort(p+,p+n+,cmp);
for(int i=,j;i<=n;i=j+){
for(j=i;col[p[i]]==col[p[j]];j++); j--;
sum=n*(n-);
p[--i]=;
dfs(i,j);
ans+=sum;
}
printf("Case #%d: %lld\n",hh,ans/);
//cout<<ans/2<<endl;
} int main(){
freopen("in.txt","r",stdin);
while(cin>>n) Main();
}