题解
套路套路地用线段树合并
注意:可能爆栈,所以 \(\text bfs\) 处理
合并要新开节点,不然后修改子树信息
\(Code\)
#include<cstdio>
#include<iostream>
#define LL long long
using namespace std;
const int N = 3e5 + 5;
int n, q, h[N], tot;
struct edge{int to, nxt;}e[N << 1];
inline void add(int x, int y){e[++tot] = edge{y, h[x]}, h[x] = tot;}
inline void read(int &x)
{
x = 0; char ch = getchar(); int f = 1;
while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
while (ch >= '0' && ch <= '9') x = (x<<3)+(x<<1)+ch-'0', ch = getchar();
x *= f;
}
int fa[N], siz[N], dep[N], rt[N], size;
struct Tree{LL sum; int ls, rs;}seg[60 * N];
void insert(int &p, int l, int r, int x, int v)
{
if (!p) p = ++size;
seg[p].sum += v;
if (l == r) return;
int mid = (l + r) >> 1;
if (x <= mid) insert(seg[p].ls, l, mid, x, v);
else insert(seg[p].rs, mid + 1, r, x, v);
}
int merge(int x, int y)
{
if (!x || !y) return x | y;
int p = ++size;
seg[p].sum = seg[x].sum + seg[y].sum;
seg[p].ls = merge(seg[x].ls, seg[y].ls);
seg[p].rs = merge(seg[x].rs, seg[y].rs);
return p;
}
LL query(int p, int l, int r, int x, int y)
{
if (x <= l && r <= y) return seg[p].sum;
int mid = (l + r) >> 1; LL res = 0;
if (x <= mid && seg[p].ls) res += query(seg[p].ls, l, mid, x, y);
if (y > mid && seg[p].rs) res += query(seg[p].rs, mid + 1, r, x, y);
return res;
}
int d[N];
void bfs()
{
int head = 0, tail = 1;
d[1] = 1, dep[1] = 1;
while (head < tail)
{
int x = d[++head];
for(register int i = h[x]; i; i = e[i].nxt)
{
if (e[i].to == fa[x]) continue;
fa[e[i].to] = x, d[++tail] = e[i].to, dep[e[i].to] = dep[x] + 1;
}
}
for(register int j = n; j; j--)
{
int x = d[j]; siz[x] = 1;
for(register int i = h[x]; i; i = e[i].nxt)
{
if (e[i].to == fa[x]) continue;
siz[x] += siz[e[i].to];
}
insert(rt[x], 1, n, dep[x], siz[x] - 1);
for(register int i = h[x]; i; i = e[i].nxt)
{
if (e[i].to == fa[x]) continue;
rt[x] = merge(rt[x], rt[e[i].to]);
}
}
}
int main()
{
read(n), read(q);
for(register int i = 1, u, v; i < n; i++) read(u), read(v), add(u, v), add(v, u);
bfs();
for(register int i = 1, p, k; i <= q; i++)
{
read(p), read(k);
printf("%lld\n", 1LL * min(dep[p] - 1, k) * (siz[p] - 1) + query(rt[p], 1, n, dep[p] + 1, dep[p] + k));
}
}