【洛谷5437】【XR-2】约定(拉格朗日插值)
题面
题解
首先发现每条边除了边权之外都是等价的,所以可以考虑每一条边的出现次数。
显然钦定一条边之后构成生成树的方案数是\(2*n^{n-3}\)。可以直接\(purfer\)序列算。
也可以发现每一条边的出现次数相等,树的总数是\(n^{n-2}\),每次出现\(n-1\)条边,每条边又是等价的。
也可以算出上面这个值。
于是要算的东西就变成了
\[\displaystyle \sum_{i=1}^n\sum_{j=i+1}^n(i+j)^k\]
这个东西不对称,很不方便计算,所以可以变成:
\[\frac{1}{2}(\sum_{i=1}^n \sum_{j=1}^n (i+j)^k-\sum_{i=1}^n (i+i)^k)\]
\(k\)次方这个东西显然是个\(k+1\)次多项式,可以套进去直接拉格朗日插值计算。
拆一下变成了\(\displaystyle \sum_{i=1}^{n} (i-1) i^k+\sum_{i=n+1}^{2n}(2n-i+1)i^k-\sum_{i=1}^n 2^ki^k\)。
然后预处理之后,可以用拉格朗日插值可以在\(O(k)\)的复杂度里算出上面的式子,然后带回去算期望就行了。
然后这里怎么拉格朗日插值。
以第一个函数为例。
令\(f(n)=\sum_{i=1}^n (i-1)i^k\),因为\(i^{k+1}\)次方大概是一个\(k+2\)次多项式,所以我们需要\(k+3\)个值,那么显然这个函数的前\(k+3\)项我们在预处理之后是可以提前算出来的。
然后根据拉格朗日插值的公式,对于一个\(k\)次多项式而言:
\[P(x)=\sum_{i=1}^{k+1}P(x_i)\prod_{j=1,j\neq i}^{k+1}\frac{x-x_j}{x_i-x_j}\]
然后因为我们选择的值是连续的若干项,所以可以简单的写成:
\[P(x)=\sum_{i=1}^{k+1}P(x_i)\frac{(-1)^{k+1-i}}{(i-1)!(k+1-i)!}\prod_{j=1,j\neq i}^{k+1}(x-x_j)\]
在这题里,我们都已经知道\(x\)是\(n\)了,所以后半部分的\(prod\)可以用前后缀的方式快速预处理出来,这样子我们就可以\(O(k)\)的计算前面的部分了。
#include<iostream>
#include<cstdio>
using namespace std;
#define MOD 998244353
#define MAX 10000100
inline int read()
{
int x=0;bool t=false;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=true,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return t?-x:x;
}
int n,K,N,ans;
int inv[MAX],jv[MAX],suf[MAX],pre[MAX];
bool zs[MAX];
int pri[MAX],tot,pw[MAX];
int fpow(int a,int b){int s=1;while(b){if(b&1)s=1ll*s*a%MOD;a=1ll*a*a%MOD;b>>=1;}return s;}
void Sieve(int n)
{
pw[1]=1;
for(int i=2;i<=n;++i)
{
if(!zs[i])pri[++tot]=i,pw[i]=fpow(i,K);
for(int j=1;j<=tot&&i*pri[j]<=n;++j)
{
zs[i*pri[j]]=true;
pw[i*pri[j]]=1ll*pw[i]*pw[pri[j]]%MOD;
if(i%pri[j]==0)break;
}
}
}
int P[MAX];
int calc(int n)
{
int ret=0;pre[0]=suf[N+1]=1;
for(int i=1;i<=N;++i)pre[i]=1ll*pre[i-1]*(n-i+MOD)%MOD;
for(int i=N;i;--i)suf[i]=1ll*suf[i+1]*(n-i+MOD)%MOD;
for(int i=1,d=((N+1)&1)?MOD-1:1;i<=N;++i,d=MOD-d)
ret=(ret+1ll*P[i]*d%MOD*jv[i-1]%MOD*jv[N-i]%MOD*pre[i-1]%MOD*suf[i+1])%MOD;
return ret;
}
int main()
{
n=read();K=read();N=K+3;
//for(int i=1;i<=n;++i)ans=(ans+1ll*(i-1)*fpow(i,K))%MOD;
//for(int i=n+1;i<=n+n;++i)ans=(ans+1ll*(n+n-i+1)*fpow(i,K))%MOD;
//for(int i=1;i<=n;++i)ans=(ans+MOD-fpow(2*i,K))%MOD;
Sieve(N);inv[0]=inv[1]=jv[0]=1;
for(int i=2;i<=N;++i)inv[i]=1ll*inv[MOD%i]*(MOD-MOD/i)%MOD;
for(int i=1;i<=N;++i)jv[i]=1ll*jv[i-1]*inv[i]%MOD;
for(int i=1;i<=N;++i)P[i]=(P[i-1]+1ll*(i-1)*pw[i])%MOD;
ans=(ans+calc(n))%MOD;
for(int i=1;i<=N;++i)P[i]=(P[i-1]+1ll*(0ll+n+n-i+1)*pw[i])%MOD;
ans=(ans+calc((n+n)%MOD))%MOD;
ans=(ans+MOD-calc(n))%MOD;
for(int i=1;i<=N;++i)P[i]=(P[i-1]+1ll*pw[i]*pw[2])%MOD;
ans=(ans+MOD-calc(n))%MOD;
ans=1ll*ans*fpow(n,MOD-2)%MOD;
printf("%d\n",ans);
return 0;
}