[清华集训2017]生成树计数

在一个\(s\)个点的图中,存在\(s-n\)条边,使图中形成了\(n\)个连通块,第\(i\)个连通块中有\(a_i\)个点。

现在我们需要再连接\(n-1\)条边,使该图变成一棵树。对一种连边方案,设原图中第\(i\)个连通块连出了\(d_i\)条边,那么这棵树\(T\)的价值为:

\[\mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right) \]

你的任务是求出所有可能的生成树的价值之和,对\(998244353\)取模。


可能只有我没读出来题目说连通块内的连边方式不计。)

树和每个点的度数可以联想到\(prufer\)序列。那么设\(c_i\)为第\(i\)个点在\(prufer\)序列中出现的次数,则\(c_i=d_i-1\)。考虑对于一个确定的序列\(c_i\),它对答案的贡献就是

\[\frac{(n-2)!}{\prod_{i=1}^nc_i!}\times (\prod_{i=1}^na_i^{c_i+1})\times (\prod_{i=1}^n(c_i+1)^m)\times (\sum_{i=1}^n(c_i+1)^m) \]

第一项是这个序列对应的有标号无根树个树;第二个是因为每个连通块中的点可以任意分配这个连通块的出边;后面两个是题面定义的价值。

那么有一个暴力思路就是递推,在那之前我们把\((n-2)!\prod_{i=1}^na_i\)看作常数项,只考虑剩余的式子。设

\[f_{n,m}=\sum_{\sum_{c_i}=m}\frac{(\prod_{i=1}^na_i^{c_i})\times (\prod_{i=1}^n(c_i+1)^m)}{\prod_{i=1}^nc_i!}\\ g_{n,m}=\sum_{\sum_{c_i}=m}\frac{(\prod_{i=1}^na_i^{c_i})\times (\prod_{i=1}^n(c_i+1)^m)}{\prod_{i=1}^nc_i!}\times (\sum_{i=1}^n(c_i+1)^m) \]

即递推前\(n\)个点的总\(c_i\)为\(m\)的所有情况之和。转移就有:

\[f_{n,m}=\sum_{i=0}^mf_{n-1,m-i}\times \frac{a_i^i\times (i+1)^m}{i!}\\ g_{n,m}=\sum_{i=0}^m(g_{n-1,m-i}+f_{n-1,m-i}\times (i+1)^m)\times \frac{a_i^i\times (i+1)^m}{i!} \]

答案就是\(g_{n,n-2}\times \frac{(n-2)!}{\prod_{i=1}^na_i}\)。

现在就可以有\(20\)分的好成绩了。如果用\(NTT\)实现上面的转移就可以有\(40\)分的好成绩。

然后发现我这个式子并不好优化(懒得优化两个式子)瞟一眼题解之后发现开始那个式子可以化得好看些:

\[\begin{align} \frac{(n-2)!}{\prod_{i=1}^nc_i!}\times (\prod_{i=1}^na_i^{c_i+1})\times (\prod_{i=1}^n(c_i+1)^m)\times (\sum_{i=1}^n(c_i+1)^m) \\=((n-2)!\prod_{i=1}^na_i)\times (\sum_{i=1}^n(c_i+1)^{2m}(\prod_{j=1}^n\frac{a_j^{c_j}}{c_j!})(\prod_{j\neq i}(c_i+1)^m)) \end{align} \]

我们还是不管前面的常数项。可以发现每个\(c_i\)对式子的贡献就是\(\frac{a_i^{c_i}(c_i+1)^m}{c_i!}\)或者一个序列中仅有一个\(c_i\)贡献为\(\frac{a_i^{c_i}(c_i+1)^{2m}}{c_i!}\),那么这个式子就可以由若干个次数表示\(c\)的\(EGF\)乘起来(实际上如果尝试用\(EGF\)推一下那个\(n^3\)递推可以更容易发现这种性质)。乍一看一共有\(n\)个系数不同的\(n\)次多项式,似乎不可做,但是第\(i\)个多项式每一项都有\(a_i\)的若干次方,且与\(x\)次数相同,所以这些多项式都可以写成\(F(a_ix)\)的形式。因此设

\[A(x)=\frac{(i+1)^{2m}x^m}{i!},B(x)=\frac{(i+1)^mx^m}{i!} \]

答案就是\(\sum_{i=1}^n\frac{A(a_ix)}{B(a_ix)}\prod_{j=1}^nB(a_jx)\)。这样有什么好处呢?这里补充一下这种trick。

如果式子可以写成\(\sum_{i=1}^nF(a_ix)\)的形式,并且对任意\(m\)都求出了\(\sum_{i=1}^na_i^m\),那么只要求出\(F(x)\),式子就可以变成

\[\sum_{m=0}([x^m]F(x))\sum_{i=1}^na_i^m \]

因此我们求出\(\frac{A(x)}{B(x)}\)和\(\prod_{i=1}^nB(a_ix)\)即可。但后面这个是\(\prod\),和前面的\(\sum\)不同,这里就要取个\(\ln\):

\[\prod_{i=1}^nB(a_ix)=e^{\sum_{i=1}^n\ln{B(a_ix)}} \]

那么我们算一下\(\ln{B(x)}\)就可以像上面那样算了。

现在唯一的问题就是怎么对每个\(m\)求出\(\sum_{i=1}^na_i^m\)。类似于自然数幂和的推导,我们写出这个东西的\(OGF\)就有

\[G(x)=\sum_{m=0}(\sum_{i=1}^na_i^m)x^m=\sum_{i=1}^n\sum_{m=0}a_i^mx^m=\sum_{i=1}^n\frac{1}{1-a_ix} \]

这就有点像P4705玩游戏这题的技巧,因为

\[x(\ln(1-a_ix))'=-\frac{a_ix}{1-a_ix}=-\frac{1}{1-a_ix}+1 \]

那设\(H(x)=\sum_{i=1}^n(\ln(1-a_ix))'\),那\(G(x)=-xH(x)+n\)。求\(H\)就:

\[H(x)=(\sum_{i=1}^n\ln(1-a_ix))'=(\ln\prod_{i=1}^n(1-a_ix))' \]

分治\(NTT\)即可。至此这题就解决了,复杂度瓶颈为最后分治\(NTT\)的\(\mathcal{O}(n\log^2n)\)。

#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i, a, b) for(int i = (a), ed = (b); i <= ed; ++i)
#define fb(i, a, b) for(int i = (a), ed = (b); i >= ed; --i)
#define go(u) for(int i = head[u]; ~i; i = e[i].nxt)
using namespace std;
typedef cn int cint;
typedef long long LL;
il void rd(int &x){
	x = 0;
	rg int f(1); rg char c(gc);
	while(c < '0' || '9' < c){if(c == '-')f = -1; c = gc;}
	while('0' <= c && c <= '9')x = (x<<1)+(x<<3)+(c^48), c = gc;
	x *= f;
}
cint maxn = 30010, mod = 998244353, G = 3, invG = (mod+1)/3;
int n, m, a[maxn], fac[maxn], ifac[maxn], inv[maxn], mul = 1;
int lim, hst, rev[maxn<<2], A[maxn<<2], B[maxn<<2], ln[maxn<<2], invB[maxn<<2];
int H[maxn<<2], E[maxn<<2];
il int fpow(int a, int b, int ans = 1){
	for(; b; b >>= 1, a = 1ll*a*a%mod)if(b&1)ans = 1ll*ans*a%mod;
	return ans;
}
il void ntt(int *a, cint &typ){
	fp(i, 0, lim-1)if(i > rev[i])swap(a[i], a[rev[i]]);
	for(rg int md = 1; md < lim; md <<= 1){
		rg int len = md<<1, Gn = fpow(typ ? invG : G, (mod-1)/len);
		for(rg int l = 0; l < lim; l += len){
			for(rg int nw = 0, Pow = 1; nw < md; ++nw, Pow = 1ll*Pow*Gn%mod){
				rg int x = a[l+nw], y = 1ll*a[l+nw+md]*Pow%mod;
				a[l+nw] = (x+y)%mod, a[l+nw+md] = (x-y+mod)%mod;
			}
		}
	}
	if(typ){
		rg int inv = fpow(lim, mod-2);
		fp(i, 0, lim-1)a[i] = 1ll*a[i]*inv%mod;
	}
}
il void init(int n){
	lim = 1, hst = 0;
	while(lim < n)lim <<= 1, ++hst;
	fp(i, 0, lim-1)rev[i] = (rev[i>>1]>>1)|((i&1)<<hst-1);
}
int inv_ary[maxn<<2];
void get_inv(int *a, int *f, int n){
	if(n == 1)return f[0] = fpow(a[0], mod-2), void();
	get_inv(a, f, n+1>>1), init(2*n-1);
	fp(i, 0, n-1)inv_ary[i] = a[i];
	ntt(f, 0), ntt(inv_ary, 0);
	fp(i, 0, lim-1)f[i] = 1ll*f[i]*(2-1ll*f[i]*inv_ary[i]%mod+mod)%mod;
	ntt(f, 1);
	fp(i, n, lim-1)f[i] = 0;
	fp(i, 0, lim-1)inv_ary[i] = 0;
}
int ln_ary[maxn<<2];
il void get_ln(int *a, int *f, int n){
	get_inv(a, ln_ary, n), init(2*n-2);
	fp(i, 1, n-1)f[i-1] = 1ll*a[i]*i%mod;
	ntt(f, 0), ntt(ln_ary, 0);
	fp(i, 0, lim-1)f[i] = 1ll*f[i]*ln_ary[i]%mod;
	ntt(f, 1);
	fp(i, n-1, lim)f[i] = 0;
	fp(i, 0, lim)ln_ary[i] = 0;
	fb(i, n-1, 1)f[i] = 1ll*f[i-1]*inv[i]%mod;
	f[0] = 0;
}
int exp_ary[maxn<<2];
void get_exp(int *a, int *f, int n){
	if(n == 1)return f[0] = 1, void();
	get_exp(a, f, n+1>>1), get_ln(f, exp_ary, n), init(2*n-1);
	fp(i, 0, n-1)exp_ary[i] = (a[i]-exp_ary[i]+mod)%mod;
	if((++exp_ary[0]) ==  mod)exp_ary[0] = 0;
	ntt(f, 0), ntt(exp_ary, 0);
	fp(i, 0, lim-1)f[i] = 1ll*f[i]*exp_ary[i]%mod;
	ntt(f, 1);
	fp(i, n, lim-1)f[i] = 0;
	fp(i, 0, lim-1)exp_ary[i] = 0;
}
int div_ary[19][maxn<<2];
void divntt(int d, int l, int r){
	if(l == r)return div_ary[d][0] = 1, div_ary[d][1] = mod-a[l], void();
	int md = l+r>>1, len = r-l+2;
	divntt(d, l, md), divntt(d+1, md+1, r), init(len), ntt(div_ary[d], 0), ntt(div_ary[d+1], 0);
	fp(i, 0, lim-1)div_ary[d][i] = 1ll*div_ary[d][i]*div_ary[d+1][i]%mod;
	ntt(div_ary[d], 1);
	fp(i, len, lim-1)div_ary[d][i] = 0;
	fp(i, 0, lim-1)div_ary[d+1][i] = 0;
}
int main(){
//	freopen("in", "r", stdin);
	rd(n), rd(m);
	fp(i, 1, n)rd(a[i]), mul = 1ll*mul*a[i]%mod;
	fac[0] = 1; fp(i, 1, n)fac[i] = 1ll*fac[i-1]*i%mod;
	ifac[n] = fpow(fac[n], mod-2); fb(i, n, 1)ifac[i-1] = 1ll*ifac[i]*i%mod;
	inv[1] = 1; fp(i, 2, n)inv[i] = mod-1ll*(mod/i)*inv[mod%i]%mod;
	fp(i, 0, n)A[i] = 1ll*fpow(i+1, m<<1)*ifac[i]%mod;
	fp(i, 0, n)B[i] = 1ll*fpow(i+1, m)*ifac[i]%mod;

	get_ln(B, ln, n+1), get_inv(B, invB, n+1), init(2*n+1);
//	fp(i, 0, n)printf("%d ", B[i]);puts("");
//	fp(i, 0, n)printf("%d ", ln[i]);puts("");
//	fp(i, 0, n)printf("%d ", invB[i]);puts("");
//	fp(i, 0, n)printf("%d ", A[i]);puts("");
//	fp(i, 0, n)printf("%d ", invB[i]);puts("");
	ntt(invB, 0), ntt(A, 0);
	fp(i, 0, lim-1)A[i] = 1ll*A[i]*invB[i]%mod;
	ntt(A, 1);
	fp(i, n+1, lim)A[i] = 0;
//	fp(i, 0, n)printf("%d ", A[i]);puts("");
	divntt(0, 1, n), get_ln(div_ary[0], H, n+1);
//	fp(i, 0, n)printf("%d ", div_ary[0][i]);puts("");
	fp(i, 1, n)H[i] = mod-1ll*H[i]*i%mod;
	H[0] = n;

	fp(i, 0, n)ln[i] = 1ll*ln[i]*H[i]%mod;
	get_exp(ln, E, n+1);
	fp(i, 0, n)A[i] = 1ll*A[i]*H[i]%mod;
	init(2*n+1), ntt(E, 0), ntt(A, 0);
	fp(i, 0, lim-1)A[i] = 1ll*A[i]*E[i]%mod;
	ntt(A, 1), printf("%lld\n", 1ll*fac[n-2]*mul%mod*A[n-2]%mod);
	return 0;
}
上一篇:MySQL 正则表达式


下一篇:2021.3.21 哥哥的受难日