题目描述
给你一个数列:
\[f_n=\begin{cases}
a^n&1\leq n\leq k\\
\sum_{i=1}^k(a-1)f_{n-i}&n>k
\end{cases}
\]
a^n&1\leq n\leq k\\
\sum_{i=1}^k(a-1)f_{n-i}&n>k
\end{cases}
\]
记\(g_i\)为当\(k=i\)时\(f_n\)的值,求
\[\sum_{i=1}^mg_i\times {19260817}^i
\]
\]
对于\(60\%\)的数据:\(m\leq 200,n\leq {10}^9\)
对于另外\(40\%\)的数据:\(m\leq {10}^9,n\leq 3\times {10}^6\)
题解
第一部分
直接按常系数线性递推的通用做法来做。
可以不用FFT。
时间复杂度:\(O(m^3\log n)\)或\(O(m^2\log m\log n)\)
第二部分
因为当\(i\geq n\)时\(g_i=a^n\),所以我们只需要求\(g_1\ldots g_{n-1}\)
\[\begin{align}
f_n&=af_{n-1}-(a-1)f_{n-m-1}\\
F(x)&=axF(x)-(a-1)x^{k+1}F(x)+ax-ax^{k+1}\\
(1-ax+(a-1)x^{k+1})F(x)&=ax-ax^{k+1}\\
F(x)&=\frac{ax-ax^{k+1}}{1-ax+(a-1)x^{k+1}}\\
&=(ax-ax^{k+1})\sum_{i=0}^\infty\sum_{j=0}^i\binom{i}{j}{(1-a)}^jx^{j(k+1)}a^{i-j}x^{i-j}\\
\end{align}
\]
f_n&=af_{n-1}-(a-1)f_{n-m-1}\\
F(x)&=axF(x)-(a-1)x^{k+1}F(x)+ax-ax^{k+1}\\
(1-ax+(a-1)x^{k+1})F(x)&=ax-ax^{k+1}\\
F(x)&=\frac{ax-ax^{k+1}}{1-ax+(a-1)x^{k+1}}\\
&=(ax-ax^{k+1})\sum_{i=0}^\infty\sum_{j=0}^i\binom{i}{j}{(1-a)}^jx^{j(k+1)}a^{i-j}x^{i-j}\\
\end{align}
\]
记
\[G(x)=\sum_{i=0}^\infty\sum_{j=0}^i\binom{i}{j}{(1-a)}^jx^{j(k+1)}a^{i-j}x^{i-j}\\
\]
\]
那么
\[\begin{align}
F(x)&=(ax-ax^{k+1})G(x)\\
[x^n]F(x)&=a[x^{n-1}]G(x)-a[x^{n-k-1}]G(x)\\
[x^n]G(x)&=[x^n]\sum_{i=0}^\infty\sum_{j=0}^i\binom{i}{j}{(1-a)}^ja^{i-j}x^{jk+i}\\
&=\sum_{j}\sum_{i=n-jk}\binom{i}{j}{(1-a)}^ja^{i-j}\\
&=\sum_{j}\binom{n-jk}{j}{(1-a)}^ja^{n-j(k+1)}\\
\end{align}
\]
F(x)&=(ax-ax^{k+1})G(x)\\
[x^n]F(x)&=a[x^{n-1}]G(x)-a[x^{n-k-1}]G(x)\\
[x^n]G(x)&=[x^n]\sum_{i=0}^\infty\sum_{j=0}^i\binom{i}{j}{(1-a)}^ja^{i-j}x^{jk+i}\\
&=\sum_{j}\sum_{i=n-jk}\binom{i}{j}{(1-a)}^ja^{i-j}\\
&=\sum_{j}\binom{n-jk}{j}{(1-a)}^ja^{n-j(k+1)}\\
\end{align}
\]
观察到对于所有的\(k\),\(j\)的取值总共有\(O(n\log n)\)种,所以可以暴力枚举\(k,j\)。
时间复杂度:\(O(n\log n+\log m)\)
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
#include<iostream>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
int rd()
{
int s=0,c;
while((c=getchar())<'0'||c>'9');
s=c-'0';
while((c=getchar())>='0'&&c<='9')
s=s*10+c-'0';
return s;
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
char str[100];
sprintf(str,"%s.in",s);
freopen(str,"r",stdin);
sprintf(str,"%s.out",s);
freopen(str,"w",stdout);
#endif
}
const ll p=998244353;
const ll vv=19260817;
ll fp(ll a,ll b)
{
ll s=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
s=s*a%p;
return s;
}
int n,k,m;
ll pw[100010];
ll c[510];
ll d[510];
ll e[510];
int len;
void mul()
{
static ll f[510];
for(int i=0;i<=2*len;i++)
f[i]=0;
for(int i=0;i<len;i++)
for(int j=0;j<len;j++)
f[i+j]=(f[i+j]+d[i]*d[j])%p;
for(int i=0;i<2*len;i++)
d[i]=f[i];
}
void mod()
{
for(int i=2*len;i>=len;i--)
if(d[i])
{
ll v=d[i];
for(int j=0;j<=len;j++)
d[i-len+j]=(d[i-len+j]-v*c[j])%p;
}
}
void pow(int n)
{
if(!n)
return;
pow(n>>1);
mul();
if(n&1)
{
for(int i=2*len;i>=1;i--)
d[i]=d[i-1];
d[0]=0;
}
mod();
}
ll calc1(int x)
{
len=x;
memset(d,0,sizeof d);
d[0]=1;
c[x]=1;
for(int i=0;i<x;i++)
c[i]=-k+1;
pow(n-1);
ll ans=0;
for(int i=1;i<=x;i++)
ans=(ans+pw[i]*d[i-1])%p;
return ans;
}
void solve1()
{
ll ans=0;
pw[0]=1;
for(int i=1;i<=m;i++)
pw[i]=pw[i-1]*k%p;
for(int i=m;i>=1;i--)
ans=(ans+calc1(i))*vv%p;
ans=(ans+p)%p;
printf("%lld\n",ans);
}
int fac[3000010];
int inv[3000010];
int ifac[3000010];
int s1[3000010];
int s2[3000010];
int getc(int x,int y)
{
return (ll)fac[x]*ifac[y]%p*ifac[x-y]%p;
}
int gao(int n,int m)
{
int s=0;
for(int i=0;i*(m+1)<=n;i++)
s=(s+(ll)fac[n-i*m]*s1[i]%p*s2[n-i*(m+1)])%p;
return s;
}
void solve2()
{
inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;
for(int i=2;i<=3000000;i++)
fac[i]=(ll)fac[i-1]*i%p;
s1[0]=1;
for(int i=1;i<=3000000;i++)
s1[i]=(ll)s1[i-1]*(1-k)%p;
s2[0]=1;
for(int i=1;i<=3000000;i++)
s2[i]=(ll)s2[i-1]*k%p;
for(int i=2;i<=3000000;i++)
{
inv[i]=(ll)-p/i*inv[p%i]%p;
ifac[i]=(ll)ifac[i-1]*inv[i]%p;
s1[i]=(ll)s1[i]*ifac[i]%p;
s2[i]=(ll)s2[i]*ifac[i]%p;
}
ll ans=0;
if(n<=m)
{
for(int i=n-1;i>=1;i--)
ans=(ans+gao(n-1,i)-gao(n-i-1,i))*vv%p;
ans=ans*k%p;
ll v=fp(k,n)*(fp(vv,m+1)-fp(vv,n))%p*fp(vv-1,p-2)%p;
ans=(ans+v)%p;
ans=(ans+p)%p;
}
else
{
for(int i=m;i>=1;i--)
ans=(ans+gao(n-1,i)-gao(n-i-1,i))*vv%p;
ans=ans*k%p;
ans=(ans+p)%p;
}
printf("%lld\n",ans);
}
int main()
{
open("a");
scanf("%d%d%d",&m,&k,&n);
if(m<=200)
solve1();
else
solve2();
return 0;
}