首先可以转化问题,变为对每种颜色分别考虑不含该颜色的简单路径条数。然后把不是当前颜色的点视为白色,是当前颜色的点视为黑色,显然路径数量是每个白色连通块大小的平方和,然后题目变为:黑白两色的树,单点翻转颜色,维护白色连通块大小平方和,然后根据Auuan大佬的题解,我用了LCT。就是对每个点维护子树、儿子大小平方和,在 link/cut 的时候更新答案。初始化所有点是白色,离线处理每个颜色即可。
这题放在2h比赛上,除了lxl其他人都写不出来(况且lxl还是本题出题人呢)
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=4e5+7; int n,m,c[N],f[N],fa[N],ch[N][2],sum[N],sz[N]; ll ans,d[N],sz2[N]; bool vis[N]; vector<int>vec[N][2],G[N]; bool nroot(int x){return x==ch[fa[x]][0]||x==ch[fa[x]][1];} void pushup(int x){sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+sz[x]+1;} void rotate(int x) { int y=fa[x],z=fa[y],w=x==ch[y][1]; if(nroot(y))ch[z][y==ch[z][1]]=x; fa[x]=z,ch[y][w]=ch[x][w^1],fa[ch[x][w^1]]=y,ch[x][w^1]=y,fa[y]=x; pushup(y),pushup(x); } void splay(int x) { while(nroot(x)) { int y=fa[x],z=fa[y]; if(nroot(y)) { if((x==ch[y][1])^(y==ch[z][1]))rotate(x); else rotate(y); } rotate(x); } } void access(int x) { int y=0; while(x) { splay(x); sz[x]+=sum[ch[x][1]]-sum[y]; sz2[x]+=1ll*sum[ch[x][1]]*sum[ch[x][1]]-1ll*sum[y]*sum[y]; ch[x][1]=y; pushup(x); x=fa[y=x]; } } int findrt(int x) { access(x),splay(x); while(ch[x][0])x=ch[x][0]; splay(x); return x; } void link(int x) { int y=f[x],z; splay(x); ans-=sz2[x]+1ll*sum[ch[x][1]]*sum[ch[x][1]]; z=findrt(y); access(x),splay(z); ans-=1ll*sum[ch[z][1]]*sum[ch[z][1]]; fa[x]=y; splay(y); sz[y]+=sum[x],sz2[y]+=1ll*sum[x]*sum[x]; pushup(y),access(x),splay(z); ans+=1ll*sum[ch[z][1]]*sum[ch[z][1]]; } void cut(int x) { int y=f[x],z; access(x); ans+=sz2[x]; z=findrt(y); access(x),splay(z); ans-=1ll*sum[ch[z][1]]*sum[ch[z][1]]; splay(x); ch[x][0]=fa[ch[x][0]]=0; pushup(x),splay(z); ans+=1ll*sum[ch[z][1]]*sum[ch[z][1]]; } void dfs(int u) {for(int i=0;i<G[u].size();i++)if(G[u][i]!=f[u])f[G[u][i]]=u,dfs(G[u][i]);} int main() { scanf("%d%d",&n,&m); ll lst; for(int i=1;i<=n;i++) scanf("%d",&c[i]),vec[c[i]][0].push_back(i),vec[c[i]][1].push_back(0); for(int i=1;i<=n+1;i++)sum[i]=1; for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),G[x].push_back(y),G[y].push_back(x); for(int i=1,u,v;i<=m;i++) { scanf("%d%d",&u,&v); vec[c[u]][0].push_back(u),vec[c[u]][1].push_back(i); c[u]=v; vec[v][0].push_back(u),vec[v][1].push_back(i); } f[1]=n+1; dfs(1); for(int i=1;i<=n;i++)link(i); for(int i=1;i<=n;i++) { if(!vec[i][0].size()){d[0]+=1ll*n*n;continue;} if(vec[i][1][0])d[0]+=1ll*n*n,lst=1ll*n*n;else lst=0; for(int j=0;j<vec[i][0].size();j++) { int u=vec[i][0][j]; if(vis[u]^=1)cut(u);else link(u); if(j==vec[i][0].size()-1||vec[i][1][j+1]!=vec[i][1][j]) d[vec[i][1][j]]+=ans-lst,lst=ans; } for(int j=vec[i][0].size()-1;~j;j--) { int u=vec[i][0][j]; if(vis[u]^=1)cut(u);else link(u); } } ans=1ll*n*n*n; for(int i=0;i<=m;i++)ans-=d[i],printf("%lld\n",ans); }View Code