P4705 玩游戏
题意:给长为\(n\)的\(\{a_i\}\)和长为\(m\)的\(\{b_i\}\),设
\[f(x)=\sum_{k\ge 0}\sum_{i=1}^n\sum_{j=1}^m\frac{(a_i+a_j)^k}{nm} x^k
\]
\]
求出\(f\)点前\(t\)项
\[\begin{aligned}
nmf(x)&=\sum_{k\ge 0}\sum_{i=1}^n\sum_{j=1}^m\sum_{l=0}^k\binom{k}{l}a_i^lb_j^{k-l}x^k\\
&=\sum_{k\ge 0}\sum_{l=0}^k\binom{k}{l}(\sum_{i=1}^na_i^l)(\sum_{j=1}^mb_j^{k-l})x^k\\
\end{aligned}
\]
nmf(x)&=\sum_{k\ge 0}\sum_{i=1}^n\sum_{j=1}^m\sum_{l=0}^k\binom{k}{l}a_i^lb_j^{k-l}x^k\\
&=\sum_{k\ge 0}\sum_{l=0}^k\binom{k}{l}(\sum_{i=1}^na_i^l)(\sum_{j=1}^mb_j^{k-l})x^k\\
\end{aligned}
\]
定义\(EGF\)
\[A(x)=\sum_{i\ge 0}\sum_{j=1}^na_j^i\frac{x^i}{i!}
\]
\]
对\(B\)同理
设
\[\begin{aligned}
F(x)&=\sum_{i\ge 0}\sum_{j=1}^na_j^ix^i\\
&=\sum_{j=1}^n\frac{1}{1-a_jx}
\end{aligned}
\]
F(x)&=\sum_{i\ge 0}\sum_{j=1}^na_j^ix^i\\
&=\sum_{j=1}^n\frac{1}{1-a_jx}
\end{aligned}
\]
设
\[\begin{aligned}
G(x)&=\sum_{j=1}^n\ln'(1-a_jx)\\
&=\sum_{j=1}^n\frac{-a_j}{1-a_jx}
\end{aligned}
\]
G(x)&=\sum_{j=1}^n\ln'(1-a_jx)\\
&=\sum_{j=1}^n\frac{-a_j}{1-a_jx}
\end{aligned}
\]
因此
\[A(x)=n-G(x)x
\]
\]
把\(G\)化简一下
\[\begin{aligned}
G(x)&=\ln'(\prod_{j=1}^n(-a_jx+1))
\end{aligned}
\]
G(x)&=\ln'(\prod_{j=1}^n(-a_jx+1))
\end{aligned}
\]
下面可以做分治卷积
总复杂度\(O(n\log^2 n)\)
注意上界取\(\max(n,m,t)\)
Code:
#include <cstdio>
#include <cctype>
#include <algorithm>
using std::max;
const int N=(1<<20)+10;
const int mod=998244353,Gi=332748118;
template <class T>
void read(T &x)
{
x=0;char c=getchar();
while(!isdigit(c)) c=getchar();
while(isdigit(c)) x=x*10+c-'0',c=getchar();
}
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
#define mul(a,b) (1ll*(a)*(b)%mod)
int qp(int d,int k){int f=1;while(k){if(k&1)f=mul(f,d);d=mul(d,d),k>>=1;}return f;}
int a[N],b[N],fac[N],inv[N],Inv[N],turn[N],inva[N],invb[N],lna[N],lnb[N],A[N],yuua[N],yuub[N];
void NTT(int *a,int len,int typ)
{
int L=-1;for(int i=1;i<len;i<<=1) ++L;
for(int i=0;i<len;i++)
{
turn[i]=turn[i>>1]>>1|(i&1)<<L;
if(i<turn[i]) std::swap(a[i],a[turn[i]]);
}
for(int le=1;le<len;le<<=1)
{
int wn=qp(typ?3:Gi,(mod-1)/(le<<1));
for(int p=0;p<len;p+=le<<1)
{
int w=1;
for(int i=p;i<p+le;i++,w=mul(w,wn))
{
int x=a[i],y=mul(w,a[i+le]);
a[i]=add(x,y);
a[i+le]=add(x,mod-y);
}
}
}
if(!typ) for(int i=0;i<len;i++) a[i]=mul(a[i],Inv[len]);
}
void polyinv(int *a,int *b,int len)
{
if(len==1) {b[0]=qp(a[0],mod-2);return;}
polyinv(a,b,len>>1);
for(int i=0;i<len<<1;i++) inva[i]=invb[i]=0;
for(int i=0;i<len;i++) inva[i]=a[i],invb[i]=b[i];
NTT(inva,len<<1,1),NTT(invb,len<<1,1);
for(int i=0;i<len<<1;i++) inva[i]=mul(invb[i],add(2,mod-mul(inva[i],invb[i])));
NTT(inva,len<<1,0);
for(int i=0;i<len;i++) b[i]=inva[i];
}
void polyd(int *a,int len)
{
for(int i=0;i<len-1;i++) a[i]=mul(a[i+1],i+1);a[len-1]=0;
}
void polyint(int *a,int len)
{
for(int i=len-1;i;i--) a[i]=mul(a[i-1],Inv[i]);a[0]=0;
}
void polyln(int *a,int len)
{
for(int i=0;i<len<<1;i++) lna[i]=lnb[i]=0;
for(int i=0;i<len;i++) lna[i]=a[i];
polyinv(lna,lnb,len);
polyd(lna,len);
NTT(lna,len<<1,1),NTT(lnb,len<<1,1);
for(int i=0;i<len<<1;i++) lna[i]=mul(lna[i],lnb[i]);
NTT(lna,len<<1,0);
polyint(lna,len);
for(int i=0;i<len;i++) a[i]=lna[i];
}
void CDQ(int *a,int *b,int l,int r)
{
if(l==r){b[l]=add(mod,-a[l]);return;}
int mid=l+r>>1;
CDQ(a,b,l,mid),CDQ(a,b,mid+1,r);
int len=r+1-l;
for(int i=0;i<len<<1;i++) yuua[i]=yuub[i]=!i;
for(int i=l;i<=mid;i++) yuua[i+1-l]=b[i];
for(int i=mid+1;i<=r;i++) yuub[i-mid]=b[i];
NTT(yuua,len<<1,1),NTT(yuub,len<<1,1);
for(int i=0;i<len<<1;i++) yuua[i]=mul(yuua[i],yuub[i]);
NTT(yuua,len<<1,0);
for(int i=l;i<=r;i++) b[i]=yuua[i+1-l];
}
void init(int *a,int len,int n)
{
for(int i=0;i<len<<1;i++) A[i]=!i;
CDQ(a,A,1,len);
polyln(A,len),polyd(A,len);
for(int i=len-1;i;i--) a[i]=mul(A[i-1],mod-1);
a[0]=n;
for(int i=0;i<len;i++) a[i]=mul(a[i],inv[i]);
}
void init(int len)
{
fac[0]=1;for(int i=1;i<=len;i++) fac[i]=mul(fac[i-1],i);
inv[len]=qp(fac[len],mod-2);
for(int i=len-1;~i;i--) inv[i]=mul(inv[i+1],i+1);
for(int i=0;i<=len;i++) Inv[i]=qp(i,mod-2);
}
int main()
{
int n,m,t;
read(n),read(m);
for(int i=1;i<=n;i++) read(a[i]);
for(int i=1;i<=m;i++) read(b[i]);
read(t);
int len=1;
while(len<=max(max(n,m),t)) len<<=1;
init(len<<1);
init(a,len,n);
init(b,len,m);
NTT(a,len<<1,1),NTT(b,len<<1,1);
for(int i=0;i<len<<1;i++) a[i]=mul(a[i],b[i]);
NTT(a,len<<1,0);
int INV=qp(mul(n,m),mod-2);
for(int i=1;i<=t;i++) printf("%lld\n",mul(mul(a[i],fac[i]),INV));
return 0;
}
2019.3.8