[UOJ86]mx的组合数——NTT+数位DP+原根与指标+卢卡斯定理

题目链接:

[UOJ86]mx的组合数

题目大意:给出四个数$p,n,l,r$,对于$\forall 0\le a\le p-1$,求$l\le x\le r,C_{x}^{n}\%p=a$的$x$的数量。$p<=3000$且保证$p$是质数,$n,l,r<=10^30$。

对于$10\%$的数据,可以直接杨辉三角推。
对于$20\%$的数据,因为$n$是确定的,可以递推出$C_{x+1}^{n}=C_{x}^{n}*\frac{x+1}{x+1-n}$。
对于另外$20\%$的数据,可以枚举$x$然后用$lucas$定理求。
对于另外$30\%$的数据,可以想到将问题转化成小于等于$r$的个数$-$小于等于$l-1$的个数。由$lucas$定理可知,$C_{x}^{n}\ mod\ p=\prod C_{b_{i}}^{a_{i}}\ mod\ p$,其中$a_{i},b_{i}$分别为$n,x$在$p$进制下的第$i$位。那么我们就可以用数位$DP$求,$f[i][j]$代表从最低为开始的前$i$位,每一位的值都不大于$b_{i}$且$\%p=j$的方案数;$g[i][j]$代表从最低为开始的前$i$位,每一位的值任意且$\%p=j$的方案数。设枚举第$i+1$位为$x$,$C_{x}^{a_{i+1}}=k$。那么可以得到$DP$转移方程$g[i+1][jk\ mod\ p]+=g[i][j]$,若$x<b_{i+1}$,则$f[i+1][jk\ mod\ p]+=g[i][j]$,若$x=b_{i+1}$,则$f[i+1][jk\ mod\ p]+=f[i][j]$。时间复杂度为$O(p^2log_{p})$。
对于$100\%$的数据,我们考虑优化上述$DP$,我们拿其中第一个转移方程来说(后两个同理),我们设$h[k]=\sum\limits_{x=0}^{p-1}[C_{x}^{a_{i+1}}==k]$。可以发现转移可以看成是$G[j*k\ mod\ p]=\sum\limits_{j=0}^{p-1}g[j]\sum\limits_{k=0}^{p-1}h[k]$,这和卷积式子很像,但他是乘法卷积,我们想办法将它变成加法卷积:因为$p$是质数,那么$p$一定有原根(设为$g$),也就是说对于任意$j$,其中$1\le j\le p-1$都有指标。我们设它的指标为$ind(j)$,那么$j*k\ mod\ p$就能转化为$g^{(ind(j)+ind(k))\ mod\ (p-1)}\ mod\ p$。这样我们就能用$FFT$或$NTT$来加速$DP$了,但注意到$0$没有指标,我们在转移时先忽略$0$,在最后输出答案时用总个数减掉其他答案就是$\%p=0$的个数了。注意原根从$1$开始枚举。至于$10^{30}$可以用$\_\_int128$存。时间复杂度为$O(plog_{p}^2)$。

两种写法,读者自选。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
typedef __int128 int128;
#define MOD 998244353
using namespace std;
int p;
int128 l,r,n;
int pr[10];
int cnt;
int G;
int mx;
ll sum;
int ind[30010];
ll f[100000];
ll g[100000];
ll h[100000];
int a[200];
int b[200];
ll ans[30010];
int c[200][30010];
int mask=1;
ll s[100000];
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int read_()
{
	int x=0; 
	char c=nc(); 
	while(c<48) 
	{
		c=nc(); 
	}
	while(c>47) 
	{
		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
	}
	return x;
}
int128 read() 
{
	int128 x=0; 
	char c=nc(); 
	while(c<48) 
	{
		c=nc(); 
	}
	while(c>47) 
	{
		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
	}
	return x;
}
ll quick(int x,int y,int mod)
{
	ll res=1ll;
	while(y)
	{
		if(y&1)
		{
			res=res*x%mod;
		}
		y>>=1;
		x=1ll*x*x%mod;
	}
	return res;
}
void NTT(ll *a,int len,int miku)
{
	for(int k=0,i=0;i<len;i++)
	{
		if(i>k)
		{
			swap(a[i],a[k]);
		}
		for(int j=len>>1;(k^=j)<j;j>>=1);
	}
	for(int k=2;k<=len;k<<=1)
	{
		int t=k>>1;
		int x=quick(3,(MOD-1)/k,MOD);
		if(miku==-1)
		{
			x=quick(x,MOD-2,MOD);
		}
		for(int i=0;i<len;i+=k)
		{
			ll w=1;
			for(int j=i;j<i+t;j++)
			{
				ll tmp=a[j+t]*w%MOD;
				a[j+t]=(a[j]-tmp+MOD)%MOD;
				a[j]=(a[j]+tmp)%MOD;
				w=w*x%MOD;
			}
		}
	}
	if(miku==-1)
	{
		for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++)
		{
			a[i]=a[i]*t%MOD;
		}
	}
}
void solve(int128 num)
{
	memset(f,0,sizeof(f));
	memset(g,0,sizeof(g));
	memset(h,0,sizeof(h));
	memset(a,0,sizeof(a));
	int res=0;
	for(int i=1;num;i++)
	{
		a[i]=num%p;
		num/=p;
		res=max(res,i);
	}
	mx=max(res,mx);
	g[0]=f[0]=1ll;
	for(int k=1;k<=mx;k++)
	{
		memset(h,0,sizeof(h));
		memset(s,0,sizeof(s));
		NTT(g,mask,1);
		NTT(f,mask,1);
		if(a[k]>=b[k])
		{
			h[ind[c[k][a[k]]]]++;
			NTT(h,mask,1);
			for(int i=0;i<mask;i++)
			{
				s[i]+=1ll*h[i]*f[i]%MOD;
				s[i]%=MOD;
			}
			NTT(h,mask,-1);
			h[ind[c[k][a[k]]]]--;
		}
		for(int i=b[k];i<a[k];i++)
		{
			h[ind[c[k][i]]]++;
		}
		NTT(h,mask,1);
		for(int i=0;i<mask;i++)
		{
			s[i]+=1ll*h[i]*g[i]%MOD;
			s[i]%=MOD;
		}
		NTT(h,mask,-1);
		NTT(s,mask,-1);
		memset(f,0,sizeof(f));
		for(int i=0;i<mask;i++)
		{
			f[i%(p-1)]+=s[i];
			f[i%(p-1)]%=MOD;
		}
		for(int i=max(b[k],a[k]);i<p;i++)
		{
			h[ind[c[k][i]]]++;
		}
		NTT(h,mask,1);
		for(int i=0;i<mask;i++)
		{
			s[i]=1ll*h[i]*g[i]%MOD;
		}
		NTT(s,mask,-1);
		memset(g,0,sizeof(g));
		for(int i=0;i<mask;i++)
		{
			g[i%(p-1)]+=s[i];
			g[i%(p-1)]%=MOD;
		}
	}
}
int main()
{
	p=read_(),n=read(),l=read(),r=read();
	l--;
	int s=p-1;
	while(mask<(p<<1))
	{
		mask<<=1;
	}
	for(int i=2;i*i<=s;i++)
	{
		if(s%i==0)
		{
			pr[++cnt]=i;
			while(s%i==0)
			{
				s/=i;
			}
		}
	}
	if(s!=1)
	{
		pr[++cnt]=s;
	}
	for(int i=1;i<p;i++)
	{
		bool flag=true;
		for(int j=1;j<=cnt;j++)
		{
			if(quick(i,(p-1)/pr[j],p)==1)
			{
				flag=false;
				break;
			}
		}
		if(flag)
		{
			G=i;
			break;
		}
	}
	sum=1ll;
	for(int i=0;i<p-1;i++)
	{
		ind[sum]=i;
		sum*=G,sum%=p;
	}
	int128 N=n;
	for(int i=1;N;i++)
	{
		b[i]=N%p;
		N/=p;
		mx=max(mx,i);
	}
	for(int i=1;i<=mx;i++)
	{
		for(int j=0;j<b[i];j++)
		{
			c[i][j]=0;
		}
		sum=1ll;
		for(int j=b[i];j<p;j++)
		{
			c[i][j]=sum;
			sum*=(j+1),sum%=p;
			sum*=quick(j+1-b[i],p-2,p),sum%=p;
		}
	}
	solve(l);
	for(int i=0;i<p-1;i++)
	{
		ans[quick(G,i,p)]-=f[i];
	}
	for(int i=1;i<=p-1;i++)
	{
		ans[i]=(ans[i]%MOD+MOD)%MOD;
	}
	solve(r);
	for(int i=0;i<p-1;i++)
	{
		ans[quick(G,i,p)]+=f[i];
	}
	for(int i=1;i<=p-1;i++)
	{
		ans[i]%=MOD;
	}
	ans[0]=(r-l)%MOD;
	for(int i=1;i<p;i++)
	{
		ans[0]-=ans[i];
		ans[0]=(ans[0]%MOD+MOD)%MOD;
	}
	for(int i=0;i<p;i++)
	{
		printf("%lld\n",ans[i]);
	}
}
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<vector>
#include<bitset>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
typedef __int128 int128;
#define MOD 998244353
using namespace std;
int p;
int128 l,r,n;
int pr[10];
int cnt;
int G;
int mx;
ll sum;
int ind[30010];
ll f[100000];
ll g[100000];
ll A[100000];
ll B[100000];
ll C[100000];
int a[200];
int b[200];
ll ans[30010];
int c[200][30010];
int mask=1;
int s[100000];
int pw[300010];
int fac[300010];
int inv[300010];
char *p1,*p2,buf[100000];
#define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
int read_()
{
	int x=0; 
	char c=nc(); 
	while(c<48) 
	{
		c=nc(); 
	}
	while(c>47) 
	{
		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
	}
	return x;
}
int128 read() 
{
	int128 x=0; 
	char c=nc(); 
	while(c<48) 
	{
		c=nc(); 
	}
	while(c>47) 
	{
		x=(((x<<2)+x)<<1)+(c^48),c=nc(); 
	}
	return x;
}
ll quick(int x,int y,int mod)
{
	ll res=1ll;
	while(y)
	{
		if(y&1)
		{
			res=res*x%mod;
		}
		y>>=1;
		x=1ll*x*x%mod;
	}
	return res;
}
void NTT(ll *a,int len,int miku)
{
	for(int k=0,i=0;i<len;i++)
	{
		if(i>k)
		{
			swap(a[i],a[k]);
		}
		for(int j=len>>1;(k^=j)<j;j>>=1);
	}
	for(int k=2;k<=len;k<<=1)
	{
		int t=k>>1;
		int x=quick(3,(MOD-1)/k,MOD);
		if(miku==-1)
		{
			x=quick(x,MOD-2,MOD);
		}
		for(int i=0;i<len;i+=k)
		{
			ll w=1;
			for(int j=i;j<i+t;j++)
			{
				ll tmp=a[j+t]*w%MOD;
				a[j+t]=(a[j]-tmp+MOD)%MOD;
				a[j]=(a[j]+tmp)%MOD;
				w=w*x%MOD;
			}
		}
	}
	if(miku==-1)
	{
		for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++)
		{
			a[i]=a[i]*t%MOD;
		}
	}
}
void solve(int128 num)
{
	memset(f,0,sizeof(f));
	memset(g,0,sizeof(g));
	memset(a,0,sizeof(a));
	int res=0;
	for(int i=1;num;i++)
	{
		a[i]=num%p;
		num/=p;
		res=max(res,i);
	}
	mx=max(res,mx);
	g[1]=f[1]=1ll;
	for(int k=1;k<=mx;k++)
	{
		memset(A,0,sizeof(A));
		memset(B,0,sizeof(B));
		for(int i=b[k];i<p;i++)
		{
			if(c[k][i])
			{
				A[ind[c[k][i]]]++;
			}
		}
		for(int i=1;i<p;i++)
		{	
			B[ind[i]]+=g[i];
			B[ind[i]]%=MOD;
		}
		NTT(A,mask,1);
		NTT(B,mask,1);
		for(int i=0;i<mask;i++)
		{
			C[i]=A[i]*B[i]%MOD;
		}
		NTT(C,mask,-1);
		memset(g,0,sizeof(g));
		for(int i=0;i<mask;i++)
		{
			(g[quick(G,i%(p-1),p)]+=C[i])%=MOD;
		}
		memset(A,0,sizeof(A));
		for(int i=b[k];i<a[k];i++)
		{
			if(c[k][i])
			{
				A[ind[c[k][i]]]++;
			}
		}
		NTT(A,mask,1);
		for(int i=0;i<mask;i++)
		{
			C[i]=A[i]*B[i]%MOD;
		}
		NTT(C,mask,-1);
		memset(s,0,sizeof(s));
		for(int i=0;i<mask;i++)
		{
			(s[quick(G,i%(p-1),p)]+=C[i])%=MOD;
		}
		if(c[k][a[k]])
		{
			for(int i=1;i<p;i++)
			{
				(s[c[k][a[k]]*i%p]+=f[i])%=MOD;;
			}
		}
		for(int i=1;i<p;i++)
		{
			f[i]=s[i];
		}
	}
}
int get_ori(int p)
{
	int s=p-1;
	for(int i=2;i*i<=s;i++)
	{
		if(s%i==0)
		{
			pr[++cnt]=i;
			while(s%i==0)
			{
				s/=i;
			}
		}
	}
	if(s!=1)
	{
		pr[++cnt]=s;
	}
	for(int i=1;i<p;i++)
	{
		bool flag=true;
		for(int j=1;j<=cnt;j++)
		{
			if(quick(i,(p-1)/pr[j],p)==1)
			{
				flag=false;
				break;
			}
		}
		if(flag)
		{
			return i;
			break;
		}
	}
}
int main()
{
	p=read_(),n=read(),l=read(),r=read();
	while(mask<(p<<1))
	{
		mask<<=1;
	}
	G=get_ori(p);
	pw[0]=1ll;
	for(int i=1;i<p;i++)
	{
		pw[i]=pw[i-1]*G%p;
	}
	sum=1ll;
	for(int i=0;i<p-1;i++)
	{
		ind[sum]=i;
		sum*=G,sum%=p;
	}
	int128 N=n;
	for(int i=1;N;i++)
	{
		b[i]=N%p;
		N/=p;
		mx=max(mx,i);
	}
	fac[0]=inv[0]=1ll;
	for(int i=1;i<p;i++)
	{
		fac[i]=fac[i-1]*i%p;
	}
	inv[p-1]=quick(fac[p-1],p-2,p);
	for(int i=p-2;i>=1;i--)
	{
		inv[i]=inv[i+1]*(i+1)%p;
	}
	for(int i=1;i<=120;i++)
	{
		for(int j=b[i];j<p;j++)
		{
			c[i][j]=fac[j]*inv[j-b[i]]%p*inv[b[i]]%p;
		}
	}
	solve(r);
	for(int i=1;i<p;i++)
	{
		ans[i]=f[i];
	}
	solve(l-1);
	for(int i=1;i<p;i++)
	{
		ans[i]=((ans[i]-f[i])%MOD+MOD)%MOD;
	}
	ans[0]=(r-l+1)%MOD;
	for(int i=1;i<p;i++)
	{
		ans[0]=((ans[0]-ans[i])%MOD+MOD)%MOD;
	}
	for(int i=0;i<p;i++)
	{
		printf("%lld\n",ans[i]);
	}
}
上一篇:Jmeter学习笔记


下一篇:w3c上的SQL 教程---基本语法 语句学习