一、题目:
二、思路:
这道题真的把我迷了一个下午,感谢 zYzYzYzYz 大佬的点拨,让我终于明白了这道题。
注意到如果 \(y\) 的一个祖先是 \(x\),那么先 \(\mathbb{access}(x)\) 再 \(\mathbb{access}(y)\) 本质上相当于只 access 了一次。所以我们可以执行一个从下到上的树形 DP。
注意,以下这个 DP 状态比较复杂。
假设我们现在的目标是求出来 \(x\) 相关的信息,现在轮到了用 \(x\) 的儿子 \(y\) 来去更新 \(x\) 的信息。此时,\(tmpf[k]\) 表示"保证 \(y\) 之前的儿子边都不是实边",操作了 \(k\) 次的方案数;\(tmpg[k]\) 表示“保证 \(y\) 之前的儿子边中恰好有一个是实边”,操作了 \(k\) 次的方案数;\(f[x,k]\) 表示保证 \(y\) 及 \(y\) 之前的儿子边都不是实边,操作了 \(k\) 次的方案数;\(g[x,k]\) 表示保证处理完 \(y\) 之后,\(x\) 顶上那条边是实边,操作了 \(k\) 次的方案数。
则有状态转移方程:
在整个更新完 \(x\) 的答案之后,我们将 \(g[x]\) 数组中的值全部赋给 \(f[x]\),此时 \(f[x,k]\) 的意义变成 \(x\) 顶上那条边不是实边,操作了 \(k\) 次的方案数。为什么可以直接将 \(g[x]\) 数组中的值赋给 \(f[x]\) 呢?因为 \(x\) 顶上的边不是实边,只有可能是 \(x\) 的父亲用了一次 access,把原本是实边的边变成了虚边。
最后有一个小细节。就是最后将 \(g\) 赋给 \(f\) 的时候,\(g[x,0]\) 是等于 0 的,而 \(g[x,1]\) 是大于 0 的。赋给 \(f\) 了之后,为了保证意义上的自洽以及树的形态不能重复,要将 \(f[x,0]\gets 1\),\(f[x,1]\gets f[x,1]-1\)。
由于 \(k\leq siz_x\),所以这是一个经典的 \(O(nK)\) 的 DP。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#define FILEIN(s) freopen(s, "r", stdin)
#define FILEOUT(s) freopen(s, "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < ‘0‘ || ch > ‘9‘) { if (ch == ‘-‘) f = -1; ch = getchar(); }
while (ch >= ‘0‘ && ch <= ‘9‘) { x = x * 10 + ch - ‘0‘; ch = getchar(); }
return f * x;
}
const int MAXN = 10005, MOD = 998244353, MAXK = 505;
int n, K, siz[MAXN];
int head[MAXN], tot;
long long f[MAXN][MAXK], g[MAXN][MAXK];
long long tmpf[MAXK], tmpg[MAXK];
struct Edge {
int y, next;
Edge() {}
Edge(int _y, int _next) : y(_y), next(_next) {}
}e[MAXN << 1];
inline void connect(int x, int y) {
e[++ tot] = Edge(y, head[x]);
head[x] = tot;
}
void dfs(int x, int fa) {
siz[x] = 1;
g[x][1] = 1;
f[x][0] = 1;
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].y;
if (y == fa) continue;
dfs(y, x);
for (int k = 0; k <= min(K, siz[x]); ++ k) {
tmpf[k] = f[x][k];
tmpg[k] = g[x][k];
f[x][k] = g[x][k] = 0;
}
for (int j = 0; j <= min(K, siz[x]); ++ j) {
for (int k = 0; k <= siz[y] && j + k <= K; ++ k) {
(f[x][j + k] += tmpf[j] * f[y][k] % MOD) %= MOD;
(g[x][j + k] += tmpf[j] * g[y][k] % MOD + tmpg[j] * f[y][k] % MOD) %= MOD;
}
}
siz[x] += siz[y];
}
memcpy(f[x], g[x], sizeof f[x]);
f[x][0] = 1; -- f[x][1];
}
int main() {
FILEIN("access.in"); FILEOUT("access.out");
n = read(); K = read();
for (int i = 1; i < n; ++ i) {
int x = read(), y = read();
connect(x, y); connect(y, x);
}
dfs(1, 0);
long long res = 0;
for (int i = 0; i <= min(n, K); ++ i) (res += f[1][i]) %= MOD;
printf("%lld\n", res);
return 0;
}