题意
给定一棵 \(n\) 个点的数,每个点都有一个权值。给出 \(q\) 次操作,有两种操作类型。
\(1.\) 给出 \(x\) 和 \(y\),将点 \(x\) 的权值修改为 \(y\)。
\(2.\) 给出点 \(u\) 和 \(v\),令 \(s\) 表示每一种从 \(u\) 号节点到 \(v\) 号节点出现的权值的次数,输出 \(\sum \dfrac{s \times (s-1)}{2}\)。
数据范围
\(1 \leq n,m \leq 10^5\),\(0 \leq a_i,y \leq 10^5\)。
思路
本题实际上就是维护队列和Count on a tree II的结合。可以考虑用树上带修莫队求解。
带修莫队实际上就是给询问增加了时间戳,将块长设置成 \(n^{\frac{2}{3}}\)。这样可以保证总的时间复杂度在 \(O(n^{\frac{5}{3}})\)。具体证明可以自行百度,这里不展开来讲。
树上莫队运用到了欧拉序的特点。简单来说就是在深搜遍历时先记录下当前点,等到遍历完其子树回溯时再记录下当前点。如题目样例给出的树:
记 \(first[u]\) 表示 \(u\) 在欧拉序中第一次出现的位置,\(last[u]\) 表示 \(u\) 在欧拉序中第二次出现的位置,\(seq[u]\) 表示欧拉序中第 \(u\) 个数。
按照出现了两次等于没有出现的思路。设 \(first[x]<first[y]\) ,可以发现,如果 \(x,y\) 的 LCA 是它们其中之一,那么从 \(first[x]\) 遍历到 \(first[y]\) 恰好将其中的每一个数字都出现了一次,不在路径上的点要么没有出现要么就出现了两次。而当它们的 LCA 不是其中之一时,按照 \(last[x]\) 到 \(first[y]\) 的顺序,除了 LCA 不会被遍历到,其余的都被遍历了一次。故只需要特判一下 LCA 即可。
最后,答案记得开 long long。
code:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=2e5+10;
#define LL long long
int seq[N],top,n,m,len,h[N],lg[N],idx,fa[N][25],first[N],last[N],depth[N],a[N];
int tot,time;LL ans[N],res,cnt[N]; bool vis[N];
struct edge{int v,nex;}e[N]; void add(int u,int v){e[++idx].v=v;e[idx].nex=h[u];h[u]=idx;}
struct query{int l,r,q,t,id;}q[N];
struct update{int x,y;}c[N];
bool cmp(query a,query b)
{
if(a.l/len!=b.l/len) return a.l/len<b.l/len;
if(a.r/len!=b.r/len) return a.r/len<b.r/len;
return a.t<b.t;
}
void dfs(int u,int father)
{
first[u]=++top,seq[top]=u;
depth[u]=depth[father]+1,fa[u][0]=father;
for(int i=1;i<=lg[depth[u]];i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=h[u];i;i=e[i].nex) if(e[i].v!=father) dfs(e[i].v,u);
last[u]=++top,seq[top]=u;
}
int LCA(int x,int y)
{
if(depth[x]<depth[y]) swap(x,y);
while(depth[x]>depth[y]) x=fa[x][lg[depth[x]-depth[y]]-1];
if(x==y) return x;
for(int i=lg[depth[x]];i>=0;i--) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void add(int x)
{
vis[x]^=1;
res-=cnt[a[x]]*(cnt[a[x]]-1ll)/2ll;
cnt[a[x]]+=vis[x]?1:-1;
res+=cnt[a[x]]*(cnt[a[x]]-1ll)/2ll;
}
int main()
{
scanf("%d%d",&n,&m);len=pow(n,2.0/3);for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int u,v,i=1;i<n;i++) scanf("%d%d",&u,&v),u++,v++,add(u,v),add(v,u);
for(int i=1;i<N;i++) lg[i]=lg[i-1]+((1<<lg[i-1])==i);dfs(1,0);
for(int opt,x,y,i=1;i<=m;i++)
{
scanf("%d%d%d",&opt,&x,&y);
if(opt==1) c[++time]=update{x+1,y};
else
{
tot++,x++,y++;if(first[x]>first[y]) swap(x,y);int z=LCA(x,y);
if(z!=x)q[tot]=query{last[x],first[y],z,time,tot};
else q[tot]=query{first[x],first[y],0,time,tot};
}
}
for(int i=1;i<=top;i++) printf("%d ",seq[i]-1);puts("");
m=tot;sort(q+1,q+m+1,cmp);
for(int i=0,j=1,t=0,k=1;k<=m;k++)
{
int l=q[k].l,r=q[k].r,tim=q[k].t;
while(i<r) add(seq[++i]);
while(i>r) add(seq[i--]);
while(j>l) add(seq[--j]);
while(j<l) add(seq[j++]);
if(q[k].q) add(q[k].q);
while(t<tim)
{
t++;
int x=c[t].x;
if(vis[x])
{
res-=cnt[a[x]]*(cnt[a[x]]-1ll)/2ll;
cnt[a[x]]--;
res+=cnt[a[x]]*(cnt[a[x]]-1ll)/2ll;
res-=cnt[c[t].y]*(cnt[c[t].y]-1ll)/2ll;
cnt[c[t].y]++;
res+=cnt[c[t].y]*(cnt[c[t].y]-1ll)/2ll;
}
swap(a[x],c[t].y);
}
while(t>tim)
{
int x=c[t].x;
if(vis[x])
{
res-=cnt[a[x]]*(cnt[a[x]]-1ll)/2ll;
cnt[a[x]]--;
res+=cnt[a[x]]*(cnt[a[x]]-1ll)/2ll;
res-=cnt[c[t].y]*(cnt[c[t].y]-1ll)/2ll;
cnt[c[t].y]++;
res+=cnt[c[t].y]*(cnt[c[t].y]-1ll)/2ll;
}
swap(a[x],c[t].y);
--t;
}
ans[q[k].id]=res;
if(q[k].q) add(q[k].q);
}
for(int i=1;i<=m;i++) printf("%lld\n",ans[i]);
return 0;
}