题意:给定长度为N,M的序列a,b和t,随机选取x∈[1,N],y∈[1,M],对于i=1,2,...,t,求(ax+by)i的期望
N,M,t≤100000
我为什么要打这么多标签
f(k)=x=1∑ny=1∑m(ax+by)k
暴力拆开
f(k)=x=1∑ny=1∑mi=0∑k(ik)axibyk−i
继续拆
f(k)=x=1∑ny=1∑mi=0∑ki!(k−i)!k!axibyk−i
整理一下
k!f(k)=i=0∑ki!∑x=1naxi(k−i)!∑y=1mbyk−i
显然是个卷积
#undef f
现在只需要求出
f(k)=i=1∑naik
右边同理
接下来是个一周目没法想到的神仙做法
对每个数单独考虑
fi(k)=aik
构造生成函数
fi(x)=1+aix+ai2x2+...
写成封闭形式
fi(x)=1−aix1
原来是
f(x)=i=1∑n1−aix1
加法并不好求其实很好求,分治暴力通分就可以了
考虑转成乘法做分治NTT
自然地想到算ln
而这个式子和ln有关的就只有倒数了
强行解释
即
ln′(1−aix)=1−aix1
(上述式子的自变量是1−aix)
而
[ln(1−aix)]′=ln′(1−aix)(1−aix)′=−1−aixai
我们发现这玩意和f有关系
设
gi(x)=−1−aixaig(x)=i=1∑ngi(x)
有
fi(x)=1−xgi(x)
所以
f(x)=n−xg(x)
现在只需要求出g
继续推之前的式子
g(x)=i=1∑n−1−aixai
=i=1∑n[ln(1−aix)]′
脑补一下,导数是可加的
g(x)=[i=1∑nln(1−aix)]′
拆进去
g(x)=[lni=1∏n(1−aix)]′
分治NTT即可
复杂度O(nlog2n)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#define MAXN 262144+5
using namespace std;
const int MOD=998244353;
typedef long long ll;
int fac[MAXN],finv[MAXN];
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD:x-y;}
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD;p>>=1;
}
return ans;
}
#define inv(x) qpow(x,MOD-2)
int r[MAXN],rt[2][MAXN];
inline void init(const int& l){for (int i=0;i<(1<<l);i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));}
void NTT(int* a,int l,int type)
{
int lim=1<<l;
for (int i=0;i<lim;i++) if (i<r[i]) swap(a[i],a[r[i]]);
for (int L=0;L<l;L++)
{
int mid=1<<L,len=mid<<1;
int Wn=rt[type][L+1];
for (int s=0;s<lim;s+=len)
for (int k=0,w=1;k<mid;k++,w=(ll)w*Wn%MOD)
{
int x=a[s+k],y=(ll)w*a[s+mid+k]%MOD;
a[s+k]=add(x,y);a[s+mid+k]=dec(x,y);
}
}
if (type)
{
int t=inv(lim);
for (int i=0;i<lim;i++) a[i]=(ll)a[i]*t%MOD;
}
}
void getinv(int* A,int* B,int n)
{
static int t[MAXN];
if (n==1) return (void)(*B=inv(*A));
getinv(A,B,(n+1)>>1);
int l=0;
while ((1<<l)<(n<<1)) ++l;
for (int i=0;i<n;i++) t[i]=A[i];
for (int i=n;i<(1<<l);i++) t[i]=B[i]=0;
init(l);
NTT(t,l,0);NTT(B,l,0);
for (int i=0;i<(1<<l);i++) B[i]=(ll)B[i]*(MOD+2-(ll)t[i]*B[i]%MOD)%MOD;
NTT(B,l,1);
for (int i=n;i<(1<<l);i++) B[i]=0;
}
inline void deriv(int* A,int* B,int n)
{
for (int i=0;i<n-1;i++) B[i]=(ll)A[i+1]*(i+1)%MOD;
B[n-1]=0;
}
inline void integ(int* A,int* B,int n)
{
for (int i=1;i<n;i++) B[i]=(ll)A[i-1]*finv[i]%MOD*fac[i-1]%MOD;
B[0]=0;
}
void getln(int* A,int* B,int n)
{
static int f[MAXN],g[MAXN];
deriv(A,f,n);getinv(A,g,n);
int l=0;
while ((1<<l)<(n<<1)) ++l;
init(l);
for (int i=n;i<(1<<l);i++) f[i]=g[i]=0;
NTT(f,l,0);NTT(g,l,0);
for (int i=0;i<(1<<l);i++) f[i]=(ll)f[i]*g[i]%MOD;
NTT(f,l,1);
integ(f,B,n);
}
void solve(int* a,int* f,int l,int r)
{
if (l==r)
{
f[0]=1;f[1]=MOD-a[l];
return;
}
int mid=(l+r)>>1;
int len=0;
while ((1<<len)<=r-l+1) ++len;
int L[(1<<len)+5],R[(1<<len)+5];
memset(L,0,sizeof(L));memset(R,0,sizeof(R));
solve(a,L,l,mid);solve(a,R,mid+1,r);
init(len);
NTT(L,len,0);NTT(R,len,0);
for (int i=0;i<(1<<len);i++) f[i]=(ll)L[i]*R[i]%MOD;
NTT(f,len,1);
}
int a[MAXN],b[MAXN];
int f[MAXN],g[MAXN];
int A[MAXN],B[MAXN];
int main()
{
rt[0][23]=qpow(3,119);rt[1][23]=inv(rt[0][23]);
for (int i=22;i>=0;i--)
{
rt[0][i]=(ll)rt[0][i+1]*rt[0][i+1]%MOD;
rt[1][i]=(ll)rt[1][i+1]*rt[1][i+1]%MOD;
}
int n,m;
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
for (int i=1;i<=m;i++) scanf("%d",&b[i]);
int t;
scanf("%d",&t);
fac[0]=1;
for (int i=1;i<=t;i++) fac[i]=(ll)fac[i-1]*i%MOD;
finv[t]=inv(fac[t]);
for (int i=t-1;i>=0;i--) finv[i]=(ll)finv[i+1]*(i+1)%MOD;
solve(a,g,1,n);
getln(g,f,t+1);deriv(f,g,t+1);
for (int i=1;i<=t;i++) A[i]=MOD-g[i-1];
A[0]=n;
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
solve(b,g,1,m);
getln(g,f,t+1);deriv(f,g,t+1);
for (int i=1;i<=t;i++) B[i]=MOD-g[i-1];
B[0]=m;
// for (int i=0;i<=t;i++) printf("%d%c",A[i]," \n"[i==t]);
// for (int i=0;i<=t;i++)
// {
// int sum=0;
// for (int k=1;k<=n;k++) sum=add(sum,qpow(a[k],i));
// printf("%d%c",sum," \n"[i==t]);
// }
// for (int i=0;i<=t;i++) printf("%d%c",B[i]," \n"[i==t]);
// for (int i=0;i<=t;i++)
// {
// int sum=0;
// for (int k=1;k<=m;k++) sum=add(sum,qpow(b[k],i));
// printf("%d%c",sum," \n"[i==t]);
// }
for (int i=0;i<=t;i++)
{
A[i]=(ll)A[i]*finv[i]%MOD;
B[i]=(ll)B[i]*finv[i]%MOD;
}
int l=0;
while ((1<<l)<=(t<<1)) ++l;
init(l);
NTT(A,l,0);NTT(B,l,0);
for (int i=0;i<(1<<l);i++) A[i]=(ll)A[i]*B[i]%MOD;
NTT(A,l,1);
for (int i=1;i<=t;i++) A[i]=(ll)A[i]*fac[i]%MOD;
int tmp=(ll)inv(n)*inv(m)%MOD;
for (int i=1;i<=t;i++) printf("%d\n",(ll)A[i]*tmp%MOD);
return 0;
}