以下考虑直接对所有$F(A)$求和,并给出两种做法——
做法1:
枚举答案$A$,对应方案数为${n-A\choose m}^{2}-{n-A-1\choose m}^{2}$,即答案为$\sum_{i=0}^{n-1}({n-i\choose m}^{2}-{n-i-1\choose m}^{2})F(i)$
记$G(n)=\sum_{i=0}^{n-1}{n-i\choose m}^{2}F(i)$,有$ans=G(n)-G(n-1)$,以下仅考虑求$G(n)$
不难证明$G(n)$是一个关于$n$的不超过$3m+1$次的多项式(将$3m$次多项式求前缀和即为$3m+1$次),更具体的,我们只需要求出$G(0),G(1),...,G(3m+1)$就可以确定$G(n)$
根据拉格朗日差值,维护$o(m)$以内的阶乘即逆元就可以$o(m)$的插出一个位置上的值
接下来,我们考虑怎么求出$G(i)$($0\le i\le 3m+1$):
先考虑求出所有$F(i)$($0\le i\le 3m+1$)的值,单次$o(m)$总复杂度即为$o(m^{2})$,无法通过
考虑拉格朗日插值法具体的式子,即为$F(i)=\sum_{j=0}^{m}F(j)\prod_{0\le k\le m,k\ne j}\frac{i-k}{j-k}$
预处理出$H_{j}=F(j)\prod_{0\le k\le m,k\ne j}\frac{1}{j-k}$,则$F(i)=\sum_{j=0}^{m}H_{j}\prod_{0\le k\le m,k\ne j}(i-k)$
当$0\le i\le m$,直接可得结果,那么假设$m<i\le 3m+1$,则$\prod_{0\le k\le m,k\ne j}(i-k)=\frac{\prod_{0\le k\le m}(i-k)}{i-j}$
简单化简,即$F(i)=\prod_{j=0}^{m}(i-j)\sum_{j=0}^{m}\frac{H_{j}}{i-j}$,前者利用逆元即可以从上一个$o(1)$算出,后者考虑$H_{i}$的生成函数$H(x)=\sum_{i=0}^{m}H_{i}x^{i}$和$H'(x)=\sum_{i=1}^{3m+1}\frac{x^{i}}{i}$,那么即$(H\times H')(x)[x^{i}]$
接下来,对于$G(i)=\sum_{j=0}^{i-1}{i-j\choose m}^{2}F(j)$,考虑$F(j)$的生成函数$F'(x)=\sum_{i=0}^{3m+1}F(i)x^{i}$以及$H''(x)=\sum_{i=1}^{3m+1}{i\choose m}^{2}x^{i}$,有$G(i)=(F'\times H'')(x)[x^{i}]$
用ntt计算多项式乘法,复杂度为$o(m\log m)$
最后$G(n)=\sum_{i=0}^{3m+1}G(i)\prod_{0\le j\le 3m+1,j\ne i}\frac{n-j}{i-j}$,$o(m)$差值即可
但这一做法的常数较大,我写不过去QAQ
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N (1<<23) 4 #define mod 998244353 5 int n,m,t,inv[N],Inv[N],rev[N],a[N],b[N],f[N]; 6 int ksm(int n,int m){ 7 n=(n+mod)%mod; 8 int s=n,ans=1; 9 while (m){ 10 if (m&1)ans=1LL*ans*s%mod; 11 s=1LL*s*s%mod; 12 m>>=1; 13 } 14 return ans; 15 } 16 void ntt(int *a,int p){ 17 for(int i=0;i<N;i++) 18 if (i<rev[i])swap(a[i],a[rev[i]]); 19 for(int i=2;i<=N;i<<=1){ 20 int s=ksm(3,(mod-1)/i); 21 if (p)s=ksm(s,mod-2); 22 for(int j=0;j<N;j+=i) 23 for(int k=0,ss=1;k<(i>>1);k++,ss=1LL*ss*s%mod){ 24 int x=a[j+k],y=1LL*a[j+k+(i>>1)]*ss%mod; 25 a[j+k]=(x+y)%mod; 26 a[j+k+(i>>1)]=(x+mod-y)%mod; 27 } 28 } 29 if (p){ 30 int s=ksm(N,mod-2); 31 for(int i=0;i<N;i++)a[i]=1LL*a[i]*s%mod; 32 } 33 } 34 int get(int n){ 35 if (n<=t)return f[n]; 36 int ss=1,ans=0; 37 for(int i=0;i<=t;i++)ss=1LL*ss*(n-i)%mod; 38 for(int i=0;i<=t;i++){ 39 int s=1LL*Inv[t-i]*Inv[i]%mod; 40 if ((t-i)&1)s=mod-s; 41 s=1LL*s*ss%mod*ksm(n-i,mod-2)%mod; 42 ans=(ans+1LL*f[i]*s)%mod; 43 } 44 return ans; 45 } 46 int main(){ 47 scanf("%d%d",&n,&m); 48 t=3*m+1; 49 for(int i=0;i<=m;i++)scanf("%d",&f[i]); 50 inv[0]=inv[1]=Inv[0]=1; 51 for(int i=2;i<=t;i++)inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod; 52 for(int i=1;i<=t;i++)Inv[i]=1LL*Inv[i-1]*inv[i]%mod; 53 for(int i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)+((i&1)*(N>>1)); 54 for(int i=0;i<=m;i++){ 55 int s=1LL*Inv[m-i]*Inv[i]%mod; 56 if ((m-i)&1)s=mod-s; 57 a[i]=1LL*s*f[i]%mod; 58 } 59 for(int i=1;i<=t;i++)b[i]=inv[i]; 60 ntt(a,0); 61 ntt(b,0); 62 for(int i=0;i<N;i++)a[i]=1LL*a[i]*b[i]%mod; 63 ntt(a,1); 64 int s=1; 65 for(int i=1;i<=m;i++)s=1LL*s*i%mod; 66 for(int i=m+1;i<=t;i++){ 67 s=1LL*s*inv[i-m-1]%mod*i%mod; 68 f[i]=1LL*a[i]*s%mod; 69 } 70 memset(a,0,sizeof(a)); 71 s=1; 72 for(int i=m;i<=t;i++){ 73 a[i]=1LL*s*s%mod; 74 s=1LL*s*(i+1)%mod*inv[i-m+1]%mod; 75 } 76 ntt(a,0); 77 ntt(f,0); 78 for(int i=0;i<N;i++)f[i]=1LL*f[i]*a[i]%mod; 79 ntt(f,1); 80 printf("%d",(get(n)-get(n-1)+mod)%mod); 81 }View Code
做法2:
考虑$F(x)=1$的情况,由于$F(A)=1$,答案即为${n\choose m}^{2}$
当然,也可以枚举最终染色的球的数量$k$,对应方案数为${n\choose k}{k\choose m}{m\choose 2m-k}$(最后一个是$m$是因为第二次染色重复只能与第一次重复,内部不能重复),复杂度为$o(m)$
(由于以下还会使用${k\choose m}{m\choose 2m-k}$,将之记作$H_{k}$)
这一做法看上去没有什么意义,但其提示我们具体的位置是不好枚举的,要考虑个数,也就是要让$F(x)$与位置无关(如$F(x)=1$就与位置无关)
考虑$F(x)=x$的情况,构造$A$的组合意义:在两次染色完毕后,再染一个球,且这个球必须是$A$之前的球,那么第三次染色方案数恰好为$A$,三次染色的总方案数即为答案
枚举这三次染色所染的球数量$k+1$,类似的方案数即为${n\choose k+1}H_{k}$
考虑$F(x)=x^{c}$的情况,构造$A^{c}$的组合意义:在两次染色完毕后,再染$c$个球,允许内部重复但同样必须是$A$之前的球的方案数,同样这三次染色的总方案数即为答案
枚举这三次染色所染的球数量$k+t$,前半部分相同为${n\choose k+t}H_{k}$,下面考虑最后$t$个球如何染色
这是一个可以容斥的问题,即枚举强制不能染色的位置个数$i$,即$\sum_{i=0}^{t}(-1)^{i}{t\choose i}(t-i)^{c}$
为了让其可以计算,将组合数展开,即$t!\sum_{i=0}^{t}\frac{(-1)^{i}}{i!}\cdot \frac{(t-i)^{c}}{(t-i)!}$
这又是一个ntt的形式,即$H_{1}(x)=\sum_{i=0}^{m}\frac{(-1)^{i}}{i!}x^{i}$和$H_{2}(x)=\sum_{i=0}^{m}\frac{i^{c}}{i!}x^{i}$,那么$(H_{1}\times H_{2})(x)[x^{i}]$即为$t=i$时除以$t!$的结果,将其乘上$t!$并记作$G_{t}$
之后又是一个$G_{t}$和$H_{k}$的卷积,即$G(x)=\sum_{i=0}^{m}G_{i}x^{i}$以及$H(x)=\sum_{i=m}^{2m}H_{i}x^{i}$,那么$(H\times G)(x)[x^{i}]$即为$k+t=i$时的答案,再乘上${n\choose k+t}$后相加即可
当$F(x)$为普通多项式时,不难发现每一次可以独立,且$H$不变,根据卷积的分配律,可以将所有$G_{i}$相加后做一次,即$G_{i}=\sum_{i=0}^{t}(-1)^{i}{t\choose i}F(t-i)$
类似的构造$H_{2}(x)=\sum_{i=0}^{m}\frac{F(i)}{i!}x^{i}$,求出$G_{i}$后与上面的做法相同
两次ntt即可,时间复杂度为$o(m\log m)$,由于此时多项式次数变为$2m+m$次,常数更小,可以通过
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N (1<<22) 4 #define mod 998244353 5 int n,m,ans,fac[N],inv[N],rev[N],a[N],b[N]; 6 int c(int n,int m){ 7 if (n<m)return 0; 8 return 1LL*fac[n]*inv[m]%mod*inv[n-m]%mod; 9 } 10 int ksm(int n,int m){ 11 n=(n+mod)%mod; 12 int s=n,ans=1; 13 while (m){ 14 if (m&1)ans=1LL*ans*s%mod; 15 s=1LL*s*s%mod; 16 m>>=1; 17 } 18 return ans; 19 } 20 void ntt(int *a,int p){ 21 for(int i=0;i<N;i++) 22 if (i<rev[i])swap(a[i],a[rev[i]]); 23 for(int i=2;i<=N;i<<=1){ 24 int s=ksm(3,(mod-1)/i); 25 if (p)s=ksm(s,mod-2); 26 for(int j=0;j<N;j+=i) 27 for(int k=0,ss=1;k<(i>>1);k++,ss=1LL*ss*s%mod){ 28 int x=a[j+k],y=1LL*a[j+k+(i>>1)]*ss%mod; 29 a[j+k]=(x+y)%mod; 30 a[j+k+(i>>1)]=(x+mod-y)%mod; 31 } 32 } 33 if (p){ 34 int s=ksm(N,mod-2); 35 for(int i=0;i<N;i++)a[i]=1LL*a[i]*s%mod; 36 } 37 } 38 int main(){ 39 scanf("%d%d",&n,&m); 40 for(int i=0;i<=m;i++)scanf("%d",&a[i]); 41 fac[0]=inv[0]=inv[1]=1; 42 for(int i=1;i<=3*m;i++)fac[i]=1LL*fac[i-1]*i%mod; 43 for(int i=2;i<=3*m;i++)inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod; 44 for(int i=1;i<=3*m;i++)inv[i]=1LL*inv[i-1]*inv[i]%mod; 45 for(int i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)+((i&1)*(N>>1)); 46 for(int i=0;i<=m;i++){ 47 a[i]=1LL*a[i]*inv[i]%mod; 48 b[i]=inv[i]; 49 if (i&1)b[i]=mod-b[i]; 50 } 51 ntt(a,0); 52 ntt(b,0); 53 for(int i=0;i<N;i++)a[i]=1LL*a[i]*b[i]%mod; 54 ntt(a,1); 55 memset(b,0,sizeof(b)); 56 for(int i=1;i<=2*m;i++){ 57 if (i>m)a[i]=0; 58 else a[i]=1LL*fac[i]*a[i]%mod; 59 b[i]=1LL*c(i,m)*c(m,2*m-i)%mod; 60 } 61 ntt(a,0); 62 ntt(b,0); 63 for(int i=0;i<N;i++)a[i]=1LL*a[i]*b[i]%mod; 64 ntt(a,1); 65 int s=1; 66 for(int i=0;i<=3*m;i++){ 67 ans=(ans+1LL*a[i]*s%mod*inv[i])%mod; 68 s=1LL*s*(n-i)%mod; 69 } 70 printf("%d",ans); 71 }View Code