多项式模板

#include<bits/stdc++.h>
using namespace std;
const int N=(1<<21)+10;
const int mod=998244353;
const int gg=3;
const int giv=(mod+1)/3;
typedef long long ll;

namespace iobuff{
	const int LEN=1000000;
	char in[LEN+5], out[LEN+5];
	char *pin=in, *pout=out, *ed=in, *eout=out+LEN;
	inline char gc(void)
	{
		#ifdef LOCAL
		return getchar();
		#endif
		return pin==ed&&(ed=(pin=in)+fread(in, 1, LEN, stdin), ed==in)?EOF:*pin++;
	}
	inline void pc(char c)
	{
		pout==eout&&(fwrite(out, 1, LEN, stdout), pout=out);
		(*pout++)=c;
	}
	inline void flush()
	{ fwrite(out, 1, pout-out, stdout), pout=out; }
	template<typename T> inline void read(T &x)
	{
		static int f;
		static char c;
		c=gc(), f=1, x=0;
		while(c<'0'||c>'9') f=(c=='-'?-1:1), c=gc();
		while(c>='0'&&c<='9') x=10*x+c-'0', c=gc();
		x*=f;
	}
	template<typename T> inline void putint(T x, char div)
	{
		static char s[15];
		static int top;
		top=0;
		x<0?pc('-'), x=-x:0;
		while(x) s[top++]=x%10, x/=10;
		!top?pc('0'), 0:0;
		while(top--) pc(s[top]+'0');
		pc(div);
	}
}
using namespace iobuff;

namespace Math{
	int inv[N],base[N],jc[N],tp=2;
	inline void init(int n){
		if(tp==2) inv[0]=inv[1]=base[0]=base[1]=jc[0]=jc[1]=1;
		for(tp;tp<=n;++tp)
			inv[tp]=1ll*(mod-mod/tp)*inv[mod%tp]%mod; 
	}
	inline int ksm(int a,ll b){
		int ret=1;
		for(;b;a=1ll*a*a%mod,b>>=1) (b&1)&&(ret=1ll*ret*a%mod);
		return ret;
	}
	inline int add(int x,int y){return (x+y>=mod)?x+y-mod:x+y;}
	inline int dec(int x,int y){return (x-y<0)?x-y+mod:x-y;}
}
using namespace Math;

namespace Container{
	struct poly{
		vector<int>v;
		inline int& operator[](int x){while(x>=v.size())v.push_back(0);return v[x];}
		inline poly(int x=0):v(1){v[0]=x;}
		inline int size(){return v.size();}
		inline void resize(int x){v.resize(x);}
		inline void mem(int l,int r,int x){fill(v.begin()+l,v.begin()+r+1,x);}
		
	};
	inline poly operator +(poly x,poly y){
		int mx=max(x.size(),y.size());
		for(int i=0;i<mx;++i) x[i]=add(x[i],y[i]);
		return x;
	}
	inline poly operator -(poly x,poly y){
		int mx=max(x.size(),y.size());
		for(int i=0;i<mx;++i) x[i]=dec(x[i],y[i]);
		return x;
	}
	inline poly operator *(poly x,poly y){
		int mx=max(x.size(),y.size());
		for(int i=0;i<mx;++i) x[i]=1ll*x[i]*y[i]%mod;
		return x;
	}
	inline poly operator *(poly x,int y){
		for(int i=0;i<x.size();++i) x[i]=1ll*x[i]*y%mod;
		return x;
	}
}
using namespace Container;

namespace basic{
	int r[N],Wn[N];
	inline void NTT(int lim,poly& f,int tp){
		for(int i=0;i<lim;++i) if(i<r[i]) swap(f[i],f[r[i]]);
		for(int mid=1;mid<lim;mid<<=1){
			int len=mid<<1,wn=ksm(tp==1?gg:giv,(mod-1)/len);
			Wn[0]=1;for(int i=1;i<mid;++i) Wn[i]=1ll*Wn[i-1]*wn%mod;
			for(int l=0;l+len-1<lim;l+=len){
				for(int k=l;k<=l+mid-1;++k){
					int w1=f[k],w2=1ll*Wn[k-l]*f[k+mid]%mod;
					f[k]=add(w1,w2);f[k+mid]=dec(w1,w2);
				}
			}
		}
	}
	inline poly poly_mul(int n,int m,poly f,poly g){
		int lim=1,len=0;
		while(lim<(n+m)) lim<<=1,len++;
		for(int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(len-1));
		if(f.size()>n) f.mem(n,min(lim,f.size())-1,0);
		if(g.size()>m) g.mem(m,min(lim,g.size())-1,0);
		NTT(lim,f,1);NTT(lim,g,1);
		for(int i=0;i<lim;++i) f[i]=1ll*f[i]*g[i]%mod;
		NTT(lim,f,-1);
		int iv=ksm(lim,mod-2);
		for(int i=0;i<lim;++i) f[i]=1ll*f[i]*iv%mod;
		return f;
	}
	inline void getdao(int n,poly& f){
		for(int i=1;i<n;++i) f[i-1]=1ll*i*f[i]%mod;
		f[n-1]=0;
	}
	inline void jifen(int n,poly& f){
		init(n);
		for(int i=n-1;i;--i) f[i]=1ll*inv[i]*f[i-1]%mod;
		f[0]=0;
	}
	inline void get(int x,poly& f){
		for(int i=0;i<x;++i) read(f[i]);
	}
	inline void print(int x,poly f){
		for(int i=0;i<x;++i) putint(f[i],i<x-1?' ':'\n');
	}
}
using namespace basic;

namespace Cipolla{
	int I,fl=0;
	struct pt{
		int a,b;
		pt(int _a=0,int _b=0){a=_a;b=_b;}
	};
	inline pt operator *(pt x,pt y){
		pt ret;
		ret.a=add(1ll*x.a*y.a%mod,1ll*x.b*y.b%mod*I%mod);
		ret.b=add(1ll*x.a*y.b%mod,1ll*x.b*y.a%mod);
		return ret;
	}
	inline bool check(int x){
		return ksm(x,(mod-1)/2)==1;
	}
	inline int random(){
		return 1ll*rand()*rand()%mod;
	}
	inline pt qpow(pt a,int b){
		pt ret=pt(1,0);
		for(;b;a=a*a,b>>=1) if(b&1) ret=ret*a;
		return ret;
	}
	inline int cipolla(int n){
		if(!fl) srand(time(0)),fl=1;
		if(!check(n)) return 0;
		int a=random();
		while(!a||check(dec(1ll*a*a%mod,n))) a=random();
		I=dec(1ll*a*a%mod,n);
		int ans=qpow(pt(a,1),(mod+1)/2).a;
		return min(ans,mod-ans);
	}
}
using namespace Cipolla;

namespace Poly{
	inline poly getinv(int n,poly f){
		if(n==1){poly g;g[0]=ksm(f[0],mod-2);return g;}
		poly g=getinv(n+1>>1,f);poly p=g;
		
		int lim=1,len=0;
		while(lim<(n<<1)) lim<<=1,len++;
		for(int i=0;i<lim;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(len-1));
		if(f.size()>n) f.mem(n,min(lim,f.size())-1,0);
		if(p.size()>n) p.mem(n,min(lim,g.size())-1,0);
		
		NTT(lim,f,1);NTT(lim,p,1);
		for(int i=0;i<lim;++i) f[i]=1ll*f[i]*p[i]%mod*p[i]%mod;
		NTT(lim,f,-1);
		int iv=ksm(lim,mod-2);
		poly h;
		for(int i=0;i<n;++i) h[i]=dec(2ll*g[i]%mod,1ll*f[i]*iv%mod);
		return h;
	} 
	inline poly getln(int n,poly f){
		poly g=getinv(n,f);getdao(n,f);
		poly h=poly_mul(n,n,f,g);
		jifen(n,h);
		return h;
	}
	inline poly getexp(int n,poly f){
		if(n==1){poly g;g[0]=1;return g;}
		poly g=getexp(n+1>>1,f),B=getln(n,g);
		for(int i=0;i<n;++i) B[i]=dec((i==0),dec(B[i],f[i]));
		return poly_mul(n,n,g,B);
	}
	inline poly Pow(int n,int k,poly f){
		poly g=getln(n,f);
		for(int i=0;i<n;++i) g[i]=1ll*g[i]*k%mod;
		return getexp(n,g);
	}
	inline poly Sqrt(int n,poly f){
		if(n==1){poly g;g[0]=cipolla(f[0]);return g;}
		poly g=Sqrt((n+1)>>1,f);
		poly p=getinv(n,g);
		poly h=poly_mul(n,n,p,f);
		for(int i=0;i<n;++i) h[i]=1ll*((mod+1)/2)*add(h[i],g[i])%mod;
		return h;		
	}
	inline poly rev(int n,poly f){
		for(int i=0;i<(n>>1);++i) swap(f[i],f[n-i-1]);
		return f;
	}
	inline poly divide(int n,int m,poly A,poly B){
		A=rev(n,A);B=rev(m,B);
		B=getinv(n-m+1,B);
		poly C=poly_mul(n-m+1,n-m+1,A,B);
		return rev(n-m+1,C);
	}
	inline poly Mod(int n,int m,poly A,poly B,poly C){
		B=poly_mul(m,n-m+1,B,C);
		for(int i=0;i<m-1;++i) A[i]=dec(A[i],B[i]);
		return A;
	}
	inline poly Mod(int n,int m,poly A,poly B){
		poly C=divide(n,m,A,B);
		return Mod(n,m,A,B,C);
	}
	inline poly Sin(int n,poly f){
		int II=ksm(gg,(mod-1)/4);
		poly g1=getexp(n,f*II),g2=getexp(n,f*(mod-II));
		return (g1-g2)*ksm(add(II,II),mod-2);
	}
	inline poly Cos(int n,poly f){
		int II=ksm(gg,(mod-1)/4);
		poly g1=getexp(n,f*II),g2=getexp(n,f*(mod-II));
		return (g1+g2)*((mod+1)/2);
	}
} 
using namespace Poly;

int n,m,a[N],ans[N];poly f,g;
namespace fastpow{
	int flag,k,kk,ne;
	poly f,ret;
	char s[N];
	inline int readpow1(){
		long long x=0,f=1;int len=strlen(s+1);
		for(int i=1;i<=len;++i){char ch=s[i];x=((x<<3)+(x<<1)+(ch^48));if(x>=n) flag=1;while(x>=mod) x-=mod;}
		return f==-1?mod-x:x;
	}
	inline int readpow2(){
		long long x=0,f=1;int m=mod-1,len=strlen(s+1);
		for(int i=1;i<=len;++i){char ch=s[i];x=((x<<3)+(x<<1)+(ch^48));while(x>=m) x-=m;}
		return f==-1?m-x:x;
	}
	inline int init(int &n,poly& f){
		int now=0;
		while(f[now]==0) ++now;
		int iv=ksm(f[now],mod-2),ans=ksm(f[now],kk);
		for(int i=now;i<n;++i) f[i-now]=1ll*f[i]*iv%mod;
		long long x=1ll*now*k;
		if(x>=n||(flag&&now!=0)){
			for(int i=0;i<n;++i) ret[i]=0;
			return -1;
		} 
		for(int i=0;i<x;++i) ret[i]=0;
		ne=x;
		n-=now;
		return ans;
	}//init:将f的最低非零系数变为1 
	inline poly notone_pow(int n,int k,poly f){
		ne=n;kk=k;int t=n;int p=init(n,f);
		if(p==-1) return ret;
		f=Poly::Pow(n,k,f);
		for(int i=ne;i<t;++i) ret[i]=f[i-ne];
		return ret;
	}
}

namespace DCFFT{
	inline void solve(int l,int r){
		if(l>=r) return ;
		int mid=(l+r)>>1;
		solve(l,mid);
		poly f;for(int i=l;i<=mid;++i) f[i-l]=ans[i];
		poly h=poly_mul(mid-l+1,r-l+1,f,g);
		for(int i=mid+1;i<=r;++i) ans[i]=add(ans[i],h[i-l]); 
		solve(mid+1,r);
	}
} 
//Tip:上面所有的n都表示n-1次多项式 

namespace lagrange{
	const int M=64010;
	poly mul[M<<2];
	#define lc (p<<1)
	#define rc (p<<1|1)
	inline void init(int p,int l,int r){
		if(l==r){mul[p][0]=mod-a[l];mul[p][1]=1;return ;}
		int mid=(l+r)>>1;
		init(lc,l,mid);init(rc,mid+1,r);
		mul[p]=poly_mul(mid-l+2,r-mid+1,mul[lc],mul[rc]);
	}
	inline void multi_query(int p,int l,int r,poly f){
		if(l==r){ans[l]=f[0];return ;}
		int mid=(l+r)>>1;
		double st=clock();
		poly B=Mod(r-l+1,mid-l+2,f,mul[lc]);
		multi_query(lc,l,mid,B);
		multi_query(rc,mid+1,r,Mod(r-l+1,r-mid+1,f,mul[rc]));
//		if(clock()-st>1) cout<<l<<" "<<r<<" "<<clock()-st<<endl;
	}
	#undef lc
	#undef rc
}
using namespace lagrange;

int main(){
	return 0;
} 
上一篇:多项式模板


下一篇:poly