树链剖分入门

前置芝士:低复杂度的区间操作算法(线段树等)+dfs+LCA

例题引入:树链剖分模板

题目要求我们对树上的路径和子树进行修改与查询

暴力:

任何算法的优化都是源于暴力,

对于路径的修改,我们可以直接采用LCA求出A和B的公共祖先C,对于(A,C)和(B,C)两条路径进行加和

对于子树的修改,我们可以直接枚举并进行暴力修改

由于数据达到1e5级别,暴力算法被淘汰。

改进:

该种片修改的暴力都逃不过数据结构的优化。

我们每次的修改都是2条边或一颗子树,由此我们很容易想到一种思路:把这颗线段树压成一维,即把线段树的点重新压入数据结构中,那么对于结点A到结点C的(A,C)修改,只要修改区间(A,C)。

但是这样正确吗?

明显是不对的,因为不论如何组合,一个结点A可以有很多个儿子,如儿子B,儿子C,那么保证(A,B)在数据结构中形成连续区间的同时,(A,C)就不可能形成。

但是这真的完全不正确吗?

也不对,如上述,对于一个结点,我们可以保证他与其中一个儿子的区间修改成立

我们先对于上路思路进行初步实现:

首先我们对这棵树进行DFS序的标记,把x结点的DFS序标号记为id[x],以x为根的子树大小(包含结点数量)记为sz[x]

原树:

树链剖分入门

 

 

 dfs后:

树链剖分入门

 

 

 

 

 

 由此我们发现一些特性:

1.对于任意结点x,以x为根的子树中的结点的编号是连续的(设k为该子树中dfs序最大的结点,则id[k]=id[x]+sz[x]-1)

 

 

2.对于任意结点x,x的一个儿子必定为x的编号+1

 

我们再把点维护在一个数据结构上,

那么对于一棵以x为根结点的子树修改和查询,相当于对[id[x],id[x]+sz[x]-1]这个区间进行查询。

如图,对以5为根节点的子树进行查询

树链剖分入门

 

 

那我们来思考:如何对一条路径进行修改

首先,我们可以把x到y路径拆分成两部分,设k为LCA(x,y),则(x,y)=(x,k)+(y,k),两部分的性质相同,所以我们对其中的一部分做分析,暂且把这部分叫做‘单路径’。

对于上面的那棵树,我们直接把他的dfs序标上去,我们发现,以下区间是连续的,连续即可在数据结构中进行修改,那么整棵树的结点,就可以拆为许多区间。

树链剖分入门

 

如果修改或查询(1,5)的路径,那么我们只要在数据结构中访问绿色部分的区间即可。

树链剖分入门

 

 

这就是剖分的部分,但是无规则的剖分带来的只是复杂度的极不稳定。

我们可以想到有一种最坏情况:

树链剖分入门

 

如上图的排列,使得最右边的路径的id编号始终是不连续的,那么查询(1,7)的路径时,就退化O(n)修改和查询了

那么我们必须想到先把连续的id编号分配给子树结点最多的点(即在dfs序的标记时优先搜索子树多的点)

树链剖分入门

 

我们引入几个概念方便理解:

对于一个结点x,他儿子中子树结点最多的结点称为重儿子,其他的儿子称为轻儿子

结点x与重儿子的边为重边,与轻儿子的边为轻边,

几条连续的重边连起来称为重链,(注意重链是一段连续的dfs序,所以一个轻儿子本身也是一个重链)

树链剖分入门

 

 

 在以x为根节点的子树中,设其重儿子为k,任意一个轻儿子为n

那么易得一个性质:sz[n]<sz[x]/2,  (x本身也有一个点,所以>=sz[x]/2的不是轻儿子)

对于每一个单路径(u,v),dep[u]<=dep[v]的操作,我们只需要把(u,v)间的所有重链找出进行访问就行,我们可以如下考虑

1.v是轻儿子,由于性质,所以从u出发到v,每经过一条轻边,子树结点个数少一半,那么轻边最多log(n)条

2.v是重儿子,由于重链与重链之间由轻边相连(看图),所以重链最多也不超过log(n)条

所以对于单路径的查询复杂度为O(logn*数据结构复杂度k)

实现

我们需要进行两次dfs来完成预处理

预处理内容如下

1.id[x]为x的dfs序编号

2.top[x]为x所在重链的顶点

3.fa[x]为x的父亲

4.heavy[x]为x的重儿子

5.sz[x]为x的子树结点个数

6.dep[x]为x的深度

第一遍dfs处理出sz,fa,dep

void Dfs1(int x,int f){
    sz[x]=1,fa[x]=f;dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x];i;i=edge[i].nxt){
        int dr=edge[i].to;
        if(dr==f) continue;
        Dfs1(dr,x);
        sz[x]+=sz[dr];
        if(sz[dr]>maxson) maxson=sz[dr],heavy[x]=dr;
    }
    return;
}

第二遍处理出id,heavy,top

int dfscnt;
void Dfs2(int x,int tops){
    id[x]=++dfscnt;
    Updata(Root,1,n,id[x],id[x],1LL*data[x]);
    top[x]=tops;
    if(!heavy[x]) return;
    Dfs2(heavy[x],tops);
    for(int i=head[x];i;i=edge[i].nxt){
        int dr=edge[i].to;
        if(dr==fa[x]||dr==heavy[x]) continue;
        Dfs2(dr,dr);
    }
}

对于每次访问(x,y),我们把一个点跳到他重链的顶点上,并对该区间在数据结构中进行访问,我们从top值深度较大的点跳,以保证两点最终top值能碰到,两点的top值相等即保证此时两点在同一重链上

void Uplink(int x,int y,ll z){
    ll res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        Updata(Root,1,n,id[top[x]],id[x],z);
        x=fa[top[x]];
     }
     if(dep[x]>dep[y]) swap(x,y);
     Updata(Root,1,n,id[x],id[y],z);
     return;
}
ll Quelink(int x,int y){
    ll res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        res=(res+Query(Root,1,n,id[top[x]],id[x]))%P;
        x=fa[top[x]];
     }
     if(dep[x]>dep[y]) swap(x,y);
     return (res+Query(Root,1,n,id[x],id[y]))%P;
}

贴AC代码

#include<bits/stdc++.h>
#define ls(x) t[x].ls
#define rs(x) t[x].rs
#define ll long long
using namespace std;
const int maxn = 1e5+10;

int Root,n,m,rt,P,data[maxn];

inline int read(){
    int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+(c^48);c=getchar();}
    return x*f;
}

int trcnt;
struct TREE{
    int ls,rs;
    ll sum,lzy;
}t[maxn<<3];
void Push_down(int x,int a,int b){
    if(!ls(x)) ls(x)=++trcnt; 
    if(!rs(x)) rs(x)=++trcnt;
    int mid=a+b>>1;
    t[ls(x)].sum=(t[ls(x)].sum+1LL*(mid-a+1)*t[x].lzy%P)%P;
    t[rs(x)].sum=(t[rs(x)].sum+1LL*(b-mid)*t[x].lzy%P)%P;
    t[ls(x)].lzy=(t[ls(x)].lzy+t[x].lzy)%P;
    t[rs(x)].lzy=(t[rs(x)].lzy+t[x].lzy)%P;
    t[x].lzy=0;
}
void Updata(int &x,int a,int b,int L,int R,ll val){
    if(!x) x=++trcnt;
    if(L<=a&&b<=R){
        t[x].sum+=1LL*(b-a+1)*val;
        t[x].lzy+=val;
        return;
    } 
    Push_down(x,a,b);
    int mid=(a+b)>>1;
    if(L<=mid) Updata(ls(x),a,mid,L,R,val);
    if(R>=mid+1) Updata(rs(x),mid+1,b,L,R,val);
    t[x].sum=t[ls(x)].sum+t[rs(x)].sum;
}
ll Query(int &x,int a,int b,int L,int R){
    if(!x) x=++trcnt;
    if(L<=a&&b<=R) return t[x].sum;
    Push_down(x,a,b);
    int mid=(a+b)>>1;ll res=0;
    if(L<=mid) res+=Query(ls(x),a,mid,L,R);
    res%=P; 
    if(R>=mid+1) res+=Query(rs(x),mid+1,b,L,R);
    res%=P;
    return res;
}
//cut line
int top[maxn],sz[maxn],id[maxn],dep[maxn],fa[maxn],heavy[maxn];
int head[maxn],edge_num;
struct EDGE{
    int to,nxt;
}edge[maxn<<1];
void add_edge(int from,int to){
    edge[++edge_num].to=to;
    edge[edge_num].nxt=head[from];
    head[from]=edge_num;
}
void Dfs1(int x,int f){
    sz[x]=1,fa[x]=f;dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x];i;i=edge[i].nxt){
        int dr=edge[i].to;
        if(dr==f) continue;
        Dfs1(dr,x);
        sz[x]+=sz[dr];
        if(sz[dr]>maxson) maxson=sz[dr],heavy[x]=dr;
    }
    return;
}
int dfscnt;
void Dfs2(int x,int tops){
    id[x]=++dfscnt;
    Updata(Root,1,n,id[x],id[x],1LL*data[x]);
    top[x]=tops;
    if(!heavy[x]) return;
    Dfs2(heavy[x],tops);
    for(int i=head[x];i;i=edge[i].nxt){
        int dr=edge[i].to;
        if(dr==fa[x]||dr==heavy[x]) continue;
        Dfs2(dr,dr);
    }
}
//cut line
void Uplink(int x,int y,ll z){
    ll res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        Updata(Root,1,n,id[top[x]],id[x],z);
        x=fa[top[x]];
     }
     if(dep[x]>dep[y]) swap(x,y);
     Updata(Root,1,n,id[x],id[y],z);
     return;
}
ll Quelink(int x,int y){
    ll res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        res=(res+Query(Root,1,n,id[top[x]],id[x]))%P;
        x=fa[top[x]];
     }
     if(dep[x]>dep[y]) swap(x,y);
     return (res+Query(Root,1,n,id[x],id[y]))%P;
}

int main(){
    n=read(),m=read(),rt=read(),P=read();
    for(register int i=1;i<=n;++i) data[i]=read();
    for(register int i=1;i<=n-1;++i){
        int x=read(),y=read();
        add_edge(x,y),add_edge(y,x);
    }
    Dfs1(rt,0);
    Dfs2(rt,rt);
    for(register int i=1;i<=m;++i){
        int opt=read(),x,y,z;
        if(opt==1){
            x=read(),y=read(),z=read();
            z%=P;
            Uplink(x,y,z);
        }
        if(opt==2){
            x=read(),y=read();
            printf("%lld\n",Quelink(x,y));
        }
        if(opt==3){
            x=read(),y=read();
            y%=P;
            Updata(Root,1,n,id[x],id[x]+sz[x]-1,1LL*y);
        }
        if(opt==4){
            x=read();
            printf("%lld\n",Query(Root,1,n,id[x],id[x]+sz[x]-1));
        }
    }
}

 

上一篇:【OpenYurt 深度解析,Java编程教程视频


下一篇:Asp.net获取数据库中所有的表名