题目
https://codeforces.com/gym/103119/problem/A
分析
https://fanfansann.blog.csdn.net/article/details/118528285
这篇博客讲得好清楚了。
代码
#include <bits/stdc++.h>
using namespace std;
#define MAXN 100005
#define Ha 998244353
#define GHa 3
long long A_[MAXN<<2], B_[MAXN<<2], C_[MAXN<<2];
int RR[MAXN<<2];
long long a[MAXN];
long long jc[MAXN];
long long ksm(long long x, int k)
{
long long ret=1;
for (; k; x=x*x%Ha,k>>=1)
if (k&1) ret=ret*x%Ha;
return ret;
}
void NTT(int len, long long *A, int type)
{
for (int i=0; i<len; i++)
if (i<RR[i])
swap(A[i], A[RR[i]]);
for (int mid=1; mid<len; mid<<=1) {
long long wn=ksm(GHa, (Ha-1)/(mid<<1));
for (int pos=0; pos<len; pos+=(mid<<1)) {
long long w=1;
for (int k=0; k<mid; k++, w= w*wn%Ha) {
long long x= A[pos+k];
long long y= w*A[pos+mid+k]%Ha;
A[pos+k]=(x+y)%Ha;
A[pos+k+mid]=(x-y+Ha)%Ha;
}
}
}
if (type==1)
return;
for (int i=(len>>1)-1; i>0; i--)
swap(A[i], A[len-i]);
long long inv= ksm(len, Ha-2);
for (int i=0; i<len; i++)
A[i]= A[i]*inv %Ha;
}
void NTT_mul(long long *A, long long *B, long long *C, int n, int m)
{
int len=1, L=0;
while (len<=n+m) len<<=1, L++;
for (int i=n+1; i<len; i++) A[i]=0;
for (int i=m+1; i<len; i++) B[i]=0;
for (int i=0; i<len; i++)
RR[i]= (RR[i>>1] >> 1) | ((i&1) << (L-1));
NTT(len, A, 1);
NTT(len, B, 1);
for (int i=0; i<len; i++)
C[i]= A[i]*B[i] %Ha;
NTT(len, C, -1);
}
void fun(int L, int R, vector<long long> &f)
{
if (L==R) {
f[0]=1;
f[1]=a[L];
return;
}
int M=(L+R)>>1;
int len1=M-L+1, len2=R-M;
vector<long long> f1(len1+3),f2(len2+3);
fun(L,M,f1);
fun(M+1,R,f2);
A_[0]=B_[0]=C_[0]=1;
for (int i=1; i<=len1; i++) A_[i]=f1[i];
for (int i=1; i<=len2; i++) B_[i]=f2[i];
NTT_mul(A_,B_,C_,len1,len2);
for (int i=R-L+1; i>0; i--)
f[i]=C_[i];
/*
printf("[%d %d]\n ",L,R);
for (int i=0; i<=len1; i++)
printf("%lld ",f1[i]);
printf("\n ");
for (int i=0; i<=len2; i++)
printf("%lld ",f2[i]);
printf("\n ");
for (int i=0; i<=R-L+1; i++)
printf("%lld ",f[i]);
printf("\n");
*/
}
void solve()
{
int n;
scanf("%d",&n);
for (int i=1; i<=n; i++) scanf("%lld",&a[i]);
vector<long long> f(n+1);
fun(1, n, f);
long long ans=0;
for (int i=1; i<=n; i++) {
ans+=f[i]*jc[i] %Ha *jc[n-i] %Ha;
ans%=Ha;
}
ans*=ksm(jc[n], Ha-2);
ans%=Ha;
printf("%lld\n",ans);
}
void pre()
{
jc[0]=1;
for (int i=1; i<=MAXN; i++)
jc[i]=jc[i-1]*i %Ha;
}
int main()
{
pre();
int ttt;
scanf("%d",&ttt);
while (ttt--) {
solve();
}
return 0;
}