LCA

目录

LCA

1. 算法分析

1.1 求LCA的四种方法

1.树上倍增法:
    倍增思想:\(f[i][j]\)表示i这个位置向上走2^j步后到达x,则有状态转移:\(f[y][j] = f[[y][j-1]][j-1]\),利用这个不断处理出f数组,树上倍增法能够得到 \(f[x][i]\) 数组和 \(d[i]\) 数组,利用这两个数组可以求出很多的东西。这是在线做法
2.tarjan算法:
    dfs的特性和并查集的特性。把所有点分成三类,第一类:正在搜索的点,第二类:已经回溯完的点,第三类:还没有搜索过的点,每次搜索的时候,记当前点为x,把点x做个标记,然后判断和这个点对应的点y是否已经回溯过了,如果y已经回溯过了,那么x和y的lca即为y的get(y)得到的节点。这是离线做法
3.dfs+ST
4.树剖求lca

1.2 求lca的两种场景

  1. 求任意两个点的lca
  2. 求集合的lca:求一个集合的lca就是求这个集合中dfs序最小的点和dfs序最大的点的lca

2. 板子

2.1 树上倍增法

HDU 2586

#include <bits/stdc++.h>

using namespace std;

const int N = 2e5 + 10;
int f[N][20], d[N], dist[N];  // f[i][j]表示从i开始,往上走2^j步到达的点,d为深度,dist为距离
int e[N], ne[N], h[N], idx, w[N];
int T, n, m, t;  // t为数的深度
queue<int> q;

void add(int a, int b, int c)
{
    e[idx] = b,  w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

// 预处理:得到每个点的深度,距离,f数组
void bfs()
{
    q.push(1);  // 把根放入队列,注意这里有可能根不是1
    d[1] = 1;
    while (q.size())
    {
        int x = q.front();
        q.pop();
        for (int i = h[x]; i != -1; i = ne[i])
        {
            int y = e[i];
            if (d[y]) continue;
            d[y] = d[x] + 1;  // 更新深度
            dist[y] = dist[x] + w[i];  // 更新距离
            
            // 进行dp更新
            f[y][0] = x;
            for (int j = 1; j <= t; ++j)
            {
                f[y][j] = f[f[y][j - 1]][j - 1];  // 分两段处理
            }
            q.push(y);
        }
    }
}

// 查找x和y的最近公共祖先
int lca(int x, int y)
{
    if (d[x] > d[y]) swap(x, y);  // 保证x的深度浅一点
    for (int i = t; i >= 0; --i)
        if (d[f[y][i]] >= d[x]) y = f[y][i];  // 让x和y到同一个深度
    if (x == y) return x;
    for (int i = t; i >= 0; --i)  // 让x和y之差一步就能相遇
    {
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    }
    return f[x][0];
}

int main()
{
     cin >> T;
     while (T--)
     {
         memset(h, -1, sizeof h);
         idx = t = 0;
         cin >> n >> m;
         t = (int)(log(n) / log(2)) + 1; // 得到树的深度
         
         // 读入一棵树
         for (int i = 0; i < n - 1; ++i)
         {
             int a, b, c;
             scanf("%d %d %d", &a, &b, &c);
             add(a, b, c), add(b, a, c);
         }
        
        bfs();
        // 回答询问
        for (int i = 1; i <= m; ++i)
        {
            int a, b;
            scanf("%d %d", &a, &b);
            printf("%d\n", dist[a] + dist[b] - 2 * dist[lca(a, b)]);
        }
     }
    return 0;
}

2.2 tarjan算法

HDU 2586

#include <bits/stdc++.h>

using namespace std;

int const N = 1e5 + 10;
int e[N], ne[N], idx, ans[N], v[N], fa[N], d[N], h[N], w[N];
vector<int> query[N], query_id[N];
int t, n, m;

void add(int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a] , h[a] = idx++;
}

// 并查集查询+路径压缩
int get(int x)
{
    if (fa[x] != x) fa[x] = get(fa[x]);
    return fa[x];
}

// tarjan算法求lca
void tarjan(int x)
{
    // 记录这个点走过一次,但是还没有回溯
    v[x] = 1;
    
    // 遍历每一个和x相连的点
    for (int i = h[x]; i != -1; i = ne[i])
    {
        int y = e[i];
        if (v[y]) continue;  // 这个点走过的话,不进行后面的操作
        d[y] = d[x] + w[i];
        tarjan(y);  // 得到y为根节点的所有子树的d
        fa[y] = x;  // 更新y的父节点
    }
    
    // 判断和x有关的lca询问
    for (int i = 0; i < query[x].size(); ++i)
    {
        int y = query[x][i];
        int id = query_id[x][i];
        if (v[y] == 2)
        {
            int lca = get(y);  // 获得lca:如果y点回溯,那么lca为get(y)
            ans[id] = min(ans[id], d[x] + d[y] - 2 * d[lca]);  // 更新答案
        }
    }
    v[x] = 2;  // 标记x点回溯
}

int main()
{
    cin >> t;
    while (t--)
    {
        cin >> n >> m;
        
        // 初始化
        memset(h, -1, sizeof h);
        idx = 0;
        for (int i = 1; i <= n; ++i)
        {
            fa[i] = i;
            query[i].clear(), query_id[i].clear();
        }
        
        // 读入树边
        for (int i = 1; i < n; ++i)
        {
            int a, b, c;
            scanf("%d %d %d", &a, &b, &c);
            add(a , b, c), add(b, a, c);
         }
         
        // 读入询问的边
        for (int i = 1; i <= m; ++i)
        {
            int a, b;
            scanf("%d %d", &a, &b);
            if (a == b)
            {
                ans[i] = 0;
            }
            else
            {
                query[a].push_back(b), query[b].push_back(a);
                query_id[a].push_back(i), query_id[b].push_back(i);
                ans[i] = (1 << 30);
            }
        }
        
        // 做dfs求lca
        tarjan(1);
        
        // 输出答案
        for (int i = 1 ; i <= m; ++i)
        {
            cout << ans[i] << endl;
        }
    }
    return 0;
}

2.3 dfs+ST

/*
我们需要维护数组oula[i]= j表示j的dfs序为i,pos[i]=j表示i第一次在dfs序中出现的位置是j
len记录dfs序的长度,dp[i][j]表示从i点出发走2^j步的范围内最小的深度的点的坐标,de[i]=j表示i的深度为j

本算法为dfs+st表求lca,预处理时间O(nlogn), 查询O(1)

算法步骤:
1. dfs:预处理出dfs序
2. ST:预处理dp数组
3. 给定任意两个点:通过dp数组得出这两个点间的深度最小的点,即为lca
*/
#include <bits/stdc++.h>

using namespace std;

int const N = 1e5 + 1, M = N * 2;
int e[M], ne[M], h[N], idx;
int t, n, m;
int oula[M], len, pos[N], dp[M][23], de[M];
typedef pair<int, int> PII;
set<PII> s;

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

// 求dfs序
void dfs(int u, int fa, int d)
{
    oula[++len] = u;  // 记录第len个为u点
    pos[u] = len;  // 记录u点的dfs序为len(只记录u点第一次出现的dfs序即可)
    de[len] = d;  // 记录深度
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;

        dfs(j, u, d + 1);
        oula[++len] = u;  // 回溯时还要记录
        de[len] = d;  
    }
}

// 得到x和y中深度比较小的那个
int Min(int x, int y)
{
    return de[x] > de[y]? y: x;
}

// 处理dp数组
void ST()
{
    for (int i = 1; i <= len; ++i) dp[i][0] = i;

    for (int j = 1; (1 << j) <= len; ++j )
        for (int i = 1; i + (1 << j) - 1 <= len; ++i )
            dp[i][j] = Min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);  // dp[i][j]是dp[i][j - 1]和dp[i + (1 << (j - 1))][j - 1]中深度更小的那个点的下标
}

// 求lca
int lca(int x, int y)
{
    // int x = pos[x], y = pos[y];  // 得到x和y点的dfs序的下标
    if (x > y) swap(x, y);
    int k = log2(y - x + 1);  // 计算在dfs序列中y和x的距离
    return Min(dp[x][k], dp[y - (1 << k) + 1][k]);  // lca为x~y范围内深度最小的那个dfs序的下标
}

int main()
{
    cin >> t;
    while (t--)
    {
        scanf("%d", &n);
        s.clear();
        memset(h, -1, sizeof h);
        idx = 0, len = 0;
        memset(dp, 0, sizeof dp);
        memset(oula, 0, sizeof oula);
        memset(de, 0, sizeof de);
        memset(pos, 0, sizeof pos);
        for (int i = 1; i <= n - 1; ++i)
        {
            int a, b;
            scanf("%d %d", &a, &b);
            add(a, b);
            add(b, a);
        }
        
        int root = 0;
        dfs(root, 0, 1);  // dfs得出dfs序
        ST();  // 得出dp数组

        scanf("%d", &m);
        while (m--)
        {
            getchar();
            char op = getchar();
            int num;
            scanf("%d", &num);
            if (op == '+') s.insert({pos[num], num});  // 插入
            else   // 删除
            {
                auto it = s.lower_bound({pos[num], num});
                s.erase(it);
            }  

            // 输出答案
            if (s.size() == 1) printf("%d\n", (*s.begin()).second);
            else if (s.size() >= 2)
            {
                auto left = (*s.begin()).first;
                auto right = (*prev(s.end())).first;
                printf("%d\n", oula[lca(left, right)]);
            }
            else printf("-1\n");
        }
    }
    return 0;
}

2.4 树剖求lca

树链剖分.md

3. 典型例题

acwing356 次小生成树
题意: 给定一张 N 个点 M 条边的无向图,求无向图的严格次小生成树。设最小生成树的边权之和为sum,严格次小生成树就是指边权之和大于sum的生成树中最小的一个。\(N≤10^5,M≤3*10^5\)
题解: 本题的思路是在求出最小生成树的基础上,找出一条非树边a->b,然后再树上找出a->b的最大值,删除这个最大值,加上非树边。
基于这个思路,目标就是要找出这个a->b在树边的最大值。
找出a->b在树边的最大值,可以先在树上进行预处理,\(fa[i][j]\)表示i向上走\(2^j\)步到达的点,\(d1[i][j]\)表示i向上走\(2^j\)步范围内的最大值,\(d2[i][j]\)表示i向上走\(2^j\)步范围内的次大值,然后每次在找lca时顺便找出,x到lca的最大值和次大值,y到lca的最大值和次大值,比较即可
代码:

#include<bits/stdc++.h>

using namespace std;

typedef long long LL;

const int N = 100010, M = 300010, INF = 0x3f3f3f3f;

int n, m;
struct Edge
{
    int a, b, w;
    bool used;
    bool operator< (const Edge &t) const
    {
        return w < t.w;
    }
}edge[M];
int p[N];
int h[N], e[M], w[M], ne[M], idx;
int depth[N], fa[N][17], d1[N][17], d2[N][17];
int q[N];

void add(int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

int find(int x)
{
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

// 找最小生成树
LL kruskal()
{
    for (int i = 1; i <= n; i ++ ) p[i] = i;
    sort(edge, edge + m);
    LL res = 0;
    for (int i = 0; i < m; i ++ )
    {
        int a = find(edge[i].a), b = find(edge[i].b), w = edge[i].w;
        if (a != b)
        {
            p[a] = b;
            res += w;
            edge[i].used = true;
        }
    }

    return res;
}

// 建树
void build()
{
    memset(h, -1, sizeof h);
    for (int i = 0; i < m; i ++ )
        if (edge[i].used)
        {
            int a = edge[i].a, b = edge[i].b, w = edge[i].w;
            add(a, b, w), add(b, a, w);
        }
}

// 预处理fa,d1,d2,depth
void bfs()
{
    memset(depth, 0x3f, sizeof depth);
    depth[0] = 0, depth[1] = 1;
    q[0] = 1;  // 把1当成根节点
    int hh = 0, tt = 0;
    while (hh <= tt)
    {
        int t = q[hh ++ ];
        for (int i = h[t]; ~i; i = ne[i])
        {
            int j = e[i];
            if (depth[j] > depth[t] + 1)
            {
                depth[j] = depth[t] + 1;  // 更新j的深度
                q[ ++ tt] = j;
                fa[j][0] = t;
                d1[j][0] = w[i], d2[j][0] = -INF;  // 求出d1和d2
                for (int k = 1; k <= 16; k ++ )
                {
                    int anc = fa[j][k - 1];
                    fa[j][k] = fa[anc][k - 1];
                    int distance[4] = {d1[j][k - 1], d2[j][k - 1], d1[anc][k - 1], d2[anc][k - 1]};
                    d1[j][k] = d2[j][k] = -INF;
                    for (int u = 0; u < 4; u ++ )
                    {
                        int d = distance[u];
                        if (d > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = d;
                        else if (d != d1[j][k] && d > d2[j][k]) d2[j][k] = d;
                    }
                }
            }
        }
    }
}

// 找出a和b的lca,顺便求出a到b之间的最大值和次大值
int lca(int a, int b, int w)
{
    static int distance[N * 2];
    int cnt = 0;
    if (depth[a] < depth[b]) swap(a, b);

    // 把a和b拉到同一个深度
    for (int k = 16; k >= 0; k -- )
        if (depth[fa[a][k]] >= depth[b])
        {
            distance[cnt ++ ] = d1[a][k];
            distance[cnt ++ ] = d2[a][k];
            a = fa[a][k];
        }

    // 把a和b之间的d1和d2的所有备选项求出来
    if (a != b)
    {
        for (int k = 16; k >= 0; k -- )
            if (fa[a][k] != fa[b][k])
            {
                distance[cnt ++ ] = d1[a][k];
                distance[cnt ++ ] = d2[a][k];
                distance[cnt ++ ] = d1[b][k];
                distance[cnt ++ ] = d2[b][k];
                a = fa[a][k], b = fa[b][k];
            }
        distance[cnt ++ ] = d1[a][0];
        distance[cnt ++ ] = d1[b][0];
    }

    // 把a和b之间的d1和d2求出来
    int dist1 = -INF, dist2 = -INF;
    for (int i = 0; i < cnt; i ++ )
    {
        int d = distance[i];
        if (d > dist1) dist2 = dist1, dist1 = d;
        else if (d != dist1 && d > dist2) dist2 = d;
    }

    if (w > dist1) return w - dist1;
    if (w > dist2) return w - dist2;
    return INF;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i ++ )
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        edge[i] = {a, b, c};
    }

    LL sum = kruskal();  // 计算最小生成树的值
    build();  // 把所有的树边建树
    bfs();  // 预处理出d1,d2,fa,depth数组

    LL res = 1e18;
    for (int i = 0; i < m; i ++ )
        if (!edge[i].used)  // 找出每条非树边,然后替换掉最大的那条树边
        {
            int a = edge[i].a, b = edge[i].b, w = edge[i].w;
            res = min(res, sum + lca(a, b, w));
        }
    printf("%lld\n", res);

    return 0;
}

Arab Collegiate Programming Contest 2015
题意: 求一个集合的lca
题解: 求一个集合的lca就是求这个集合中dfs序最小的点和dfs序最大的点的lca。dfs+ST处理。
代码:

#include <bits/stdc++.h>

using namespace std;

int const N = 1e5 + 1, M = N * 2;
int e[M], ne[M], h[N], idx;
int t, n, m;
int oula[M], len, pos[N], dp[M][23], de[M];
typedef pair<int, int> PII;
set<PII> s;

void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

// 求dfs序
void dfs(int u, int fa, int d)
{
    oula[++len] = u;  // 记录第len个为u点
    pos[u] = len;  // 记录u点的dfs序为len(只记录u点第一次出现的dfs序即可)
    de[len] = d;  // 记录深度
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;

        dfs(j, u, d + 1);
        oula[++len] = u;  // 回溯时还要记录
        de[len] = d;  
    }
}

// 得到x和y中深度比较小的那个
int Min(int x, int y)
{
    return de[x] > de[y]? y: x;
}

// 处理dp数组
void ST()
{
    for (int i = 1; i <= len; ++i) dp[i][0] = i;

    for (int j = 1; (1 << j) <= len; ++j )
        for (int i = 1; i + (1 << j) - 1 <= len; ++i )
            dp[i][j] = Min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);  // dp[i][j]是dp[i][j - 1]和dp[i + (1 << (j - 1))][j - 1]中深度更小的那个点的下标
}

// 求lca
int lca(int x, int y)
{
    // int x = pos[x], y = pos[y];  // 得到x和y点的dfs序的下标
    if (x > y) swap(x, y);
    int k = log2(y - x + 1);  // 计算在dfs序列中y和x的距离
    return Min(dp[x][k], dp[y - (1 << k) + 1][k]);  // lca为x~y范围内深度最小的那个dfs序的下标
}

int main()
{
    cin >> t;
    while (t--)
    {
        scanf("%d", &n);
        s.clear();
        memset(h, -1, sizeof h);
        idx = 0, len = 0;
        memset(dp, 0, sizeof dp);
        memset(oula, 0, sizeof oula);
        memset(de, 0, sizeof de);
        memset(pos, 0, sizeof pos);
        for (int i = 1; i <= n - 1; ++i)
        {
            int a, b;
            scanf("%d %d", &a, &b);
            add(a, b);
            add(b, a);
        }
        
        int root = 0;
        dfs(root, 0, 1);  // dfs得出dfs序
        ST();  // 得出dp数组

        scanf("%d", &m);
        while (m--)
        {
            getchar();
            char op = getchar();
            int num;
            scanf("%d", &num);
            if (op == '+') s.insert({pos[num], num});  // 插入
            else   // 删除
            {
                auto it = s.lower_bound({pos[num], num});
                s.erase(it);
            }  

            // 输出答案
            if (s.size() == 1) printf("%d\n", (*s.begin()).second);
            else if (s.size() >= 2)
            {
                auto left = (*s.begin()).first;
                auto right = (*prev(s.end())).first;
                printf("%d\n", oula[lca(left, right)]);
            }
            else printf("-1\n");
        }
    }
    return 0;
}
上一篇:LCA(最近公共祖先)


下一篇:[NOIP模拟]相遇/行程的交集