题目链接:
题目大意:给出四个数$p,n,l,r$,对于$\forall 0\le a\le p-1$,求$l\le x\le r,C_{x}^{n}\%p=a$的$x$的数量。$p<=3000$且保证$p$是质数,$n,l,r<=10^30$。
对于$10\%$的数据,可以直接杨辉三角推。
对于$20\%$的数据,因为$n$是确定的,可以递推出$C_{x+1}^{n}=C_{x}^{n}*\frac{x+1}{x+1-n}$。
对于另外$20\%$的数据,可以枚举$x$然后用$lucas$定理求。
对于另外$30\%$的数据,可以想到将问题转化成小于等于$r$的个数$-$小于等于$l-1$的个数。由$lucas$定理可知,$C_{x}^{n}\ mod\ p=\prod C_{b_{i}}^{a_{i}}\ mod\ p$,其中$a_{i},b_{i}$分别为$n,x$在$p$进制下的第$i$位。那么我们就可以用数位$DP$求,$f[i][j]$代表从最低为开始的前$i$位,每一位的值都不大于$b_{i}$且$\%p=j$的方案数;$g[i][j]$代表从最低为开始的前$i$位,每一位的值任意且$\%p=j$的方案数。设枚举第$i+1$位为$x$,$C_{x}^{a_{i+1}}=k$。那么可以得到$DP$转移方程$g[i+1][jk\ mod\ p]+=g[i][j]$,若$x<b_{i+1}$,则$f[i+1][jk\ mod\ p]+=g[i][j]$,若$x=b_{i+1}$,则$f[i+1][jk\ mod\ p]+=f[i][j]$。时间复杂度为$O(p^2log_{p})$。
对于$100\%$的数据,我们考虑优化上述$DP$,我们拿其中第一个转移方程来说(后两个同理),我们设$h[k]=\sum\limits_{x=0}^{p-1}[C_{x}^{a_{i+1}}==k]$。可以发现转移可以看成是$G[j*k\ mod\ p]=\sum\limits_{j=0}^{p-1}g[j]\sum\limits_{k=0}^{p-1}h[k]$,这和卷积式子很像,但他是乘法卷积,我们想办法将它变成加法卷积:因为$p$是质数,那么$p$一定有原根(设为$g$),也就是说对于任意$j$,其中$1\le j\le p-1$都有指标。我们设它的指标为$ind(j)$,那么$j*k\ mod\ p$就能转化为$g^{(ind(j)+ind(k))\ mod\ (p-1)}\ mod\ p$。这样我们就能用$FFT$或$NTT$来加速$DP$了,但注意到$0$没有指标,我们在转移时先忽略$0$,在最后输出答案时用总个数减掉其他答案就是$\%p=0$的个数了。注意原根从$1$开始枚举。至于$10^{30}$可以用$\_\_int128$存。时间复杂度为$O(plog_{p}^2)$。
两种写法,读者自选。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long typedef __int128 int128; #define MOD 998244353 using namespace std; int p; int128 l,r,n; int pr[10]; int cnt; int G; int mx; ll sum; int ind[30010]; ll f[100000]; ll g[100000]; ll h[100000]; int a[200]; int b[200]; ll ans[30010]; int c[200][30010]; int mask=1; ll s[100000]; char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int read_() { int x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } int128 read() { int128 x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } ll quick(int x,int y,int mod) { ll res=1ll; while(y) { if(y&1) { res=res*x%mod; } y>>=1; x=1ll*x*x%mod; } return res; } void NTT(ll *a,int len,int miku) { for(int k=0,i=0;i<len;i++) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int k=2;k<=len;k<<=1) { int t=k>>1; int x=quick(3,(MOD-1)/k,MOD); if(miku==-1) { x=quick(x,MOD-2,MOD); } for(int i=0;i<len;i+=k) { ll w=1; for(int j=i;j<i+t;j++) { ll tmp=a[j+t]*w%MOD; a[j+t]=(a[j]-tmp+MOD)%MOD; a[j]=(a[j]+tmp)%MOD; w=w*x%MOD; } } } if(miku==-1) { for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++) { a[i]=a[i]*t%MOD; } } } void solve(int128 num) { memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(h,0,sizeof(h)); memset(a,0,sizeof(a)); int res=0; for(int i=1;num;i++) { a[i]=num%p; num/=p; res=max(res,i); } mx=max(res,mx); g[0]=f[0]=1ll; for(int k=1;k<=mx;k++) { memset(h,0,sizeof(h)); memset(s,0,sizeof(s)); NTT(g,mask,1); NTT(f,mask,1); if(a[k]>=b[k]) { h[ind[c[k][a[k]]]]++; NTT(h,mask,1); for(int i=0;i<mask;i++) { s[i]+=1ll*h[i]*f[i]%MOD; s[i]%=MOD; } NTT(h,mask,-1); h[ind[c[k][a[k]]]]--; } for(int i=b[k];i<a[k];i++) { h[ind[c[k][i]]]++; } NTT(h,mask,1); for(int i=0;i<mask;i++) { s[i]+=1ll*h[i]*g[i]%MOD; s[i]%=MOD; } NTT(h,mask,-1); NTT(s,mask,-1); memset(f,0,sizeof(f)); for(int i=0;i<mask;i++) { f[i%(p-1)]+=s[i]; f[i%(p-1)]%=MOD; } for(int i=max(b[k],a[k]);i<p;i++) { h[ind[c[k][i]]]++; } NTT(h,mask,1); for(int i=0;i<mask;i++) { s[i]=1ll*h[i]*g[i]%MOD; } NTT(s,mask,-1); memset(g,0,sizeof(g)); for(int i=0;i<mask;i++) { g[i%(p-1)]+=s[i]; g[i%(p-1)]%=MOD; } } } int main() { p=read_(),n=read(),l=read(),r=read(); l--; int s=p-1; while(mask<(p<<1)) { mask<<=1; } for(int i=2;i*i<=s;i++) { if(s%i==0) { pr[++cnt]=i; while(s%i==0) { s/=i; } } } if(s!=1) { pr[++cnt]=s; } for(int i=1;i<p;i++) { bool flag=true; for(int j=1;j<=cnt;j++) { if(quick(i,(p-1)/pr[j],p)==1) { flag=false; break; } } if(flag) { G=i; break; } } sum=1ll; for(int i=0;i<p-1;i++) { ind[sum]=i; sum*=G,sum%=p; } int128 N=n; for(int i=1;N;i++) { b[i]=N%p; N/=p; mx=max(mx,i); } for(int i=1;i<=mx;i++) { for(int j=0;j<b[i];j++) { c[i][j]=0; } sum=1ll; for(int j=b[i];j<p;j++) { c[i][j]=sum; sum*=(j+1),sum%=p; sum*=quick(j+1-b[i],p-2,p),sum%=p; } } solve(l); for(int i=0;i<p-1;i++) { ans[quick(G,i,p)]-=f[i]; } for(int i=1;i<=p-1;i++) { ans[i]=(ans[i]%MOD+MOD)%MOD; } solve(r); for(int i=0;i<p-1;i++) { ans[quick(G,i,p)]+=f[i]; } for(int i=1;i<=p-1;i++) { ans[i]%=MOD; } ans[0]=(r-l)%MOD; for(int i=1;i<p;i++) { ans[0]-=ans[i]; ans[0]=(ans[0]%MOD+MOD)%MOD; } for(int i=0;i<p;i++) { printf("%lld\n",ans[i]); } }
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long typedef __int128 int128; #define MOD 998244353 using namespace std; int p; int128 l,r,n; int pr[10]; int cnt; int G; int mx; ll sum; int ind[30010]; ll f[100000]; ll g[100000]; ll A[100000]; ll B[100000]; ll C[100000]; int a[200]; int b[200]; ll ans[30010]; int c[200][30010]; int mask=1; int s[100000]; int pw[300010]; int fac[300010]; int inv[300010]; char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int read_() { int x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } int128 read() { int128 x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } ll quick(int x,int y,int mod) { ll res=1ll; while(y) { if(y&1) { res=res*x%mod; } y>>=1; x=1ll*x*x%mod; } return res; } void NTT(ll *a,int len,int miku) { for(int k=0,i=0;i<len;i++) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int k=2;k<=len;k<<=1) { int t=k>>1; int x=quick(3,(MOD-1)/k,MOD); if(miku==-1) { x=quick(x,MOD-2,MOD); } for(int i=0;i<len;i+=k) { ll w=1; for(int j=i;j<i+t;j++) { ll tmp=a[j+t]*w%MOD; a[j+t]=(a[j]-tmp+MOD)%MOD; a[j]=(a[j]+tmp)%MOD; w=w*x%MOD; } } } if(miku==-1) { for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++) { a[i]=a[i]*t%MOD; } } } void solve(int128 num) { memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(a,0,sizeof(a)); int res=0; for(int i=1;num;i++) { a[i]=num%p; num/=p; res=max(res,i); } mx=max(res,mx); g[1]=f[1]=1ll; for(int k=1;k<=mx;k++) { memset(A,0,sizeof(A)); memset(B,0,sizeof(B)); for(int i=b[k];i<p;i++) { if(c[k][i]) { A[ind[c[k][i]]]++; } } for(int i=1;i<p;i++) { B[ind[i]]+=g[i]; B[ind[i]]%=MOD; } NTT(A,mask,1); NTT(B,mask,1); for(int i=0;i<mask;i++) { C[i]=A[i]*B[i]%MOD; } NTT(C,mask,-1); memset(g,0,sizeof(g)); for(int i=0;i<mask;i++) { (g[quick(G,i%(p-1),p)]+=C[i])%=MOD; } memset(A,0,sizeof(A)); for(int i=b[k];i<a[k];i++) { if(c[k][i]) { A[ind[c[k][i]]]++; } } NTT(A,mask,1); for(int i=0;i<mask;i++) { C[i]=A[i]*B[i]%MOD; } NTT(C,mask,-1); memset(s,0,sizeof(s)); for(int i=0;i<mask;i++) { (s[quick(G,i%(p-1),p)]+=C[i])%=MOD; } if(c[k][a[k]]) { for(int i=1;i<p;i++) { (s[c[k][a[k]]*i%p]+=f[i])%=MOD;; } } for(int i=1;i<p;i++) { f[i]=s[i]; } } } int get_ori(int p) { int s=p-1; for(int i=2;i*i<=s;i++) { if(s%i==0) { pr[++cnt]=i; while(s%i==0) { s/=i; } } } if(s!=1) { pr[++cnt]=s; } for(int i=1;i<p;i++) { bool flag=true; for(int j=1;j<=cnt;j++) { if(quick(i,(p-1)/pr[j],p)==1) { flag=false; break; } } if(flag) { return i; break; } } } int main() { p=read_(),n=read(),l=read(),r=read(); while(mask<(p<<1)) { mask<<=1; } G=get_ori(p); pw[0]=1ll; for(int i=1;i<p;i++) { pw[i]=pw[i-1]*G%p; } sum=1ll; for(int i=0;i<p-1;i++) { ind[sum]=i; sum*=G,sum%=p; } int128 N=n; for(int i=1;N;i++) { b[i]=N%p; N/=p; mx=max(mx,i); } fac[0]=inv[0]=1ll; for(int i=1;i<p;i++) { fac[i]=fac[i-1]*i%p; } inv[p-1]=quick(fac[p-1],p-2,p); for(int i=p-2;i>=1;i--) { inv[i]=inv[i+1]*(i+1)%p; } for(int i=1;i<=120;i++) { for(int j=b[i];j<p;j++) { c[i][j]=fac[j]*inv[j-b[i]]%p*inv[b[i]]%p; } } solve(r); for(int i=1;i<p;i++) { ans[i]=f[i]; } solve(l-1); for(int i=1;i<p;i++) { ans[i]=((ans[i]-f[i])%MOD+MOD)%MOD; } ans[0]=(r-l+1)%MOD; for(int i=1;i<p;i++) { ans[0]=((ans[0]-ans[i])%MOD+MOD)%MOD; } for(int i=0;i<p;i++) { printf("%lld\n",ans[i]); } }