Solution
像树点染色一样,看到这种到根染色的题目要想到LCT。
每个splay中的点的颜色都是一样的,然后access的时候提取出每一段颜色,用一个树状数组统计答案即可。复杂度是nlog2n的。
Code
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=100010;
const int inf=2147483647;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
int n,col[Maxn],cnt=0;
struct P{int x,id;}A[Maxn];
bool cmp(P a,P b){return a.x<b.x;}
int fa[Maxn],son[Maxn][2],sz[Maxn],rev[Maxn];
bool is(int x){return(son[fa[x]][0]!=x&&son[fa[x]][1]!=x);}
void up(int x){if(x)sz[x]=sz[son[x][0]]+sz[son[x][1]]+1;}
void Rev(int x)
{
rev[x]^=1;
swap(son[x][0],son[x][1]);
}
void down(int x)
{
int lc=son[x][0],rc=son[x][1];
if(rev[x])
{
rev[x]=0;
if(lc)Rev(lc);if(rc)Rev(rc);
}
if(lc)col[lc]=col[x];if(rc)col[rc]=col[x];
}
void rotate(int x)
{
int y=fa[x],z=fa[y],w=(son[y][0]==x);
son[y][w^1]=son[x][w];if(son[x][w])fa[son[x][w]]=y;
if(!is(y))son[z][son[z][1]==y]=x;fa[x]=z;
son[x][w]=y;fa[y]=x;
up(y),up(x);
}
int sta[Maxn],top=0;
void pushdown(int x)
{
while(1)
{
sta[++top]=x;
if(is(x))break;
x=fa[x];
}
while(top)down(sta[top--]);
}
int s[Maxn];
void add(int x,int y){for(;x<=cnt;x+=(x&-x))s[x]+=y;}
int query(int x){int re=0;for(;x;x-=(x&-x))re+=s[x];return re;}
void splay(int x)
{
pushdown(x);
while(!is(x))
{
int y=fa[x],z=fa[y];
if(is(y))rotate(x);
else rotate(((son[z][1]==y)==(son[y][1]==x))?y:x),rotate(x);
}
}
vector<pa>h;
void access(int x,int op)
{
int last=0;
while(x)
{
splay(x);
if(op)h.push_back(make_pair(col[x],sz[x]-sz[son[x][1]]));
son[x][1]=last;
up(x);last=x;x=fa[x];
}
}
void make_root(int x){access(x,0);splay(x);Rev(x);}
void link(int x,int y){make_root(x);fa[x]=y;}
LL solve()
{
LL re=0;
for(int i=0;i<h.size();i++)
re+=(LL)query(h[i].first-1)*h[i].second,add(h[i].first,h[i].second);
for(int i=0;i<h.size();i++)
add(h[i].first,-h[i].second);
h.clear();
return re;
}
int main()
{
n=read();
for(int i=1;i<=n;i++)A[i].x=read(),A[i].id=i,sz[i]=1;
sort(A+1,A+1+n,cmp);
for(int i=1;i<=n;i++)
{
if(i==1||A[i-1].x!=A[i].x)cnt++;
col[A[i].id]=cnt;
}
for(int i=1;i<n;i++)
{
int x=read(),y=read();
make_root(1);
access(x,1);
printf("%lld\n",solve());
splay(1);col[1]=col[y];link(x,y);
}
}