1 CF832D 玩转地铁
2 题目描述
时间限制 \(2s\) | 空间限制 \(256M\)
\(Misha\) 和 \(Grisha\) 是很有意思的两个男孩,他们很喜欢坐地铁玩,地铁有 \(n\) 个车站,由 \(n-1\) 条线路连接起来,每一条线路连接两个车站,所有的车站都可以到达。 男孩们准备玩游戏,早上 \(Misha\) 将会选择最短的路线从车站 \(s\) 走到车站 \(f\),同时在包括 \(s\) 和 \(f\) 站的每个经过的车站都喷写上一句话 \(Misha\ was\ here\)。在同一天傍晚 \(Grisha\) 坐地铁走最短路线从 \(t\) 站走到 \(f\) 站,他计算看到 \(Misha\) 写的这句话的次数。当天晚上清洁工会清除掉所有喷写的字。男孩们已经选好了未来几天的三个车站 \(a,b,c\),其中有一个是当天的车站 \(s\),有一个是车站 \(f\),另外的那个是车站 \(t\)。请计算如何选择 \(s,f,t\) 使得 \(Grisha\) 能看到那句话的次数最多。
数据范围:\(2 ≤ n ≤ 10^5, 1 ≤ q ≤ 10^5\)
3 题解
容易发现,答案一定是给出三个点中某两个点的最近公共祖先与另一个节点的距离。由于我们只能在有根树上求出 \(LCA\),所以我们求出的 \(x\) 和 \(y\) 的 \(LCA\) 肯定深度不比 \(min(depth_x, depth_y)\) 要大。此时,如果剩下的那一个节点与 \(LCA(x, y)\) 之间的路径经过 \(x\) 或 \(y\) 与 \(LCA(x, y)\) 之间的路径,那么路径之间就出现重复,该路径就不是最短路了。
为了避免这种情况的出现,我们每次取出的是所有 \(LCA\) 中的深度最大的 \(LCA\)。这是因为,如果我们当前找到的 \(LCA\) 不是深度最大的,那么一定存在一个深度更大的 \(LCA\),而那个 \(LCA\) 到我们当前找到的 \(LCA\) 中间的路径使我们在计算时走的并不是最短路,答案也就是错误的。
找到这个 \(LCA\) 之后,我们只需要枚举所有的情况,即计算 \(LCA\) 到 \(x, y, z\) 三个节点之间的距离的最大值即可。由于数据范围只有 \(1\times 10^5\),我们可以用倍增计算两点间的 \(LCA\)。树上任意两点 \(x, y\) 的距离即为 \(depth_x + depth_y - 2 \times depth_{LCA(x, y)}\),因为 \(LCA\) 与根节点之间的路径一共被计算了 \(2\) 次,而这两次都不是我们所需要的。
4 代码(空格警告):
#include <iostream>
#include <queue>
#include <cmath>
using namespace std;
#define d(x,y) dep[lca(x,y)]
const int N = 1e5+10;
int n, q, tot, p, t, a, b, c, LCA;
int head[N], ver[N*2], last[N*2], dep[N], f[N][30];
queue<int> Q;
void add(int x, int y)
{
ver[++tot] = y;
last[tot] = head[x];
head[x] = tot;
}
void bfs()
{
Q.push(1); dep[1] = 1;
while (Q.size())
{
int x = Q.front();Q.pop();
for (int i = head[x]; i; i = last[i])
{
int y = ver[i];
if (dep[y]) continue;
dep[y] = dep[x] + 1;
f[y][0] = x;
for (int j = 1; j <= t; j++) f[y][j] = f[f[y][j-1]][j-1];
Q.push(y);
}
}
}
int lca(int x, int y)
{
if (dep[x] > dep[y]) swap(x, y);
for (int i = t; i >= 0; i--) if (dep[f[y][i]] >= dep[x]) y = f[y][i];
if (x == y) return x;
for (int i = t; i >= 0; i--) if (f[y][i] != f[x][i]) y = f[y][i], x = f[x][i];
return f[y][0];
}
int dis(int x, int y)
{
return dep[x] + dep[y] - 2 * dep[lca(x, y)];
}
int main()
{
cin >> n >> q;
t = (int)(log(n) / log(2)) + 1;
for (int i = 1; i < n; i++)
{
cin >> p;
add(p, i+1);
add(i+1, p);
}
bfs();
for (int i = 1; i <= q; i++)
{
cin >> a >> b >> c;
if (d(a, b) >= d(a, c) && d(a, b) >= d(b, c)) LCA = lca(a, b);
else if (d(a, c) > d(a, b) && d(a, c) > d(b, c)) LCA = lca(a, c);
else LCA = lca(b, c);
cout << max(max(dis(LCA, a), dis(LCA, b)), dis(LCA, c))+1 << '\n';
}
return 0;
}