板子题
题目传送门
这道题目要求在树上修改和查询点一条链上或者是一棵子树的点的权值。
算法解析
考虑使用LCA,但是不能使用倍增的解法(倍增只能查询不能修改),所以我们要使用一种新的算法——轻重链剖分。
建议先学完LCA在来看这篇文章。
定义
这里列出一些必要的定义:
- 重儿子:子节点最多的一个儿子。
- 轻儿子:一个节点的儿子除重儿子之外的儿子。
- 重边:一段为重儿子的边。
- 轻边:除重边以外的边。
- 重链:由重边组成的链,由轻儿子为起点。
实现
预处理
预处理由两次dfs组成。
第一次dfs需要处理以下数组:
\(fa,son,siz,d\) 分别代表节点的 父亲,重儿子,子树节点数,深度。
第二次dfs需要处理以下数组:
\(id,top,a\) 分别代表节点的新编号(按照dfs达到的顺序),这个节点所在的重链的顶端,新编号的点权,注意一定要先处理重儿子并且使用dfs。原因后面讲。
修改
怎么修改或者查询一条链的点权呢?
我们可以让更深的点沿着重链向上跳,直到两个点处于同一条重链上就可以了。
然后我们会发现,如果先处理重儿子并且使用dfs的时候,这样一条重链上的节点和一棵子树的节点的编号是连续的,我们就可以使用线段树来解决这个问题了。
而处理一棵子树权值的时候,我们只要将 \(id_i\) 到 \(id_i+siz_i-1]\) 这段区间进行修改或者是查询就可以了。
因为修改的点的 \(id\) 都是连续的,所以我们就可以使用线段树大法来解决了。
复杂度
处理一条链的复杂度是 \(\Theta\left(\log^2n\right)\) ,处理子树的复杂度是 \(\Theta\left(\log n\right)\) 。
代码
这里用递归式线段树来解决这个问题。
#include<cstdio>
#define maxn 100039
#define emaxn 200039
using namespace std;
typedef long long ll;
int n,T,root;
ll MOD;
int head[maxn],nex[emaxn],to[emaxn],k;
#define add(x,y) nex[++k]=head[x];\
head[x]=k;\
to[k]=y;
int tmp,x,y;
ll z;
//链式前向星
int w[maxn],a[maxn];
int siz[maxn],son[maxn],d[maxn],fa[maxn];
//size是Linux保留字
int id[maxn],top[maxn],cnt;
void dfs1(int num,int pre){
siz[num]=1;
fa[num]=pre;
d[num]=d[pre]+1;
int maxx=0;
for(int i=head[num];i;i=nex[i])
if(to[i]!=pre){
dfs1(to[i],num);
if(siz[to[i]]>maxx) maxx=siz[to[i]],son[num]=to[i];
siz[num]+=siz[to[i]];
}
return;
}
void dfs2(int num,int topf){
top[num]=topf; id[num]=++cnt;
if(!son[num]) return;
dfs2(son[num],topf);
for(int i=head[num];i;i=nex[i])
if(to[i]!=fa[num] && to[i]!=son[num])
dfs2(to[i],to[i]);
return;
}
//预处理
//以下为线段树
int L,R;
ll C;
ll sum[maxn<<2],f[maxn<<2];
void up(int rt){sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%MOD;return;}
void build(int l,int r,int rt){
if(l==r){
sum[rt]=a[l]%MOD;
return;
}
int m=(l+r)>>1;
build(l,m,rt<<1);
build(m+1,r,rt<<1|1);
up(rt);
}
void down(int ln,int rn,int rt){
if(f[rt]){
f[rt<<1]+=f[rt]; f[rt<<1]%=MOD;
f[rt<<1|1]+=f[rt]; f[rt<<1|1]%=MOD;
sum[rt<<1]+=f[rt]*ln; sum[rt<<1]%=MOD;
sum[rt<<1|1]+=f[rt]*rn; sum[rt<<1|1]%=MOD;
f[rt]=0;
}
return;
}
void update(int l,int r,int rt){
if(L<=l&&r<=R){
f[rt]+=C;
sum[rt]+=C*(r-l+1);
return;
}
int m=(l+r)>>1;
down(m-l+1,r-m,rt);
if(m>=L) update(l,m,rt<<1);
if(m<R) update(m+1,r,rt<<1|1);
up(rt);
}
ll find(int l,int r,int rt){
if(L<=l&&r<=R)
return sum[rt];
int m=(l+r)>>1;
ll s=0;
down(m-l+1,r-m,rt);
if(m>=L) s=(s+find(l,m,rt<<1))%MOD;
if(m<R) s=(s+find(m+1,r,rt<<1|1))%MOD;
return s;
}
//以上为线段树
void init(){
dfs1(root,0);
dfs2(root,root);
for(int i=1;i<=n;i++)
a[id[i]]=w[i];
build(1,n,1);
return;
}
void swap(int &x,int &y) {x^=y;y^=x;x^=y;}//交换
void updateR(int u,int v,ll c){
C=c;
while(top[u]!=top[v]){
if(d[top[u]]<d[top[v]]) swap(u,v);
L=id[top[u]]; R=id[u];
update(1,n,1);
u=fa[top[u]];
}
if(d[u]<d[v]) swap(u,v);
L=id[v]; R=id[u];
update(1,n,1);
return;
}
ll findR(int u,int v){
ll ans=0;
while(top[u]!=top[v]){
if(d[top[u]]<d[top[v]]) swap(u,v);
L=id[top[u]]; R=id[u];
ans+=find(1,n,1);
ans%=MOD;
u=fa[top[u]];
}
if(d[u]<d[v]) swap(u,v);
L=id[v]; R=id[u];
ans+=find(1,n,1);
return ans;
}
void updateS(int rt,ll c){
L=id[rt]; R=id[rt]+siz[rt]-1; C=c;
update(1,n,1);
}
ll findS(int rt){
L=id[rt]; R=id[rt]+siz[rt]-1;
return find(1,n,1);
}
int main(){
scanf("%d%d%d%lld",&n,&T,&root,&MOD);
for(int i=1;i<=n;i++)
scanf("%d",&w[i]);
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y) add(y,x)
}
init();
while(T--){
scanf("%d",&tmp);
if(tmp==1){
scanf("%d%d%lld",&x,&y,&z);
updateR(x,y,z);
}
else if(tmp==2){
scanf("%d%d",&x,&y);
printf("%lld\n",findR(x,y)%MOD);
}
else if(tmp==3){
scanf("%d%lld",&x,&z);
updateS(x,z);
}
else{
scanf("%d",&x);
printf("%lld\n",findS(x)%MOD);
}
}
return 0;
}