Description
给定一棵以 \(1\) 为根 \(n\) 个节点的树。
定义 \(f(k)\) :从树上等概率随机选出 \(k\) 个节点,这 \(k\) 个点的虚树大小的期望。
一个点 \(x\) 在这些被选出的 \(k\) 个点的虚树上,当且仅当它满足下列条件至少一个:
- \(x\) 被选出。
- 存在两个被选出的节点 \(a,b\),使得 \(\operatorname{lca}(a,b)=x\)。
给定 \(m\),求 \(f(1),f(2),\cdots,f(m)\)。 对 \(998244353\) 取模。\(n\leq 4\cdot 10^5\)。
Sol
又是套着期望皮的计数题。
对于每个点 \(i\) 求出有多少种方案对答案有贡献即可:
- \(i\) 被选出,总方案数为 \(C(n-1,k-1)\) 。
- \(i\) 至少两个儿子的子树中存在被选出的点。
第二种不太好算,考虑用总方案数减去不合法的方案数。
总方案数就是 \(C(n-1,k)\)。
如果点 \(i\) 的子树中没有被选中的,方案数为 \(C(n-sze[i],k)\)。
只有一个儿子的子树中有被选中的,可以枚举儿子 \(j\),方案数就是 \(\sum\limits_{j} C(n-sze[i]+sze[j],k)\)。
注意到这样的话,\(i\) 子树中没有被选中的方案数被多算了 儿子个数次,所以还需要加上 \(son[i]\times C(n-sze[i],k)\)。
所以
\[f(k)=\sum\limits_{i=1}^n C_{n-1}^{k-1}+C_{n-1}^k+(son[i]-1)\times C_{n-sze[i]}^k-\sum_j C_{n-sze[i]+sze[j]}^k
\]
\]
\[f(k)=\sum\limits_{i=1}^n C_{n}^{k}+(son[i]-1)\times C_{n-sze[i]}^k-\sum_j C_{n-sze[i]+sze[j]}^k
\]
\]
如何对于每个 \(k\) 快速求呢?
观察到式子中的每一项组合数的上标都是 \(k\),所以我们可以开个桶 \(buc[i]\),在形如 \(buc[n-sze[i]]\) 的地方加上 \(son[i]+1\),在 \(buc[n-sze[i]+sze[j]]\) 处 \(-1\)。
好处就是,再推一步式子:
\[f(k)=\sum_{i=0}^n buc[i]\cdot C_i^k
\]
\]
这就是个卷积的形式,\(\mathbf{NTT}\)优化就吼了。
Code
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=2e6+5;
const int mod=998244353;
int son[N],sze[N],buc[N];
int n,m,cnt,head[N],fac[N];
int a[N],b[N],lim,rev[N],ifac[N];
struct Edge{
int to,nxt;
}edge[N<<1];
void add(int x,int y){
edge[++cnt].to=y;
edge[cnt].nxt=head[x];
head[x]=cnt;
}
int ksm(int a,int b=mod-2,int ans=1){
while(b){
if(b&1) ans=1ll*ans*a%mod;
a=1ll*a*a%mod;b>>=1;
} return ans;
}
void ntt(int *f,int g){
for(int i=1;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
for(int mid=1;mid<lim;mid<<=1){
int tmp=ksm(g,(mod-1)/(mid<<1));
for(int R=mid<<1,j=0;j<lim;j+=R){
for(int w=1,k=0;k<mid;k++,w=1ll*w*tmp%mod){
int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
}
}
} if(g>3)
for(int in=ksm(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
void init(int n){
fac[0]=ifac[0]=1;
for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
ifac[n]=ksm(fac[n]);
for(int i=n-1;i;i--) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
}
void dfs(int now,int fa=0){
sze[now]=1; int tot=0; buc[n]++;
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
if(sze[to]) continue;
tot++; dfs(to,now);
sze[now]+=sze[to];
}
for(int i=head[now];i;i=edge[i].nxt){
int to=edge[i].to;
if(to==fa) continue;
(buc[n-sze[now]+sze[to]]+=mod-1)%=mod;
} (buc[n-sze[now]]+=tot-1+mod)%=mod;
}
int C(int n,int m){
if(n<m) return 0;
return 1ll*ifac[n]*fac[m]%mod*fac[n-m]%mod;
}
signed main(){
n=getint(),m=getint(),init(N-5);
for(int i=1;i<n;i++){
int x=getint(),y=getint();
add(x,y),add(y,x);
} dfs(1);
lim=1;while(lim<=n+n) lim<<=1;
for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
for(int i=0;i<=n;i++)
a[n-i]=1ll*buc[i]*fac[i]%mod,
b[i]=ifac[i];
ntt(a,3),ntt(b,3);
for(int i=0;i<lim;i++) a[i]=1ll*a[i]*b[i]%mod;
ntt(a,(mod+1)/3);
for(int i=1;i<=m;i++)
printf("%lld\n",1ll*a[n-i]*ifac[i]%mod*C(n,i)%mod);
return 0;
}