LCT裸题
注意打标记之间的影响就是了
这个膜数不会爆unsigned int
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cctype>
using namespace std;
#define rg register
#define il inline
#define sta static
#define vd void
#define int unsigned int
#define mod 51061
il int gi(){
sta int x,flg;sta char ch;
x=flg=0,ch=getchar();
while(!isdigit(ch)){if(ch=='-')flg=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return flg?-x:x;
}
const int maxn=100001;
int ch[maxn][2],fa[maxn],w[maxn],sum[maxn],siz[maxn];
int add[maxn],mul[maxn];
bool rev[maxn];
typedef const int& fast;
il vd upd(fast x){if(x)sum[x]=(sum[ch[x][0]]+sum[ch[x][1]]+w[x])%mod,siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;}
il bool isrt(fast x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;}
il vd Mul(fast x,fast y){if(x)w[x]=w[x]*y%mod,sum[x]=sum[x]*y%mod,add[x]=add[x]*y%mod,mul[x]=mul[x]*y%mod;}
il vd Add(fast x,fast y){if(x)w[x]=(w[x]+y)%mod,sum[x]=(sum[x]+y*siz[x])%mod,add[x]=(add[x]+y)%mod;}
il vd Rev(fast x){if(x)rev[x]^=1,std::swap(ch[x][0],ch[x][1]);}
il vd down(fast x){
if(mul[x]^1)Mul(ch[x][0],mul[x]),Mul(ch[x][1],mul[x]),mul[x]=1;
if(add[x])Add(ch[x][0],add[x]),Add(ch[x][1],add[x]),add[x]=0;
if(rev[x])Rev(ch[x][0]),Rev(ch[x][1]),rev[x]=0;
}
il vd rotate(fast x){
sta int y,z,o;y=fa[x],z=fa[y],o=ch[y][1]==x;
if(!isrt(y))ch[z][y==ch[z][1]]=x;fa[x]=z;
ch[y][o]=ch[x][!o];fa[ch[x][!o]]=y;
fa[y]=x;ch[x][!o]=y;
upd(y);
}
il vd splay(fast x){
sta int stk[maxn],top;stk[top=1]=x;
for(rg int i=x;!isrt(i);i=fa[i])stk[++top]=fa[i];
while(top)down(stk[top--]);
sta int y,z;
for(y=fa[x],z=fa[y];!isrt(x);rotate(x),y=fa[x],z=fa[y])
if(!isrt(y))rotate(((ch[y][0]==x)^(ch[z][0]==y))?x:y);
upd(x);
}
il vd access(int x){for(rg int y=0;x;x=fa[y=x])splay(x),ch[x][1]=y,upd(x);}
il vd makert(fast x){access(x),splay(x),Rev(x);}
il vd link(fast x,fast y){makert(x),fa[x]=y;}
il vd split(fast x,fast y){makert(x),access(y),splay(y);}
il vd cut(fast x,fast y){split(x,y),fa[x]=ch[y][0]=0;}
main(){
freopen("nt2012_wym_tree.in","r",stdin);
freopen("nt2012_wym_tree.out","w",stdout);
int n=gi(),q=gi();char opt[3];
for(rg int i=1;i<=n;++i)w[i]=sum[i]=siz[i]=1,add[i]=0,mul[i]=1,rev[i]=0;
for(rg int i=1;i<n;++i)link(gi(),gi());
int u,v;
while(q--){
scanf("%s",opt);
if(opt[0]=='+')u=gi(),v=gi(),split(u,v),Add(v,gi());
else if(opt[0]=='-')u=gi(),v=gi(),cut(u,v),u=gi(),v=gi(),link(u,v);
else if(opt[0]=='*')u=gi(),v=gi(),split(u,v),Mul(v,gi());
else u=gi(),v=gi(),split(u,v),printf("%u\n",sum[v]);
}
return 0;
}