树链剖分

前言

其实就是预处理+线段树。

目的:把树拆成链再用线段树处理(其实还是一种暴力,只是有点优化)。

为什么不直接拆成从根节点到每一个叶子结点的链?

假如更改一个节点(有多个子节点),那么就要修改几个线段树中的值了,会影响效率。

所以在这些链中不能有重叠部分。

为了提高线段树的效率,我们要尽量把一条链搞长一点,而不是更多的链。

所以我们就想到了下方预处理的办法。

先来回顾几个问题:

  1. 将树从\(x\)到\(y\)结点最短路径上所有节点的值都加上\(z\)(树上差分即可);

  2. 求树从\(x\)到\(y\)结点最短路径上所有节点的值之和(LCA即可);

  3. 将以\(x\)为根节点的子树内所有节点值都加上z(dfs序+差分即可);

  4. 求以\(x\)为根节点的子树内所有节点值之和(同3)。

但是:假如把几个问题放在一起咋做?

于是树链剖分闪亮登场!

准备

先说些概念:

  • 重儿子:父亲节点的所有儿子中子树结点数目最多(\(size\)最大)的结点;

  • 轻儿子:父亲节点中除了重儿子以外的儿子;

  • 重边:父亲结点和重儿子连成的边;(下图加粗)

  • 轻边:父亲节点和轻儿子连成的边;

  • 重链:由多条重边连接而成的路径;

  • 轻链:由多条轻边连接而成的路径;

树链剖分

对数组的一些解释:

名称 解释
\(f[u]\) 保存结点u的父亲节点
\(dep[u]\) 保存结点u的深度值
\(size[u]\) 保存以u为根的子树节点个数
\(son[u]\) 保存重儿子
\(top[u]\) 保存当前节点所在链的顶端节点(上图红点)
\(id[u]\) 保存树中每个节点剖分以后的新编号(DFS的执行顺序)

我们的目标就是把上图拆成一下几条链:

\(1\rightarrow4\rightarrow9\rightarrow13\rightarrow14\)

\(2−>6−>11\)

\(3−>7\)

\(5\)

\(8\)

\(10\)

\(12\)

处理

预处理:求f、d、size、son数组

void dfs1(int u,int fa){
	f[u]=fa;
	dep[u]=dep[fa]+1;
	size[u]=1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa)continue;
		dfs1(v,u);
		size[u]+=size[v];
		if(size[v]>size[son[u]])son[u]=v;
	}
}

结果:

树链剖分

预处理:求出top、rk、id数组(dfs序)

void dfs2(int u,int t){
	top[u]=t;
	id[u]=++cnt;
	a[cnt]=w[u];
	if(!son[u])return;
	dfs2(son[u],t);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=son[u]&&v!=f[u])dfs2(v,v);
	}
}

结果:

树链剖分

LCA操作咋办?有没有注意到top数组?它就是LCA中的"跳"的变形。

LCA

其实也可以不写。

妈妈再也不用担心我不会倍增啦!

这里使用了\(top\)来进行加速,因为\(top\)可以直接跳转到该重链的起始结点,轻链没有起始结点之说,他们的\(top\)就是自己。需要注意的是,每次循环只能跳一次,并且让结点深的那个来跳到\(top\)的位置,避免两个一起跳从而擦肩而过。

int lca(int x,int y){
	int fx=top[x],fy=top[y];
	while(fx!=fy){
		if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
		x=f[fx],fx=top[x];
	}
	return dep[x]<dep[y]?x:y;
}

修改链

在LCA的基础上也可以这么写:

void updata_lian(int x,int y,int z){
	int fx=top[x],fy=top[y];
	while(fx!=fy){
		if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
		updata(id[fx],id[x],z,1,cnt,1);
		x=f[fx],fx=top[x];
	}
	if(id[x]>id[y])swap(x,y);
	updata(id[x],id[y],z,1,cnt,1);
}

计贡献

在LCA的基础上也可以这么写:

int query_lian(int x,int y){
	int fx=top[x],fy=top[y],sum=0;
	while(fx!=fy){
		if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
		inc(sum,query(id[fx],id[x],1,cnt,1));
		x=f[fx],fx=top[x];
	}
	if(id[x]>id[y])swap(x,y);
	inc(sum,query(id[x],id[y],1,cnt,1));
	return sum;
}

这样就差不多了。

上题:

P3384 【模板】轻重链剖分

没什么好说的。

注意一个点:\(mod\)不是题目固定的,也就是说:\(a[i]\)可能大于\(mod\),甚至是几倍。

最好双重保险,函数返回前加一个取模。

出题人不要脸。

#include<bits/stdc++.h>
#define ll long long
#define LC x<<1
#define RC x<<1|1
#define N 1100000
using namespace std;
ll n,m,r,mod,op,x,y,z,cnt,tot,head[N],size[N],top[N],dep[N],son[N],id[N],f[N],a[N],w[N];
struct node{
	ll to,nxt;
}e[N*5];
struct nd_tree{
	ll sum,lazy,len;
}t[N*5];
void inc(ll &a,ll b){a+=b;if(a>=mod)a-=mod;}
void add(ll f,ll to){
	e[++tot].to=to;
	e[tot].nxt=head[f];
	head[f]=tot;
}
ll lca(ll x,ll y){
	ll fx=top[x],fy=top[y];
	while(fx!=fy){
		if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
		x=f[fx],fx=top[x];
	}
	return dep[x]<dep[y]?x:y;
}
void dfs1(ll u,ll fa){
	f[u]=fa;
	dep[u]=dep[fa]+1;
	size[u]=1;
	for(ll i=head[u];i;i=e[i].nxt){
		ll v=e[i].to;
		if(v==fa)continue;
		dfs1(v,u);
		size[u]+=size[v];
		if(size[v]>size[son[u]])son[u]=v;
	}
}
void dfs2(ll u,ll t){
	top[u]=t;
	id[u]=++cnt;
	a[cnt]=w[u];
	if(son[u])dfs2(son[u],t);
	for(ll i=head[u];i;i=e[i].nxt){
		ll v=e[i].to;
		if(v!=son[u]&&v!=f[u])dfs2(v,v);
	}
}
void pushup(ll x){
	t[x].sum=t[LC].sum+t[RC].sum;
}
void pushdown(ll x){
	if(t[x].lazy==0)return;
	inc(t[LC].lazy,t[x].lazy);
	inc(t[RC].lazy,t[x].lazy);
	inc(t[LC].sum,t[LC].len*t[x].lazy);
	inc(t[RC].sum,t[RC].len*t[x].lazy);
	t[x].lazy=0;
}
void build(ll l,ll r,ll x){
	t[x].len=r-l+1;
	if(l==r){t[x].sum=a[l];return;}
	ll mid=(l+r)>>1;
	build(l,mid,LC);
	build(mid+1,r,RC);
	pushup(x);
}
void updata(ll ql,ll qr,ll c,ll l,ll r,ll x){
	if(ql<=l&&qr>=r){
		inc(t[x].sum,t[x].len*c);
		inc(t[x].lazy,c);
		return;
	}
	pushdown(x);
	ll mid=(l+r)>>1;
	if(ql<=mid)updata(ql,qr,c,l,mid,LC);
	if(qr>mid)updata(ql,qr,c,mid+1,r,RC);
	pushup(x);
}
ll query(ll ql,ll qr,ll l,ll r,ll x){
	if(ql<=l&&qr>=r)return t[x].sum;
	pushdown(x);
	ll mid=(l+r)>>1,sum=0;
	if(ql<=mid)sum+=query(ql,qr,l,mid,LC);
	if(qr>mid)sum+=query(ql,qr,mid+1,r,RC);
	return sum%mod;
}
void updata_lian(ll x,ll y,ll z){
	ll fx=top[x],fy=top[y];
	while(fx!=fy){
		if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
		updata(id[fx],id[x],z,1,cnt,1);
		x=f[fx],fx=top[x];
	}
	if(id[x]>id[y])swap(x,y);
	updata(id[x],id[y],z,1,cnt,1);
}
ll query_lian(ll x,ll y){
	ll fx=top[x],fy=top[y],sum=0;
	while(fx!=fy){
		if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
		inc(sum,query(id[fx],id[x],1,cnt,1));
		x=f[fx],fx=top[x];
	}
	if(id[x]>id[y])swap(x,y);
	inc(sum,query(id[x],id[y],1,cnt,1));
	return sum;
}
int main(){
	scanf("%lld%lld%lld%lld",&n,&m,&r,&mod);
	for(ll i=1;i<=n;i++)scanf("%lld",&w[i]),w[i]%=mod;
	for(ll i=1;i<n;i++)scanf("%lld%lld",&x,&y),add(x,y),add(y,x);
	dep[r]=0;f[r]=1;
	dfs1(r,0);
	dfs2(r,r);
	build(1,n,1);
	for(ll i=1;i<=m;i++){
		scanf("%lld",&op);
		if(op==1){
			scanf("%lld%lld%lld",&x,&y,&z);z%=mod;
			updata_lian(x,y,z);
		}
		if(op==2){
			scanf("%lld%lld",&x,&y);
			printf("%lld\n",query_lian(x,y));
		}
		if(op==3){
			scanf("%lld%lld",&x,&y);y%=mod;
			updata(id[x],id[x]+size[x]-1,y,1,n,1);
		}
		if(op==4){
			scanf("%lld",&x);
			printf("%lld\n",query(id[x],id[x]+size[x]-1,1,n,1));
		}
	}
}

后记

参考资料强力推荐

上一篇:P1101 单词方阵


下一篇:javascript基础函数4.1