Description
给定一棵 \(n\) 个结点的树,你从点 \(x\) 出发,每次等概率随机选择一条与所在点相邻的边走过去。
有 \(Q\) 次询问,每次询问给定一个集合 \(S\) ,求如果从 \(x\) 出发一直随机游走,直到点集 \(S\) 中所有点都至少经过一次的话,期望游走几步。
特别地,点 \(x\)(即起点)视为一开始就被经过了一次。
答案对 \(998244353\) 取模。
Solution
不妨设 \(f_{i,S}\) 表示在点 \(i\) 时,要遍历集合 \(S\) 的期望步数。那么对于一个询问 \(S\) ,答案就是 \(f_{x,S}\) 。
从两个方面来考虑如何求 \(f\) :
- 如果 \(u\not\in S\) ,由套路,显然满足 \[f_{u,S}=\frac{\sum_{\text{v is the neighbor of u}}f_{v,S}}{degree_u}+1\]
- 如果 \(u\in S\)
- 若 \(\{u\}=S\) ,显然 \(f_{u,S}=0\) ;
- 若 \(\{u\}\neq S\) ,容易得到 \(f_{u,S}=f_{u,S-\{u\}}\)
这样我们对于同一个状态 \(S\) 可以得到若干个方程,那么在这一个状态内高斯消元即可。
由于是树上消元,所以可以用[Codeforces 802L]Send the Fool Further! (hard)的方法化成 \(f_u=k_uf_{fa_u}+b_u\) 的形式 \(O(n)\) 求解。
总复杂度是 \(O(n\log(n)2^n+Q)\) ,其中 \(\log(n)\) 是求逆元的复杂度。
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 20, SIZE = (1<<18)+5, yzh = 998244353;
int n, q, x, u, v, bin[N], dg[N], S;
struct tt {int to, next; }edge[N<<1];
int path[N], top, k[N], b[N], f[N][SIZE];
int quick_pow(int a, int b) {
int ans = 1;
while (b) {
if (b&1) ans = 1ll*ans*a%yzh;
b >>= 1, a = 1ll*a*a%yzh;
}
return ans;
}
void dfs(int u, int fa) {
k[u] = b[u] = 0;
for (int i = path[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa) dfs(v, u);
if (!(bin[u-1]&S)) {
if (dg[u] == 1 && x != u) k[u] = b[u] = 1;
else {
k[u] = dg[u], b[u] = dg[u];
for (int i = path[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa) {
(k[u] -= k[v]) %= yzh; (b[u] += b[v]) %= yzh;
}
k[u] = quick_pow(k[u], yzh-2);
b[u] = 1ll*b[u]*k[u]%yzh;
}
}else {
if (S^bin[u-1]) {
k[u] = 0; b[u] = f[u][S^bin[u-1]];
}else k[u] = b[u] = 0;
}
}
void cal(int u, int fa) {
f[u][S] = (1ll*k[u]*f[fa][S]%yzh+b[u])%yzh;
for (int i = path[u], v; i; i = edge[i].next)
if ((v = edge[i].to) != fa) cal(v, u);
}
void add(int u, int v) {edge[++top] = (tt){v, path[u]}, path[u] = top; ++dg[v]; }
void work() {
scanf("%d%d%d", &n, &q, &x);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
bin[0] = 1; for (int i = 1; i < N; i++) bin[i] = (bin[i-1]<<1);
for (int i = 1; i < bin[n]; i++) S = i, dfs(x, 0), cal(x, 0);
while (q--) {
S = 0; scanf("%d", &u);
for (int i = 1; i <= u; i++) scanf("%d", &v), S |= bin[v-1];
printf("%d\n", (f[x][S]+yzh)%yzh);
}
}
int main() {work(); return 0; }