考虑两个颜色相同的点,如果二者不是祖先后代,那么这个限制等价于起点终点不能分别在两个点的子树中,否则限制相当于两点不能一个在在后代子树中,一个不在祖先的儿子的子树中。因此每个点都能拆成 \(O(1)\) 个限制。因为每种颜色点不超过 20 因此可以直接大力枚举同色点对。
看到子树,显然来一个 dfs 序,这样子树编号就连续了。
然后发现这些限制是若干矩形,那么直接扫描线求矩形面积并就是总不合法路径数,减掉就好了。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
#define int long long
struct edge
{
int nxt,to;
}e[1000001<<1];
int n,tot,h[1000001],c[1000001],id[1000001],cnt,fa[1000001][21],dep[1000001],s[1000001],ans[1000001<<2],len[1000001<<2],sum[1000001<<2];
vector<int> col[1000001];
struct element
{
int l,r,h,p;
bool operator <(const element &other) const
{
return h<other.h;
}
}a[10000001];
inline int read()
{
int x=0;
char c=getchar();
while(c<‘0‘||c>‘9‘)
c=getchar();
while(c>=‘0‘&&c<=‘9‘)
{
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x;
}
inline void add(int x,int y)
{
e[++tot].nxt=h[x];
h[x]=tot;
e[tot].to=y;
}
inline int ls(int k)
{
return k<<1;
}
inline int rs(int k)
{
return k<<1|1;
}
inline void push_up(int k)
{
sum[k]=ans[k]? len[k]:sum[ls(k)]+sum[rs(k)];
}
void build(int k,int l,int r)
{
len[k]=r-l+1;
if(l==r)
return;
int mid=(l+r)>>1;
build(ls(k),l,mid);
build(rs(k),mid+1,r);
}
void update(int nl,int nr,int l,int r,int k,int p)
{
if(l>=nl&&r<=nr)
{
ans[k]+=p;
push_up(k);
return;
}
int mid=(l+r)>>1;
if(nl<=mid)
update(nl,nr,l,mid,ls(k),p);
if(nr>mid)
update(nl,nr,mid+1,r,rs(k),p);
push_up(k);
}
inline long long calc()
{
long long res=0;
sort(a+1,a+cnt+1);
build(1,1,n);
a[cnt+1].h=n+1;
for(register int i=1;i<=cnt;++i)
{
update(a[i].l,a[i].r,1,n,1,a[i].p);
if(a[i].h^a[i+1].h)
res+=1ll*(a[i+1].h-a[i].h)*sum[1];
}
return (1ll*n*(n-1)-res)/2+n;
}
inline int find(int x,int y)
{
if(dep[x]<dep[y])
x^=y^=x^=y;
for(register int d=dep[x]-dep[y]-1,i=0;1<<i<=d;++i)
if(d&(1<<i))
x=fa[x][i];
return x;
}
void dfs(int k,int f,int deep)
{
dep[k]=deep;
id[k]=++cnt;
s[k]=1;
fa[k][0]=f;
for(register int i=1;i<=20;++i)
fa[k][i]=fa[fa[k][i-1]][i-1];
for(register int i=h[k];i;i=e[i].nxt)
{
if(e[i].to==f)
continue;
dfs(e[i].to,k,deep+1);
s[k]+=s[e[i].to];
}
}
inline long long solve()
{
cnt=0;
for(register int i=1;i<=n;++i)
col[c[i]].push_back(i);
for(register int k=1;k<=n;++k)
for(register int i=0;i<(int)col[k].size()-1;++i)
for(register int j=i+1;j<(int)col[k].size();++j)
{
int x=col[k][i],y=col[k][j];
if(dep[x]>dep[y])
x^=y^=x^=y;
if(id[y]>=id[x]&&id[y]<=id[x]+s[x]-1)
{
int node=find(x,y);
a[++cnt].h=1;
a[cnt].l=id[y];
a[cnt].r=id[y]+s[y]-1;
a[cnt].p=1;
a[++cnt].h=id[node];
a[cnt].l=id[y];
a[cnt].r=id[y]+s[y]-1;
a[cnt].p=-1;
a[++cnt].h=id[y];
a[cnt].l=1;
a[cnt].r=id[node]-1;
a[cnt].p=1;
a[++cnt].h=id[y]+s[y];
a[cnt].l=1;
a[cnt].r=id[node]-1;
a[cnt].p=-1;
a[++cnt].h=id[node]+s[node];
a[cnt].l=id[y];
a[cnt].r=id[y]+s[y]-1;
a[cnt].p=1;
a[++cnt].h=n+1;
a[cnt].l=id[y];
a[cnt].r=id[y]+s[y]-1;
a[cnt].p=-1;
a[++cnt].h=id[y];
a[cnt].l=id[node]+s[node];
a[cnt].r=n;
a[cnt].p=1;
a[++cnt].h=id[y]+s[y];
a[cnt].l=id[node]+s[node];
a[cnt].r=n;
a[cnt].p=-1;
}
else
{
a[++cnt].h=id[x];
a[cnt].l=id[y];
a[cnt].r=id[y]+s[y]-1;
a[cnt].p=1;
a[++cnt].h=id[x]+s[x];
a[cnt].l=id[y];
a[cnt].r=id[y]+s[y]-1;
a[cnt].p=-1;
a[++cnt].h=id[y];
a[cnt].l=id[x];
a[cnt].r=id[x]+s[x]-1;
a[cnt].p=1;
a[++cnt].h=id[y]+s[y];
a[cnt].l=id[x];
a[cnt].r=id[x]+s[x]-1;
a[cnt].p=-1;
}
}
return calc();
}
signed main()
{
n=read();
for(register int i=1;i<=n;++i)
c[i]=read();
for(register int i=1;i<n;++i)
{
int x=read(),y=read();
add(x,y);
add(y,x);
}
dfs(1,0,1);
printf("%lld\n",solve());
return 0;
}