题目大意
有一棵\(n\)(\(n\leq10^5\))个节点的树,每个点有颜色\(c\),一开始所有颜色互不相同
要进行\(m\)(\(m\leq10^5\))次操作,每次操作是以下三种中的一种:
1.给出点\(x\),将点\(x\)到根路径上所有点的染成一种没出现过的颜色
2.给出点\(x\),\(y\),询问点\(x\)到\(y\)的简单路径上有多少种颜色
3.给出点\(x\),询问点\(x\)的子树中到根路径上颜色种类最多的点
题解
把1操作看成LCT的access操作,2操作就相当于询问一条链上有几条LCT里的重链,3操作相当于询问子树中到根LCT重链中最多的点
同时维护LCT和树剖,LCT为树的染色状态(同色的点在同一条重链上),树剖维护每个点\(i\)到根有几条重链,记为\(a_i\)
1操作就是LCT的access,同时树剖修改\(a_i\)
2操作的答案是\(a_x+a_y-2*a_{lca(x,y)}+1\),不用考虑重链在\(lca(x,y)\)处拐弯的情况
3操作是查询子树里\(a_i\)的最小值
代码
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<iomanip>
#include<iostream>
#include<map>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rep(i,x,y) for(register int i=(x);i<=(y);++i)
#define dwn(i,x,y) for(register int i=(x);i>=(y);--i)
#define maxn 100010
#define maxm (maxn<<1)
#define view(u,k) for(int k=fir[u];k!=-1;k=nxt[k])
#define ls (u<<1)
#define rs (u<<1|1)
#define ls2 son[u][0]
#define rs2 son[u][1]
#define mi (l+r>>1)
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)&&ch!='-')ch=getchar();
if(ch=='-')f=-1,ch=getchar();
while(isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x*f;
}
void write(int x)
{
if(x==0){putchar('0'),putchar('\n');return;}
int f=0;char ch[20];
if(x<0)putchar('-'),x=-x;
while(x)ch[++f]=x%10+'0',x/=10;
while(f)putchar(ch[f--]);
putchar('\n');
return;
}
int fir[maxn],nxt[maxm],v[maxm],now,n,q;
int siz[maxn],dfn[maxn],top[maxn],fth[maxn],wson[maxn],dep[maxn],cnt,tim,to[maxn],tr[maxn<<2],mk[maxn<<2];
int fa[maxn],son[maxn][2],rev[maxn],st[maxn],tp,tmpdep[maxn];
void ade(int u1,int v1){v[cnt]=v1,nxt[cnt]=fir[u1],fir[u1]=cnt++;}
void getson(int u)
{
siz[u]=1;
view(u,k)if(v[k]!=fth[u])
{
fa[v[k]]=fth[v[k]]=u,tmpdep[v[k]]=tmpdep[u]+1,dep[v[k]]=dep[u]+1,getson(v[k]),siz[u]+=siz[v[k]];
if(siz[v[k]]>siz[wson[u]])wson[u]=v[k];
}
}
void gettop(int u,int anc)
{
dfn[u]=++tim,to[tim]=u,top[u]=anc;
if(wson[u])gettop(wson[u],anc);
view(u,k)if(v[k]!=fth[u]&&v[k]!=wson[u])gettop(v[k],v[k]);
}
void build(int u,int l,int r)
{
if(l==r){tr[u]=dep[to[l]]+1;return;}
build(ls,l,mi),build(rs,mi+1,r),tr[u]=max(tr[ls],tr[rs]);return;
}
void mark(int u,int k){tr[u]+=k,mk[u]+=k;}
void pd(int u){if(mk[u]){mark(ls,mk[u]),mark(rs,mk[u]),mk[u]=0;}}
void add(int u,int l,int r,int x,int y,int k)
{
if(x<=l&&r<=y)return mark(u,k);
pd(u);
if(x<=mi)add(ls,l,mi,x,y,k);
if(y>mi)add(rs,mi+1,r,x,y,k);
tr[u]=max(tr[ls],tr[rs]);
return;
}
int ask(int u,int l,int r,int x,int y)
{
if(x<=l&&r<=y)return tr[u];
pd(u);
int res=0;
if(x<=mi)res=ask(ls,l,mi,x,y);
if(y>mi)res=max(res,ask(rs,mi+1,r,x,y));
return res;
}
int getso(int u){return son[fa[u]][0]!=u;}
int nort(int u){return son[fa[u]][0]==u||son[fa[u]][1]==u;}
void rot(int u)
{
int fu=fa[u],ffu=fa[fu],l=getso(u),fl=getso(fu),r=l^1,rson=son[u][r];
if(nort(fu))son[ffu][fl]=u;son[fu][l]=rson,son[u][r]=fu,fa[rson]=fu,fa[u]=ffu,fa[fu]=u;
}
void splay(int u)
{
while(nort(u)){int fu=fa[u];if(nort(fu))rot(getso(u)^getso(fu)?u:fu);rot(u);}
}
int rnxt(int u){u=rs2;while(u){if(!ls2)break;u=ls2;}return u;}
int sroot(int u){while(u&&ls2)u=ls2;return u;}
void acs(int u)
{
for(int vv=0;u;vv=u,u=fa[u])
{
splay(u);
int tmp=rnxt(u),tmp2=sroot(vv);
rs2=vv;
if(tmp){add(1,1,n,dfn[tmp],dfn[tmp]+siz[tmp]-1,1);}
if(tmp2){add(1,1,n,dfn[tmp2],dfn[tmp2]+siz[tmp2]-1,-1);}
}
}
int Lca(int x,int y)
{
while(top[x]!=top[y])
{
if(tmpdep[top[x]]<tmpdep[top[y]])swap(x,y);
x=fth[top[x]];
}
return tmpdep[x]<tmpdep[y]?x:y;
}
int main()
{
n=read(),q=read();
memset(fir,-1,sizeof(fir));
rep(i,1,n-1){int x=read(),y=read();ade(x,y),ade(y,x);}
getson(1),gettop(1,1),build(1,1,n);
while(q--)
{
int f=read();
if(f==1){int x=read();acs(x);}
else if(f==2)
{
int x=read(),y=read(),lca=Lca(x,y),ansx=ask(1,1,n,dfn[x],dfn[x]),ansy=ask(1,1,n,dfn[y],dfn[y]),ansl=ask(1,1,n,dfn[lca],dfn[lca]);
write(ansx+ansy-2*ansl+1);
}
else
{
int x=read();
write(ask(1,1,n,dfn[x],dfn[x]+siz[x]-1));
}
}
return 0;
}