目录
LCA
1. 算法分析
1.1 求LCA的四种方法
1.树上倍增法:
倍增思想:\(f[i][j]\)表示i这个位置向上走2^j步后到达x,则有状态转移:\(f[y][j] = f[[y][j-1]][j-1]\),利用这个不断处理出f数组,树上倍增法能够得到 \(f[x][i]\) 数组和 \(d[i]\) 数组,利用这两个数组可以求出很多的东西。这是在线做法
2.tarjan算法:
dfs的特性和并查集的特性。把所有点分成三类,第一类:正在搜索的点,第二类:已经回溯完的点,第三类:还没有搜索过的点,每次搜索的时候,记当前点为x,把点x做个标记,然后判断和这个点对应的点y是否已经回溯过了,如果y已经回溯过了,那么x和y的lca即为y的get(y)得到的节点。这是离线做法
3.dfs+ST
4.树剖求lca
1.2 求lca的两种场景
- 求任意两个点的lca
- 求集合的lca:求一个集合的lca就是求这个集合中dfs序最小的点和dfs序最大的点的lca
2. 板子
2.1 树上倍增法
HDU 2586
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
int f[N][20], d[N], dist[N]; // f[i][j]表示从i开始,往上走2^j步到达的点,d为深度,dist为距离
int e[N], ne[N], h[N], idx, w[N];
int T, n, m, t; // t为数的深度
queue<int> q;
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
// 预处理:得到每个点的深度,距离,f数组
void bfs()
{
q.push(1); // 把根放入队列,注意这里有可能根不是1
d[1] = 1;
while (q.size())
{
int x = q.front();
q.pop();
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (d[y]) continue;
d[y] = d[x] + 1; // 更新深度
dist[y] = dist[x] + w[i]; // 更新距离
// 进行dp更新
f[y][0] = x;
for (int j = 1; j <= t; ++j)
{
f[y][j] = f[f[y][j - 1]][j - 1]; // 分两段处理
}
q.push(y);
}
}
}
// 查找x和y的最近公共祖先
int lca(int x, int y)
{
if (d[x] > d[y]) swap(x, y); // 保证x的深度浅一点
for (int i = t; i >= 0; --i)
if (d[f[y][i]] >= d[x]) y = f[y][i]; // 让x和y到同一个深度
if (x == y) return x;
for (int i = t; i >= 0; --i) // 让x和y之差一步就能相遇
{
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
}
return f[x][0];
}
int main()
{
cin >> T;
while (T--)
{
memset(h, -1, sizeof h);
idx = t = 0;
cin >> n >> m;
t = (int)(log(n) / log(2)) + 1; // 得到树的深度
// 读入一棵树
for (int i = 0; i < n - 1; ++i)
{
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
bfs();
// 回答询问
for (int i = 1; i <= m; ++i)
{
int a, b;
scanf("%d %d", &a, &b);
printf("%d\n", dist[a] + dist[b] - 2 * dist[lca(a, b)]);
}
}
return 0;
}
2.2 tarjan算法
HDU 2586
#include <bits/stdc++.h>
using namespace std;
int const N = 1e5 + 10;
int e[N], ne[N], idx, ans[N], v[N], fa[N], d[N], h[N], w[N];
vector<int> query[N], query_id[N];
int t, n, m;
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a] , h[a] = idx++;
}
// 并查集查询+路径压缩
int get(int x)
{
if (fa[x] != x) fa[x] = get(fa[x]);
return fa[x];
}
// tarjan算法求lca
void tarjan(int x)
{
// 记录这个点走过一次,但是还没有回溯
v[x] = 1;
// 遍历每一个和x相连的点
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (v[y]) continue; // 这个点走过的话,不进行后面的操作
d[y] = d[x] + w[i];
tarjan(y); // 得到y为根节点的所有子树的d
fa[y] = x; // 更新y的父节点
}
// 判断和x有关的lca询问
for (int i = 0; i < query[x].size(); ++i)
{
int y = query[x][i];
int id = query_id[x][i];
if (v[y] == 2)
{
int lca = get(y); // 获得lca:如果y点回溯,那么lca为get(y)
ans[id] = min(ans[id], d[x] + d[y] - 2 * d[lca]); // 更新答案
}
}
v[x] = 2; // 标记x点回溯
}
int main()
{
cin >> t;
while (t--)
{
cin >> n >> m;
// 初始化
memset(h, -1, sizeof h);
idx = 0;
for (int i = 1; i <= n; ++i)
{
fa[i] = i;
query[i].clear(), query_id[i].clear();
}
// 读入树边
for (int i = 1; i < n; ++i)
{
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
add(a , b, c), add(b, a, c);
}
// 读入询问的边
for (int i = 1; i <= m; ++i)
{
int a, b;
scanf("%d %d", &a, &b);
if (a == b)
{
ans[i] = 0;
}
else
{
query[a].push_back(b), query[b].push_back(a);
query_id[a].push_back(i), query_id[b].push_back(i);
ans[i] = (1 << 30);
}
}
// 做dfs求lca
tarjan(1);
// 输出答案
for (int i = 1 ; i <= m; ++i)
{
cout << ans[i] << endl;
}
}
return 0;
}
2.3 dfs+ST
/*
我们需要维护数组oula[i]= j表示j的dfs序为i,pos[i]=j表示i第一次在dfs序中出现的位置是j
len记录dfs序的长度,dp[i][j]表示从i点出发走2^j步的范围内最小的深度的点的坐标,de[i]=j表示i的深度为j
本算法为dfs+st表求lca,预处理时间O(nlogn), 查询O(1)
算法步骤:
1. dfs:预处理出dfs序
2. ST:预处理dp数组
3. 给定任意两个点:通过dp数组得出这两个点间的深度最小的点,即为lca
*/
#include <bits/stdc++.h>
using namespace std;
int const N = 1e5 + 1, M = N * 2;
int e[M], ne[M], h[N], idx;
int t, n, m;
int oula[M], len, pos[N], dp[M][23], de[M];
typedef pair<int, int> PII;
set<PII> s;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
// 求dfs序
void dfs(int u, int fa, int d)
{
oula[++len] = u; // 记录第len个为u点
pos[u] = len; // 记录u点的dfs序为len(只记录u点第一次出现的dfs序即可)
de[len] = d; // 记录深度
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
dfs(j, u, d + 1);
oula[++len] = u; // 回溯时还要记录
de[len] = d;
}
}
// 得到x和y中深度比较小的那个
int Min(int x, int y)
{
return de[x] > de[y]? y: x;
}
// 处理dp数组
void ST()
{
for (int i = 1; i <= len; ++i) dp[i][0] = i;
for (int j = 1; (1 << j) <= len; ++j )
for (int i = 1; i + (1 << j) - 1 <= len; ++i )
dp[i][j] = Min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]); // dp[i][j]是dp[i][j - 1]和dp[i + (1 << (j - 1))][j - 1]中深度更小的那个点的下标
}
// 求lca
int lca(int x, int y)
{
// int x = pos[x], y = pos[y]; // 得到x和y点的dfs序的下标
if (x > y) swap(x, y);
int k = log2(y - x + 1); // 计算在dfs序列中y和x的距离
return Min(dp[x][k], dp[y - (1 << k) + 1][k]); // lca为x~y范围内深度最小的那个dfs序的下标
}
int main()
{
cin >> t;
while (t--)
{
scanf("%d", &n);
s.clear();
memset(h, -1, sizeof h);
idx = 0, len = 0;
memset(dp, 0, sizeof dp);
memset(oula, 0, sizeof oula);
memset(de, 0, sizeof de);
memset(pos, 0, sizeof pos);
for (int i = 1; i <= n - 1; ++i)
{
int a, b;
scanf("%d %d", &a, &b);
add(a, b);
add(b, a);
}
int root = 0;
dfs(root, 0, 1); // dfs得出dfs序
ST(); // 得出dp数组
scanf("%d", &m);
while (m--)
{
getchar();
char op = getchar();
int num;
scanf("%d", &num);
if (op == '+') s.insert({pos[num], num}); // 插入
else // 删除
{
auto it = s.lower_bound({pos[num], num});
s.erase(it);
}
// 输出答案
if (s.size() == 1) printf("%d\n", (*s.begin()).second);
else if (s.size() >= 2)
{
auto left = (*s.begin()).first;
auto right = (*prev(s.end())).first;
printf("%d\n", oula[lca(left, right)]);
}
else printf("-1\n");
}
}
return 0;
}
2.4 树剖求lca
见 树链剖分.md
3. 典型例题
acwing356 次小生成树
题意: 给定一张 N 个点 M 条边的无向图,求无向图的严格次小生成树。设最小生成树的边权之和为sum,严格次小生成树就是指边权之和大于sum的生成树中最小的一个。\(N≤10^5,M≤3*10^5\)
题解: 本题的思路是在求出最小生成树的基础上,找出一条非树边a->b,然后再树上找出a->b的最大值,删除这个最大值,加上非树边。
基于这个思路,目标就是要找出这个a->b在树边的最大值。
找出a->b在树边的最大值,可以先在树上进行预处理,\(fa[i][j]\)表示i向上走\(2^j\)步到达的点,\(d1[i][j]\)表示i向上走\(2^j\)步范围内的最大值,\(d2[i][j]\)表示i向上走\(2^j\)步范围内的次大值,然后每次在找lca时顺便找出,x到lca的最大值和次大值,y到lca的最大值和次大值,比较即可
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 100010, M = 300010, INF = 0x3f3f3f3f;
int n, m;
struct Edge
{
int a, b, w;
bool used;
bool operator< (const Edge &t) const
{
return w < t.w;
}
}edge[M];
int p[N];
int h[N], e[M], w[M], ne[M], idx;
int depth[N], fa[N][17], d1[N][17], d2[N][17];
int q[N];
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
int find(int x)
{
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
// 找最小生成树
LL kruskal()
{
for (int i = 1; i <= n; i ++ ) p[i] = i;
sort(edge, edge + m);
LL res = 0;
for (int i = 0; i < m; i ++ )
{
int a = find(edge[i].a), b = find(edge[i].b), w = edge[i].w;
if (a != b)
{
p[a] = b;
res += w;
edge[i].used = true;
}
}
return res;
}
// 建树
void build()
{
memset(h, -1, sizeof h);
for (int i = 0; i < m; i ++ )
if (edge[i].used)
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
add(a, b, w), add(b, a, w);
}
}
// 预处理fa,d1,d2,depth
void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
q[0] = 1; // 把1当成根节点
int hh = 0, tt = 0;
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1; // 更新j的深度
q[ ++ tt] = j;
fa[j][0] = t;
d1[j][0] = w[i], d2[j][0] = -INF; // 求出d1和d2
for (int k = 1; k <= 16; k ++ )
{
int anc = fa[j][k - 1];
fa[j][k] = fa[anc][k - 1];
int distance[4] = {d1[j][k - 1], d2[j][k - 1], d1[anc][k - 1], d2[anc][k - 1]};
d1[j][k] = d2[j][k] = -INF;
for (int u = 0; u < 4; u ++ )
{
int d = distance[u];
if (d > d1[j][k]) d2[j][k] = d1[j][k], d1[j][k] = d;
else if (d != d1[j][k] && d > d2[j][k]) d2[j][k] = d;
}
}
}
}
}
}
// 找出a和b的lca,顺便求出a到b之间的最大值和次大值
int lca(int a, int b, int w)
{
static int distance[N * 2];
int cnt = 0;
if (depth[a] < depth[b]) swap(a, b);
// 把a和b拉到同一个深度
for (int k = 16; k >= 0; k -- )
if (depth[fa[a][k]] >= depth[b])
{
distance[cnt ++ ] = d1[a][k];
distance[cnt ++ ] = d2[a][k];
a = fa[a][k];
}
// 把a和b之间的d1和d2的所有备选项求出来
if (a != b)
{
for (int k = 16; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
{
distance[cnt ++ ] = d1[a][k];
distance[cnt ++ ] = d2[a][k];
distance[cnt ++ ] = d1[b][k];
distance[cnt ++ ] = d2[b][k];
a = fa[a][k], b = fa[b][k];
}
distance[cnt ++ ] = d1[a][0];
distance[cnt ++ ] = d1[b][0];
}
// 把a和b之间的d1和d2求出来
int dist1 = -INF, dist2 = -INF;
for (int i = 0; i < cnt; i ++ )
{
int d = distance[i];
if (d > dist1) dist2 = dist1, dist1 = d;
else if (d != dist1 && d > dist2) dist2 = d;
}
if (w > dist1) return w - dist1;
if (w > dist2) return w - dist2;
return INF;
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i < m; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
edge[i] = {a, b, c};
}
LL sum = kruskal(); // 计算最小生成树的值
build(); // 把所有的树边建树
bfs(); // 预处理出d1,d2,fa,depth数组
LL res = 1e18;
for (int i = 0; i < m; i ++ )
if (!edge[i].used) // 找出每条非树边,然后替换掉最大的那条树边
{
int a = edge[i].a, b = edge[i].b, w = edge[i].w;
res = min(res, sum + lca(a, b, w));
}
printf("%lld\n", res);
return 0;
}
Arab Collegiate Programming Contest 2015
题意: 求一个集合的lca
题解: 求一个集合的lca就是求这个集合中dfs序最小的点和dfs序最大的点的lca。dfs+ST处理。
代码:
#include <bits/stdc++.h>
using namespace std;
int const N = 1e5 + 1, M = N * 2;
int e[M], ne[M], h[N], idx;
int t, n, m;
int oula[M], len, pos[N], dp[M][23], de[M];
typedef pair<int, int> PII;
set<PII> s;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
// 求dfs序
void dfs(int u, int fa, int d)
{
oula[++len] = u; // 记录第len个为u点
pos[u] = len; // 记录u点的dfs序为len(只记录u点第一次出现的dfs序即可)
de[len] = d; // 记录深度
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
dfs(j, u, d + 1);
oula[++len] = u; // 回溯时还要记录
de[len] = d;
}
}
// 得到x和y中深度比较小的那个
int Min(int x, int y)
{
return de[x] > de[y]? y: x;
}
// 处理dp数组
void ST()
{
for (int i = 1; i <= len; ++i) dp[i][0] = i;
for (int j = 1; (1 << j) <= len; ++j )
for (int i = 1; i + (1 << j) - 1 <= len; ++i )
dp[i][j] = Min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]); // dp[i][j]是dp[i][j - 1]和dp[i + (1 << (j - 1))][j - 1]中深度更小的那个点的下标
}
// 求lca
int lca(int x, int y)
{
// int x = pos[x], y = pos[y]; // 得到x和y点的dfs序的下标
if (x > y) swap(x, y);
int k = log2(y - x + 1); // 计算在dfs序列中y和x的距离
return Min(dp[x][k], dp[y - (1 << k) + 1][k]); // lca为x~y范围内深度最小的那个dfs序的下标
}
int main()
{
cin >> t;
while (t--)
{
scanf("%d", &n);
s.clear();
memset(h, -1, sizeof h);
idx = 0, len = 0;
memset(dp, 0, sizeof dp);
memset(oula, 0, sizeof oula);
memset(de, 0, sizeof de);
memset(pos, 0, sizeof pos);
for (int i = 1; i <= n - 1; ++i)
{
int a, b;
scanf("%d %d", &a, &b);
add(a, b);
add(b, a);
}
int root = 0;
dfs(root, 0, 1); // dfs得出dfs序
ST(); // 得出dp数组
scanf("%d", &m);
while (m--)
{
getchar();
char op = getchar();
int num;
scanf("%d", &num);
if (op == '+') s.insert({pos[num], num}); // 插入
else // 删除
{
auto it = s.lower_bound({pos[num], num});
s.erase(it);
}
// 输出答案
if (s.size() == 1) printf("%d\n", (*s.begin()).second);
else if (s.size() >= 2)
{
auto left = (*s.begin()).first;
auto right = (*prev(s.end())).first;
printf("%d\n", oula[lca(left, right)]);
}
else printf("-1\n");
}
}
return 0;
}