lct模板
网上有很多lct的bolgs了,这里就不详细论述。
几个函数:
1.access(x)
把x到原树的根的路径打通。
分成3步:
1.x转到当前的根
2.x的右儿子改成上一次的根(改成重边),维护信息
3.x往上跳轻边(x=tr[x].f)
Access后x到根在一颗splay里,但根不一定是splay的根,所以一般要splay(x)一下
2.makeroot(x)
使x成为原树的根。
先Access(x);splay(x);
然后此时x只有左儿子,swap一下x就变成深度最小的点了
3.findroot(x)
找到x所在原树的根
Access(x);splay(x);
一直向左找到深度最小的就行
注意要pushdown
4.split(x,y)
把x到y的链拉出来成为一颗splay
makeroot(x); access(y); splay(y);
5.link(x,y)
makeroot(x) if(findroot(y)!=x)tr[x].f=y;
6.cut
split(x,y); if(findroot(y)==x&&tr[x].f==y&&tr[y].ch[0]==x&&tr[y].ch[1]==0) {
tr[x].f=0;tr[y].ch[0]=0; wh(y);
}
#include<cstdio> #include<iostream> #include<cstdlib> #include<cstring> #include<algorithm> #include<cmath> #define maxn 300005 using namespace std; int n,m,v[maxn]; struct node{ int f,ch[2],s; bool re; }tr[maxn]; void wh(int x){ tr[x].s=tr[tr[x].ch[0]].s^tr[tr[x].ch[1]].s^v[x]; } int get(int x){ return tr[tr[x].f].ch[1]==x; } bool isroot(int x){ return tr[tr[x].f].ch[0]!=x&&tr[tr[x].f].ch[1]!=x; } void upr(int k){ tr[k].re^=1; swap(tr[k].ch[0],tr[k].ch[1]); } void down(int k){ if(tr[k].re){ upr(tr[k].ch[0]);upr(tr[k].ch[1]); tr[k].re=0; } } void rotate(int x){ int y=tr[x].f,z=tr[y].f; int wx=get(x),wy=get(y); if(!isroot(y))tr[z].ch[wy]=x;tr[x].f=z;//attation tr[y].ch[wx]=tr[x].ch[wx^1];tr[tr[x].ch[wx^1]].f=y; tr[x].ch[wx^1]=y;tr[y].f=x; wh(y),wh(x); } int st[maxn]; void splay(int x){ int y=x,top=0; st[++top]=y; while(!isroot(y))y=tr[y].f,st[++top]=y; while(top)down(st[top--]);//at while(!isroot(x)){ int y=tr[x].f; if(!isroot(y))rotate(get(x)==get(y)?y:x);//at rotate(x); } } void access(int x){ for(int y=0;x;y=x,x=tr[x].f) splay(x),tr[x].ch[1]=y,wh(x); } void makeroot(int x){//原树的根 access(x);splay(x); upr(x); wh(x); } int findroot(int x){ access(x);splay(x); int f; while(x)down(x),f=x,x=tr[x].ch[0]; return f; } void split(int x,int y){ makeroot(x); access(y);splay(y); } void link(int x,int y){ makeroot(x); if(findroot(y)!=x)tr[x].f=y; } void cut(int x,int y){ split(x,y); if(findroot(y)==x&&tr[x].f==y&&tr[y].ch[0]==x&&tr[y].ch[1]==0){ tr[x].f=0;tr[y].ch[0]=0; wh(y); } } int main(){ cin>>n>>m; for(int i=1;i<=n;i++)scanf("%d",&v[i]); for(int i=1,op,x,y;i<=m;i++){ scanf("%d%d%d",&op,&x,&y); if(op==0)split(x,y),printf("%d\n",tr[y].s); if(op==1)link(x,y); if(op==2)cut(x,y); if(op==3)splay(x),v[x]=y; } return 0; }View Code