参考了官方题解的做法。
第 3 个限制本质要求是什么?
或者说,什么样的树可以满足第 3 个限制?
- 叶向树
- 根向树
- 一棵根向树 + 一棵叶向树 + 一条边 拼起来的树
为第三种情况画了一幅直观的图:
用 \(f_i\) 表示 无标号、有根、边同向(即都为叶向或都为根向,但此时只算一种)、\(\text{deg}(root) \le 2\)、最长链为 \(i\) 的树的数量。
转移式为:
\[f_i = f_{i - 1} + \color{blue}{ f_{i - 1} \times \left(\sum_{j=0}^{i - 2}f_j\right) } + \color{red}{ \binom{f_{i - 1} + 1}{2} } \]第一个黑色的部分表示的是 \(\text{deg}(root) = 1\) 的情况,这时候直接取绝于第二层子树的形态。
第二个蓝色的部分表示的是 \(\text{deg}(root) = 2\),且两边子树的深度不相等的情况。显然有一边必须是深度 \(=i-1\),另一边深度 \(<i-1\),直接乘起来就好了。
第三个红色的部分表示的是 \(\text{deg}(root) = 2\),且两边子树深度均为 \(i - 1\) 的情况。注意到儿子不考虑顺序,所以对于方案 \((a, b)\) 和 \((b, a)\) 应当被认为是相同的,那么答案就是 \(\binom{f_{i - 1}}{2} + f_{i - 1} = \frac{f_{i - 1}(f_{i - 1} + 1)}{2} = \binom{f_{i - 1} + 1}{2}\),也就是要么两边方案不同,有 \(\binom{f_{i - 1}}{2}\) 种;要么两边方案相同,有 \(f_{i - 1}\) 种。
可以利用前缀和,在 \(O(n)\) 的时间内预处理出 \(f_i\)。
用 \(f'_i\) 表示除了强制 \(\text{deg}(root) = 2\) 外,其他限制和上面一样 的方案数。显然就是去掉了第一项,可以通过差分得到:\(f_i' = f_i - f_{i - 1}\)。
有了 \(f_i\) 和 \(f'_i\) 后,我们考虑回答询问,先计算前两种情况(根向/叶向树),注意 \(\text{deg}(root)\) 可以为 \(1/2/3\)。
\[\text{Answer} = 2\left(f_n + f_{n - 1} \binom{1 + \sum_{j=0}^{n-2} f_j}{2} + \left(\sum_{j=0}^{n-2} f_j \right) \binom{f_{n - 1} + 1}{2} + \binom{f_{n - 1} + 2}{3} \right) - 1 \]本质上也就是分类讨论 \(root\) 的各个子树的情况,推导和上面 \(f\) 的递推式是相似的。
至于 \(\times 2\) 是因为,当前 \(f\) 没有对边定向,而我们必须钦定根向 / 叶向。
而 \(-1\) 是因为当树的形态为一条链时,两者无区别。
接下来计算第三种情况,我们枚举两边的树的深度 \(i\) 和 \(n-i-1\),得到:
\[\text{Answer} = \sum_{i = 0}^{n - 1} (f_i - 1) f'_{n - i - 1} \]红边为“关键边”,而黄边不能为“关键边”(为了防止重复计数),而 \(i\) 指的是 \(\text{A}\) 部分的深度。所以发现 \(\text{A}\) 部分是一棵 \(\text{deg}(root) \le 2\) 的非链根向树(如果为链的话,整棵树退化为一棵叶向树),方案数为 \(f_i - 1\);而 \(\text{B}\) 部分是一棵 \(\text{deg}(root) = 2\) 的叶向树(因为红边一定是最靠右的),方案数为 \(f_{n - i - 1}'\)。根据乘法原理乘起来即可。
两部分加起来即是答案。
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 5, MOD = 998244353;
typedef long long ll;
int n, ans, f[N], pref[N];
int C2(ll x) { return (ll)x * (x - 1) % MOD * 499122177 % MOD; }
int C3(ll x) { return (ll)x * (x - 1) % MOD * (x - 2) % MOD * 166374059 % MOD; }
int main()
{
ios::sync_with_stdio(false);
cin >> n;
f[0] = 1; f[1] = 2; pref[0] = 1; pref[1] = 3;
for(int i = 2; i <= n; i++)
{
f[i] = (f[i - 1] + (ll)f[i - 1] * pref[i - 2] % MOD + C2(f[i - 1] + 1)) % MOD;
pref[i] = (pref[i - 1] + f[i]) % MOD;
}
ans = (ans + 2ll * f[n] - 1 + MOD) % MOD;
if(n >= 1) ans = (ans + 2ll * C3(f[n - 1] + 2)) % MOD;
if(n >= 2) ans = (ans + 2ll * ((ll)f[n - 1] * C2(pref[n - 2] + 1) + (ll)pref[n - 2] * C2(f[n - 1] + 1))) % MOD;
for(int i = 0; i < n; i++) ans = (ans + (ll)(f[i] + MOD - 1) % MOD * ((n - i - 1 >= 1) ? (f[n - i - 1] - f[n - i - 2] + MOD) % MOD : 0) % MOD) % MOD;
cout << ans << endl;
return 0;
}