树链剖分
先占个坑
#include <bits/stdc++.h>
#define N 500020
#define LS(z) (z<<1)
#define RS(z) (z<<1|1)
#define int long long
using namespace std;
inline int read();
int n,m,o,P;
int A[N],a[N];
int idx,dfn[N],fa[N];
int dep[N],top[N],siz[N],son[N];
int hd[N],cne;
struct edge{int to,nx;} e[N];
void Add(int u,int v)
{
cne++;
e[cne].to=v;
e[cne].nx=hd[u];
hd[u]=cne;
}
struct segment_tree
{int tag,sum;} T[N<<2];
inline void add(int l,int r,int k,int p)
{
T[p].tag=T[p].tag+k;
T[p].sum=T[p].sum+k*(r-l+1);
}
inline void push_tag(int l,int r,int p)
{
int m=(l+r)>>1;
add(l ,m,T[p].tag,LS(p));
add(m+1,r,T[p].tag,RS(p));
T[p].tag=0;
}
void build(int l,int r,int p)
{
T[p].tag=0;
if(l==r) {T[p].sum=a[l];return ;}
int m=(l+r)>>1;
build(l ,m,LS(p));
build(m+1,r,RS(p));
T[p].sum=T[LS(p)].sum+T[RS(p)].sum;
}
// 表示将以 x 为根节点的子树内所有节点值都加上 z
inline void upd(int L,int R,int K,int l,int r,int p)
{
if(L<=l && r<=R)
{
T[p].sum+=K*(r-l+1);
T[p].tag+=K;
return ;
}
push_tag(l,r,p);
int m=(l+r)>>1;
if(L<=m) upd(L,R,K, l ,m,LS(p));
if(R> m) upd(L,R,K, m+1,r,RS(p));
T[p].sum=T[LS(p)].sum+T[RS(p)].sum;
}
// 表示求以 x 为根节点的子树内所有节点值之和
inline int query(int L,int R,int l,int r,int p)
{
int ret=0;
if(L<=l && R>=r) return T[p].sum;
push_tag(l,r,p);
int m=(l+r)>>1;
if(L<=m) ret+=query(L,R, l ,m,LS(p));
if(R> m) ret+=query(L,R, m+1,r,RS(p));
return ret;
}
/*-------------------- 以上是线段树 --------------------*/
void dfs1(int x)
{
siz[x]=1;
for(int i=hd[x];i;i=e[i].nx)
{
int t=e[i].to;
if(!dep[t])
{
dep[t]=dep[x]+1;
fa[t]=x;
dfs1(t);
siz[x]+=siz[t];
if(siz[t]> siz[son[x]])
son[x]=t;
}
}
}
void dfs2(int x,int v)
{
dfn[x]=++idx;
a[idx]=A[x];
top[x]=v;
if(son[x]) dfs2(son[x],v);
for(int i=hd[x];i;i=e[i].nx)
{
int t=e[i].to;
if(t!=fa[x] && t!=son[x])
dfs2(t,t);
}
}
// 表示求树从 x 到 y 结点最短路径上所有节点的值之和
int call_ask(int x,int y)
{
int ret=0;
int tx=top[x],ty=top[y];
while(tx!=ty)
{
if(dep[tx]< dep[ty])
swap(x,y),swap(tx,ty);
ret+=query(dfn[tx],dfn[x], 1,idx,1);
x=fa[tx];tx=top[x];
}
if(dfn[x]> dfn[y])
swap(x,y);
ret+=query(dfn[x],dfn[y], 1,idx,1);
return ret;
}
// 表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z
void call_add(int x,int y,int v)
{
int tx=top[x],ty=top[y];
while(tx!=ty)
{
if(dep[tx]<dep[ty])
swap(x,y),swap(tx,ty);
upd(dfn[tx],dfn[x],v, 1,idx,1);
x=fa[tx],tx=top[x];
}
if(dfn[x]> dfn[y])
swap(x,y);
upd(dfn[x],dfn[y],v, 1,idx,1);
}
signed main()
{
n=read();m=read();
o=read();P=read();
for(int i=1;i<=n;i++)
A[i]=read(),A[i]%=P;
for(int i=1;i< n;i++)
{
int x=read(),y=read();
Add(x,y);Add(y,x);
}
dep[o]=1;
fa[o]=1;
dfs1(o);
dfs2(o,o);
build(1,n,1);
while(m--)
{
int opt=read();
if(opt==1)
{
int x=read(),y=read(),z=read();
call_add(x,y,z%P);
}
if(opt==2)
{
int x=read(),y=read();
printf("%lld\n", call_ask(x,y) %P);
}
if(opt==3)
{
int x=read(),y=read();
upd(dfn[x],dfn[x]+siz[x]-1,y%P, 1,n,1);
}
if(opt==4)
{
int x=read();
printf("%lld\n", query(dfn[x],dfn[x]+siz[x]-1, 1,n,1) %P);
}
}
return 0;
}
inline int read()
{
int x=0,w=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=x*10+ch-'0',ch=getchar();
return x*w;
}