[CF519E] A and B and Lecture Rooms - LCA
Description
给定一棵树,有 \(m\) 组询问,每次给定两个点 \(u,v\),问到 \(u,v\) 距离相等的点有多少个。
Solution
一定是连接 \(u,v\) 的路径的中点以及它所发出的其它子树。
以下设 \(LCA(u,v)=l, MID(u,v)=c\)。
如果 \(l=c\),那么砍掉与 \(u,v\) 有关的两棵子树即可。
如果 \(l \neq c\),假设 \(DEP(u)>DEP(v)\),则中点的子树砍掉与 \(u\) 有关可的一棵子树即可。
问题 1:到底怎么求中点?
中点存在的充要条件为 \(2|(DEP(u)+DEP(v))\),如果中点存在,那么它的深度,是可以求出的:直链的情况为 \(\frac {DEP(u)+DEP(v)} 2\),曲链的情况下,设 \(len = DEP(u)+DEP(v)-2DEP(l)\),假设 \(DEP(u) \ge DEP(v)\),则 \(MID\) 的深度为 \(DEP(u)-\frac {len} 2\)。求出深度后,我们从较深点倍增向上跳即可。
问题 2:怎样砍掉点 \(p\) 与点 \(q\) 相关的一棵子树?
我们将 \(q\) 倍增向上跳,使它成为 \(p\) 的孩子,那么此时的深度显然是 \(DEP(p)+1\)。
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1000005;
int n, m;
vector<int> g[N];
int dep[N], fa[N][20], siz[N];
// 树,倍增 LCA,求路径长度
void dfs(int p, int from)
{
siz[p] = 1;
for (auto q : g[p])
{
if (q != from)
{
dep[q] = dep[p] + 1;
fa[q][0] = p;
dfs(q, p);
siz[p] += siz[q];
}
}
}
void presolve()
{
dep[1] = 1;
dfs(1, 0);
for (int j = 1; j < 20; j++)
for (int i = 1; i <= n; i++)
fa[i][j] = fa[fa[i][j - 1]][j - 1];
}
int lca(int p, int q)
{
if (dep[p] < dep[q])
swap(p, q);
for (int i = 19; i >= 0; i--)
if (dep[fa[p][i]] >= dep[q])
p = fa[p][i];
for (int i = 19; i >= 0; i--)
if (fa[p][i] != fa[q][i])
{
p = fa[p][i];
q = fa[q][i];
}
if (p != q)
return fa[p][0];
return p;
}
// 跳跃,到某一确定深度
int jump(int p, int depth)
{
for (int i = 19; i >= 0; i--)
if (dep[fa[p][i]] >= depth)
p = fa[p][i];
return p;
}
// 求中点
bool check_mid(int u, int v)
{
return (dep[u] + dep[v]) % 2 == 0;
}
int mid(int u, int v)
{
int l = lca(u, v);
if (l == v)
{
int depth = (dep[u] + dep[v]) / 2;
return jump(u, depth);
}
else
{
int len = dep[u] + dep[v] - 2 * dep[l];
int depth = dep[u] - len / 2;
return jump(u, depth);
}
}
// 求相关孩子
int get_rela_child(int p, int q)
{
return jump(q, dep[p] + 1);
}
// 主程序
signed main()
{
ios::sync_with_stdio(false);
cin >> n;
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
presolve();
cin >> m;
for (int i = 1; i <= m; i++)
{
int u, v;
cin >> u >> v;
if (dep[u] < dep[v])
swap(u, v);
int l = lca(u, v);
int ans = 0;
if (u == v)
{
ans = n;
}
else if (check_mid(u, v))
{
int c = mid(u, v);
if (l == c)
{
int rela_u = get_rela_child(c, u);
int rela_v = get_rela_child(c, v);
ans = n - siz[rela_u] - siz[rela_v];
}
else
{
int rela_u = get_rela_child(c, u);
ans = siz[c] - siz[rela_u];
}
}
cout << ans << endl;
}
}