题目大意:$n$ 个点的完全图,点 $i$ 和点 $j$ 的边权为 $(i+j)^k$。随机一个生成树,问这个生成树边权和的期望对 $998244353$ 取模的值。
$1\le n\le 998244352,1\le k\le 10^7$。
其实也是一道比较简单的题。(所以就应该把这题和上一道原题调个位置)
考虑一条边在生成树中出现的概率,由于一共有 $\dfrac{n(n-1)}{2}$ 条边,一个生成树有 $n-1$ 条边,而每条边的概率相等,所以为 $\dfrac{2}{n}$。
那么开始推式子:(注:第三步是枚举 $i+j$)
$$\dfrac{2}{n}\sum\limits_{i=1}^n\sum\limits_{j=i+1}^n(i+j)^k$$$$\dfrac{1}{n}(\sum\limits_{i=1}^n\sum\limits_{j=1}^n(i+j)^k-\sum\limits^n_{i=1}(i+i)^k)$$
$$\dfrac{1}{n}(\sum\limits_{s=1}^{2n}s^k\min(s-1,2n+1-s)-2^k\sum\limits^n_{i=1}i^k)$$
$$\dfrac{1}{n}(\sum\limits_{s=1}^{n}s^k(s-1)+\sum\limits_{s=n+1}^{2n}s^k(2n+1-s)-2^k\sum\limits^n_{i=1}i^k)$$
$$\dfrac{1}{n}(\sum\limits_{s=1}^{n}s^{k+1}-\sum\limits_{s=1}^{n}s^{k}+(2n+1)\sum\limits_{s=n+1}^{2n}s^k-\sum\limits_{s=n+1}^{2n}s^{k+1}-2^{k}\sum\limits^n_{i=1}i^k)$$
$$\dfrac{1}{n}(2\sum\limits_{i=1}^{n}i^{k+1}-(2n+2+2^k)\sum\limits_{i=1}^{n}i^{k}+(2n+1)\sum\limits_{i=1}^{2n}i^k-\sum\limits_{i=1}^{2n}i^{k+1})$$
现在问题就是求 $f(n)=\sum\limits_{i=1}^ni^k$ 了。
由于 $f(n)-f(n-1)=n^k$,$f$ 的差值是个 $k$ 次多项式,所以 $f$ 是个 $k+1$ 次多项式。
那么可以拉格朗日插值。(以下内容的代码实现细节比较多,注意要控制复杂度不带 $\log$)
取 $k+2$ 个点为 $1$ 到 $k+2$,发现点值 $y_i$ 可以 $O(k)$ 计算。($y_i=y_{i-1}+i^k$ 不能直接快速幂,不然带 $\log$。可以用欧拉筛筛出所有 $k$ 次方)
$$f(n)=\sum\limits_{i=1}^{k+2}y_i\dfrac{\prod\limits^{k+2}_{j=1,j\ne i}(n-x_j)}{\prod\limits^{k+2}_{j=1,j\ne i}(x_i-x_j)}$$
这样拉格朗日插值公式中的分母就是两个阶乘相乘的形式,可以 $O(1)$。(预处理要注意控制复杂度)
代入一个数算时,先特判 $n\ge mod$(因为会调用到 $f(2n)$),此时 $f(n)=\lfloor\dfrac{n}{mod}\rfloor f(mod-1)+f(n\%mod)$。
否则先算出 $fac=\prod\limits_{i=1}^{k+2}(n-i)$。同时预处理出所有 $n-i$ 的逆元 $inv_i$。(不要一个个快速幂算,复杂度错的。要用 $O(k+\log)$ 的方式)
此时就有:
$$f(n)=\sum\limits_{i=1}^{k+2}y_i\dfrac{fac\times inv_i}{(i-1)!(k-i+2)!(-1)^{k-i+2}}$$
已经可以 $O(k)$ 计算了。
时间复杂度 $O(k+\log)$。
#include<bits/stdc++.h> using namespace std; const int maxn=10001000,mod=998244353; #define FOR(i,a,b) for(int i=(a);i<=(b);i++) #define ROF(i,a,b) for(int i=(a);i>=(b);i--) #define MEM(x,v) memset(x,v,sizeof(x)) inline int read(){ char ch=getchar();int x=0,f=0; while(ch<'0' || ch>'9') f|=ch=='-',ch=getchar(); while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar(); return f?-x:x; } int n,k,ans,kp[maxn],k1p[maxn],pr[maxn/10],pl,ky[maxn],k1y[maxn],fac[maxn],invfac[maxn],tfac[maxn],tinv[maxn]; bool vis[maxn]; int qpow(int a,int b){ int ans=1; for(;b;b>>=1,a=1ll*a*a%mod) if(b&1) ans=1ll*ans*a%mod; return ans; } int kcalc(int x){ if(x<=k+2){ int ans=1ll*ky[x]*fac[x-1]%mod*fac[k-x+2]%mod; if((k-x)&1) return ans?mod-ans:0; else return ans; } if(x>=mod) return (kcalc(mod-1)+kcalc(x%mod))%mod; tfac[0]=1; FOR(i,1,k+2) tfac[i]=1ll*tfac[i-1]*(x-i)%mod; tinv[k+2]=qpow(tfac[k+2],mod-2); ROF(i,k+1,1) tinv[i]=1ll*tinv[i+1]*(x-i-1)%mod; FOR(i,2,k+2) tinv[i]=1ll*tinv[i]*tfac[i-1]%mod; int ans=0; FOR(i,1,k+2) ans=(ans+1ll*ky[i]*tfac[k+2]%mod*tinv[i])%mod; return ans; } int k1calc(int x){ if(x<=k+3){ int ans=1ll*k1y[x]*fac[x-1]%mod*fac[k-x+3]%mod; if((k-x)&1) return ans; else return ans?mod-ans:0; } if(x>=mod) return (k1calc(mod-1)+k1calc(x%mod))%mod; tfac[0]=1; FOR(i,1,k+3) tfac[i]=1ll*tfac[i-1]*(x-i)%mod; tinv[k+3]=qpow(tfac[k+3],mod-2); ROF(i,k+2,1) tinv[i]=1ll*tinv[i+1]*(x-i-1)%mod; FOR(i,2,k+3) tinv[i]=1ll*tinv[i]*tfac[i-1]%mod; int ans=0; FOR(i,1,k+3) ans=(ans+1ll*k1y[i]*tfac[k+3]%mod*tinv[i])%mod; return ans; } int main(){ n=read();k=read(); fac[0]=1; FOR(i,1,k+3) fac[i]=1ll*fac[i-1]*i%mod; invfac[k+3]=qpow(fac[k+3],mod-2); ROF(i,k+2,0) invfac[i]=1ll*invfac[i+1]*(i+1)%mod; kp[1]=k1p[1]=1; FOR(i,2,k+3){ if(!vis[i]){ pr[++pl]=i; kp[i]=qpow(i,k); k1p[i]=qpow(i,k+1); } FOR(j,1,pl){ if(i*pr[j]>k+3) break; vis[i*pr[j]]=true; kp[i*pr[j]]=1ll*kp[i]*kp[pr[j]]%mod; k1p[i*pr[j]]=1ll*k1p[i]*k1p[pr[j]]%mod; if(i%pr[j]==0) break; } } ky[1]=k1y[1]=1; FOR(i,2,k+3) ky[i]=(ky[i-1]+kp[i])%mod,k1y[i]=(k1y[i-1]+k1p[i])%mod; FOR(i,1,k+3){ ky[i]=1ll*ky[i]*invfac[i-1]%mod*invfac[k-i+2]%mod; if((k-i)&1) ky[i]=ky[i]?mod-ky[i]:0; k1y[i]=1ll*k1y[i]*invfac[i-1]%mod*invfac[k-i+3]%mod; if(!((k-i)&1)) k1y[i]=k1y[i]?mod-k1y[i]:0; } ans=2*k1calc(n)%mod; ans=(ans-(2ll*n+2+qpow(2,k))*kcalc(n)%mod+mod)%mod; ans=(ans+1ll*(2*n+1)*kcalc(2*n)%mod)%mod; ans=(ans-k1calc(2*n)+mod)%mod; ans=1ll*ans*qpow(n,mod-2)%mod; printf("%d\n",ans); }View Code