一、题目:
二、思路:
这道题题面中的“数据范围”部分的 Latex 公式 出了一些问题,向大家致以诚挚的歉意。
现在来说一下这道题怎么搞。
我们先来考虑一个点 \(z\) 能贡献答案的条件。题目中说了,若 \(\exists x,y\in[l,r]\) 满足 \(\mathbb{lca}(x,y)=z\),则 \(z\) 就可以贡献答案。那么我们考虑能不能把这个限制在精确一点,也就是说,我们把 \(x,y\) 的编号差放到最小,每次只需要检查 \(z\) 的子树内编号差最小的 \(x,y\) 即可。
由此想到一种暴力算法。
先将满足 \(l\leq z\leq r\) 的 \(z\) 单独计算,然后对于每个 \(z\) 以及它的某一个儿子 \(son\),枚举 \(\mathbb{subtree}(son)\) 中的节点 \(x\in[l,r]\)。在 \(\overline {\mathbb{subtree}(son)}\) 中,找见离 \(x\) 最近的节点 \(y\)(这里的“最近”指的是编号最近),如果 \(y\in[l,r]\),那么 \(z\) 就可以贡献答案。
当然这样的复杂度是不可接受的,于是我们考虑 dsu on tree。大致地说,就是在枚举 \(x\) 时,不枚举 \(\mathbb{subtree}(\mathbb{heavyson}(z))\) 中的点,只去枚举其他子树中的点。然后预处理出这些点在 \(z\) 处的最近点(在 \(z\) 处的意思就是:以 \(z\) 为最近公共祖先)。用启发式合并的 set 即可。
考虑如何回答询问。我们先将询问离线下来,按照 \(r\) 从小到大排序。对每个点 \(z\) 维护 \(Left[z]\) 数组,表示若 \(l\leq Left[z]\),那么 \(z\) 就可以贡献答案;通俗地讲,\(Left[z]\) 就是 \(z\) 可以产生贡献的最大的 \(l\)。那么随着 \(r\) 的递增,\(Left[z]\) 肯定是单调不降的。
对于当前要更新的点 \(r\),令 \(z\gets r\),让 \(z\) 顺着重链往上跳,每次让 \(z\) 移动到 \(\mathbb{father}(\mathbb{top}(z))\)。
- 找见 \(r\) 在 \(z\) 处的后继 \(y_0\)。将点对 \((r,z)\) 挂在 \(y_0\) 处,表示若 \(r\geq y_0\),则 \(Left[z]\gets\max\{Left[z],r \}\)。
- 找见 \(r\) 在 \(z\) 处的前驱 \(x_0\)。直接令 \(Left[z]\gets \max\{Left[z],x_0 \}\)。
这样为什么就能枚举到所有可能发生改变的 \(z\) 呢?考虑轻重链剖分的性质,一个点 \(z\) 最多只会有一个向下坠的重边,而上述的统计过程是在轻边处统计的。由于 \(x\) 和 \(y\) 一定在不同的子树中,所以 \(z\) 要么在 \(r\) 充当 \(x\) 的角色的时候被枚举到,要么在 \(r\) 充当 \(y\) 的角色的时候被枚举到。
统计答案用一个权值树状数组即可。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <set>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
#define FILEIN(s) freopen(s, "r", stdin)
#define FILEOUT(s) freopen(s, "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < ‘0‘ || ch > ‘9‘) { if (ch == ‘-‘) f = -1; ch = getchar(); }
while (ch >= ‘0‘ && ch <= ‘9‘) { x = x * 10 + ch - ‘0‘; ch = getchar(); }
return f * x;
}
const int MAXN = 3e5 + 5, INF = 0x3f3f3f3f;
int n, Q;
int siz[MAXN], son[MAXN], fa[MAXN], top[MAXN];
int Left[MAXN], ans[MAXN];
set<int>S[MAXN];
map<int, int>lower[MAXN], upper[MAXN];
vector<int>linker[MAXN];
vector<pair<int, int> >ope[MAXN];
struct Query {
int l, r, id;
inline friend bool operator <(const Query &a, const Query &b) {
return a.r < b.r;
}
}q[MAXN];
namespace BIT {
#define lowbit(x) (x & (-x))
int tr[MAXN];
inline void add(int p, int x) { // 树状数组的值域包含0,所以先将p++。
++ p;
for (; p <= n + 1; p += lowbit(p))
tr[p] += x;
}
inline int sum(int p) {
int res = 0;
++ p;
for (; p; p -= lowbit(p))
res += tr[p];
return res;
}
inline int query(int l, int r) {
return sum(r) - sum(l - 1);
}
}
void dfs1(int x, int father) {
siz[x] = 1;
fa[x] = father;
for (auto &y : linker[x]) {
if (y == father) continue;
dfs1(y, x);
if (siz[y] > siz[son[x]]) son[x] = y;
siz[x] += siz[y];
}
}
void dfs2(int x, int first) {
top[x] = first;
if (son[x]) dfs2(son[x], first);
for (auto &y : linker[x]) {
if (y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
void dfs3(int x) {
if (!son[x]) { S[x].insert(x); return; }
dfs3(son[x]);
for (int i = 0; i < (int)linker[x].size(); ++ i) {
int y = linker[x][i];
if (y == fa[x] || y == son[x]) continue;
dfs3(y);
for (auto &p : S[y]) {
lower[p][x] = 0;
upper[p][x] = INF;
set<int>::iterator it = S[x].lower_bound(p);
if (it != S[x].end()) upper[p][x] = *it;
if (it != S[x].begin()) { -- it; lower[p][x] = *it; }
it = S[son[x]].lower_bound(p);
if (it != S[son[x]].end()) upper[p][x] = min(upper[p][x], *it);
if (it != S[son[x]].begin()) { -- it; lower[p][x] = max(lower[p][x], *it); }
}
for (auto &p : S[y]) S[x].insert(p);
}
swap(S[x], S[son[x]]);
for (int i = (int)linker[x].size() - 1; i >= 0; -- i) {
int y = linker[x][i];
if (y == fa[x] || y == son[x]) continue;
for (auto &p : S[y]) {
set<int>::iterator it = S[x].lower_bound(p);
if (it != S[x].end()) upper[p][x] = min(upper[p][x], *it);
if (it != S[x].begin()) { -- it; lower[p][x] = max(lower[p][x], *it); }
}
for (auto &p : S[y]) S[x].insert(p);
}
S[x].insert(x);
}
int main() {
FILEIN("party.in"); FILEOUT("party.out");
n = read(); Q = read();
for (int i = 1; i < n; ++ i) {
int x = read(), y = read();
linker[x].push_back(y);
linker[y].push_back(x);
}
dfs1(1, 0);
dfs2(1, 1);
dfs3(1);
for (int i = 1; i <= Q; ++ i) {
q[i].l = read(); q[i].r = read();
q[i].id = i;
}
sort(q + 1, q + Q + 1);
int R = 0;
BIT::add(0, n);
for (int i = 1; i <= Q; ++ i) {
while (R < q[i].r) {
++ R;
// Left[R] = max(Left[R], R);
if (R > Left[R]) {
BIT::add(Left[R], -1);
Left[R] = R;
BIT::add(Left[R], 1);
}
for (auto &p : ope[R]) {
int x = p.first, z = p.second;
// Left[z] = max(Left[z], x);
if (x > Left[z]) {
BIT::add(Left[z], -1);
Left[z] = x;
BIT::add(Left[z], 1);
}
}
int x = fa[top[R]];
while (x) {
int y0 = upper[R][x];
if (y0 <= n) ope[y0].push_back({ R, x });
int x0 = lower[R][x];
// Left[x] = max(Left[x], x0);
if (x0 > Left[x]) {
BIT::add(Left[x], -1);
Left[x] = x0;
BIT::add(Left[x], 1);
}
x = fa[top[x]];
}
}
ans[q[i].id] = BIT::query(q[i].l, n);
}
for (int i = 1; i <= Q; ++ i) {
printf("%d\n", ans[i]);
}
return 0;
}