点分治 学习笔记

0. 点分治的用途

点分治可以解决树上的关于路径的问题,例如 洛谷P4178 Tree。(题目大意:给定一棵 \(n\) 个节点的树,每条边有边权,求出树上两点距离小于等于 \(k\) 的点对数量)这道题如果使用 \(O(n^2)\) 的暴力算法 显然 会T飞 ,然而之后您就会看到,点分治算法可以在 \(O(n\log^2 n)\) 的时间复杂度内解决它。

1.思想

顾名思义,点分治使用了分治的思想,把原问题拆分成若干个子问题,分别求解后再合并。把大象装进冰箱里需要几步?

1.1. 分

注意到如果指定一个节点为根节点,那么一个路径有可能有以下两个来源:

1.点分治 学习笔记路径经过根节点;
2.点分治 学习笔记路径完全被子树包含。

到这里一个分治算法已经呼之欲出了——可以递归地处理第二种情况,只需要在算法中考虑第一种情况就可以了。

1.2. 治

第二种情况可以递归地处理,并且递归到叶子节点时就不需要考虑第二种情况了(根本没有子树),所以这里主要考虑第一种情况。

按顺序考虑每一个子树,用一个树状数组(或者是一个别的什么数据结构)来维护根节点到已经考虑过的每一个节点,在新加入一个子树时对于新子树的每一个节点求出有多少个根节点到“老节点”的距离小于 (\(k-\) (根节点到新节点的距离))并统计进答案,再把每一个新节点塞进树状数组里。这样就解决了第一种情况了。

最好结合代码理解:

//mark:树状数组 g:图
int dis[MAXN+5],tail;
void get_dis(int u,int fa,int now){
    dis[++tail]=now;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v,w=g[u][i].w;
        if(!removed[v]&&v!=fa){
            get_dis(v,u,now+w);
        }
    }
}
int calc(int u){//处理第一种情况
    int res=0;
    ta=0;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v,w=g[u][i].w;
        if(!removed[v]){
            tail=0;
            get_dis(v,u,w);
            for(int j=1;j<=tail;j++){//统计
                if(dis[j]<=k)res+=mark.query(k-dis[j]);
            }
            for(int j=1;j<=tail;j++){//加入
                if(dis[j]<=k){
                    mark.add(dis[j],1);
                    res++;
                    vis[++ta]=dis[j];
                }
            }
        }
    }
    for(int i=1;i<=ta;i++){//复原树状数组方便下次使用
        if(vis[i]<=k){
            mark.add(vis[i],-1);
        }
    }
    return res;
}

1.3. 合

只需要无脑地把每种情况加起来就可以了(

1.4. “细节”

在递归时一定要用子树的重心作为根节点,这样才能保证时间复杂度最优。

至于证明,关于此,我确信已发现了一种美妙的证法 ,可惜这里空白的地方太小,写不下。 其实是我太菜了不会证TwT 可以感性理解一下,根节点取重心可以使问题分割地尽可能均匀,分治就跑得飞快了。

2. 时间复杂度

据说是 \(O(n\log^2 n)\) 的,然而我不会证啊qaq

如果递归时根节点不取重心,时间复杂度会退化为 \(O(n^2\log n)\),还不如暴力。

3. Code

#include <bits/stdc++.h>
using namespace std;
#define MAXN 40000
#define MAXW 1000
#define INF 0x3fffffff
struct BIT{//一般通过树状数组
    int tree[MAXN*MAXW+5];
    int query(int x){
        int res=0;
        while(x){
            res+=tree[x];
            x-=x&-x;
        }
        return res;
    }
    void add(int x,int k){
        while(x<=MAXN*MAXW+2){
            tree[x]+=k;
            x+=x&-x;
        }
    }
};
int n,k;
struct edge{
    int v,w;
    edge(){v=w=0;}
    edge(int _v,int _w){v=_v;w=_w;}
};
vector<edge> g[MAXN+5];
bool removed[MAXN+5];//由于每个点在递归之后就没用了,所以要打个标记把它移除掉
int zx,zx_maxx=INF,siz[MAXN+5];
void get_zx(int u,int fa,int tot){//求重心
    siz[u]=1;
    int maxx=0;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v;
        if(!removed[v]&&v!=fa){
            get_zx(v,u,tot);
            siz[u]+=siz[v];
            maxx=max(maxx,siz[v]);
        }
    }
    maxx=max(maxx,tot-siz[u]);
    if(zx_maxx>maxx){
        zx=u;
        zx_maxx=maxx;
    }
}
int dis[MAXN+5],tail;
void get_dis(int u,int fa,int now){//求节点到每一个其他点的距离
    dis[++tail]=now;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v,w=g[u][i].w;
        if(!removed[v]&&v!=fa){
            get_dis(v,u,now+w);
        }
    }
}
BIT mark;
int vis[MAXN+5],ta;
int calc(int u){//求当前子树经过根节点的合法路径数
    int res=0;
    ta=0;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v,w=g[u][i].w;
        if(!removed[v]){
            tail=0;
            get_dis(v,u,w);
            for(int j=1;j<=tail;j++){
                if(dis[j]<=k)res+=mark.query(k-dis[j]);
            }
            for(int j=1;j<=tail;j++){
                if(dis[j]<=k){
                    mark.add(dis[j],1);
                    res++;
                    vis[++ta]=dis[j];
                }
            }
        }
    }
    for(int i=1;i<=ta;i++){
        if(vis[i]<=k){
            mark.add(vis[i],-1);
        }
    }
    return res;
}
int work(int u){//分治
    removed[u]=true;
    int res=calc(u);
    //cout<<u<<" "<<res<<endl;
    for(int i=0;i<g[u].size();i++){
        int v=g[u][i].v,w=g[u][i].w;
        if(!removed[v]){
            zx_maxx=INF;
            get_zx(v,0,siz[v]);
            res+=work(zx);
        }
    }
    return res;
}
int main(){
    ios::sync_with_stdio(false);
    cin>>n;
    for(int i=1;i<=n-1;i++){
        int u,v,w;
        cin>>u>>v>>w;
        g[u].push_back(edge(v,w));
        g[v].push_back(edge(u,w));
    }
    cin>>k;
    zx_maxx=INF;
    get_zx(1,0,n);
    cout<<work(zx)<<endl;
    return 0;
}

4. 点一个赞!

上一篇:2021.8.22北高暑训


下一篇:CF351D Jeff and Removing Periods Ⅱ