题面
题解
从数据范围很容易看出是个虚树DP(可惜看出来了也还是不会做)
虚树大家应该都会, 不会的话自己去搜吧, 我懒得讲了, 我们在这里只需要考虑如何DP即可
首先我们需要求出每个点被哪个点所控制, 设\(u\)点被\(bl[u]\)所控制, 两遍DFS即可, 考虑儿子对父亲的影响和父亲对儿子的影响
代码细节相信不要我说, 能做这道题的总不可能不会DFS吧
还是贴一下自己看吧
void dfs1(int u, int fa)
{
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].to; if(v == fa) continue;
dfs1(v, u); int dv = dep[bl[v]] - dep[u], du = bl[u] ? dep[bl[u]] - dep[u] : 0x3f3f3f3f;
if(dv < du || (dv == du && bl[v] < bl[u])) bl[u] = bl[v];
}
}
void dfs2(int u, int fa)
{
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].to; if(v == fa) continue;
int dv = dis(bl[v], v), du = dis(bl[u], v);
if(du < dv || (du == dv && bl[u] < bl[v])) bl[v] = bl[u];
dfs2(v, u);
}
}
//分别在递归前和递归后处理一下就完事了
然后考虑如何计算答案, 对虚树上每一条边讨论
①: 边的两端被同一个节点所控制, 加上这两个点不在虚树中的儿子的sz即可
②: 边的两端被不同点控制, 我们需要找出一个分界点, 满足此分界点归下面那个点的\(bl[]\)所控制, 此分界点的父亲被上面那个点的\(bl[]\)所控制, 用个数据结构维护一下或者倍增跳一下就可以了
至于每个点不在虚树中的儿子的sz, 拿当前点的sz减去所有他在虚树中的儿子的sz即可
具体实现细节参见代码(参考了一下题解的思路嘿嘿嘿)
代码实现
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#define N 300005
using namespace std;
int n, m, Q, head[N], sz[N], son[N], dep[N], dfn[N], f[N][21], top[N], cnt, tp, a[N], b[N], stk[N], bl[N], con[N], ans[N], l[N];
struct edge { int from, to, next; } e[N << 1];
inline int read()
{
int x = 0, w = 1;
char c = getchar();
while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * w;
}
inline void add(int u, int v) { e[++cnt] = (edge) { u, v, head[u] }; head[u] = cnt; }
void dfs_sz(int u, int fa)
{
sz[u] = 1; dep[u] = dep[fa] + 1; f[u][0] = fa;
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].to; if(v == fa) continue;
dfs_sz(v, u); sz[u] += sz[v]; if(sz[son[u]] < sz[v]) son[u] = v;
}
}
void dfs_top(int x, int y)
{
dfn[x] = ++cnt; top[x] = y;
if(!son[x]) return; dfs_top(son[x], y);
for(int i = head[x]; i; i = e[i].next) if(e[i].to != f[x][0] && e[i].to != son[x]) dfs_top(e[i].to, e[i].to);
}
int LCA(int x, int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = f[top[x]][0];
}
return dep[x] < dep[y] ? x : y;
}
bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
void dfs1(int u, int fa)
{
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].to; if(v == fa) continue;
dfs1(v, u); int dv = dep[bl[v]] - dep[u], du = bl[u] ? dep[bl[u]] - dep[u] : 0x3f3f3f3f;
if(dv < du || (dv == du && bl[v] < bl[u])) bl[u] = bl[v];
}
}
int dis(int x, int y) { return dep[x] + dep[y] - 2 * dep[LCA(x, y)]; }
void dfs2(int u, int fa)
{
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].to; if(v == fa) continue;
int dv = dis(bl[v], v), du = dis(bl[u], v);
if(du < dv || (du == dv && bl[u] < bl[v])) bl[v] = bl[u];
dfs2(v, u);
}
}
void dp(int u)
{
for(int s, mid, nt, dv, du, i = head[u]; i; i = e[i].next)
{
int v = e[i].to; dp(v); s = mid = v;
for(int j = l[dep[v]]; j >= 0; j--) if(dep[f[s][j]] > dep[u]) s = f[s][j];
con[u] -= sz[s];
if(bl[u] == bl[v]) { ans[bl[u]] += sz[s] - sz[v]; continue; }
for(int j = l[dep[v]]; j >= 0; j--)
{
nt = f[mid][j]; if(dep[nt] <= dep[u]) continue;
dv = dis(bl[v], nt), du = dis(bl[u], nt);
if(dv < du || (dv == du && bl[v] < bl[u])) mid = nt;
}
ans[bl[u]] += sz[s] - sz[mid];
ans[bl[v]] += sz[mid] - sz[v];
}
ans[bl[u]] += con[u];
}
void query()
{
m = read(); cnt = tp = 0;
for(int i = 1; i <= m; i++) bl[a[++cnt] = b[i] = read()] = b[i];
sort(a + 1, a + cnt + 1, cmp);
for(int i = 1; i < m; i++) a[++cnt] = LCA(a[i], a[i + 1]);
a[++cnt] = 1; sort(a + 1, a + cnt + 1, cmp);
int len = unique(a + 1, a + cnt + 1) - a - 1;
cnt = 0; for(int i = 1; i <= len; i++) head[a[i]] = 0, con[a[i]] = sz[a[i]];
for(int i = 1; i <= len; i++)
{
while(tp && dfn[a[i]] >= dfn[stk[tp]] + sz[stk[tp]]) tp--;
if(tp) add(stk[tp], a[i]); stk[++tp] = a[i];
}
dfs1(1, 0); dfs2(1, 0); dp(1);
for(int i = 1; i <= m; i++) printf("%d%c", ans[b[i]], i == m ? '\n' : ' ');
for(int i = 1; i <= len; i++) bl[a[i]] = ans[a[i]] = con[a[i]] = 0;
}
int main()
{
n = read(); for(int i = 2; i <= n; i++) l[i] = l[i >> 1] + 1;
for(int i = 1; i < n; i++) { int u = read(), v = read(); add(u, v); add(v, u); }
cnt = 0; dfs_sz(1, 0); dfs_top(1, 0);
for(int i = 1; i <= n; i++)
for(int j = 1; j <= 20; j++)
f[i][j] = f[f[i][j - 1]][j - 1];
Q = read(); while(Q--) query();
return 0;
}