【数论】FFT,NTT,FWT

FFT-快速傅里叶变换

一,时间复杂度

时间复杂度: \(\mathcal{O(nlogn)}\)

二,限制

  • 计算过程(即 多项式相乘过程)不能取模;
  • 常数较大;
  • 会存在精度差。

三,模板

计算多项式 \(C=A\times B\), 其中 多项式 \(A\) 的长度是 \(n+1\), 多项式 \(B\) 的长度是 \(m+1\) ,得到多项式 \(C\) 的长度是 \(n+m+1\)。

#include <bits/stdc++.h>
using namespace std;
const int maxn=2e6+5;
const double PI=acos(-1);
struct comp{
    double x,y;
    comp operator +(comp b){return comp{x+b.x,y+b.y};}
    comp operator -(comp b){return comp{x-b.x,y-b.y};}
    comp operator *(comp b){return comp{x*b.x-y*b.y,x*b.y+y*b.x};}
};
int up,L,R[maxn<<1];
void fft_init(int n,int m){
    up=1;L=0;
    while(up<n+m+2)up<<=1,L++;
    for(int i=0;i<up;i++)
        R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
}
comp a[maxn<<1],b[maxn<<1],c[maxn<<1];
void fft(comp a[],int type)
{
    for(int i=0;i<up;i++)
        if(i<R[i])swap(a[i],a[R[i]]);
    for(int mid=1;mid<up;mid<<=1)
    {
        comp wn=comp{cos(PI/mid),type*sin(PI/mid)};
        for(int r=mid<<1,j=0;j<up;j+=r)
        {
            comp w{1,0};
            for(int k=0;k<mid;k++,w=w*wn)
            {
                comp x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;
                a[j+mid+k]=x-y;
            }
        }
    }
}
int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%lf",&a[i].x);
    for(int i=0;i<=m;i++)scanf("%lf",&b[i].x);
    fft_init(n,m);
    fft(a,1);fft(b,1);
    for(int i=0;i<=up;i++)c[i]=a[i]*b[i];
    fft(c,-1);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)(c[i].x/up+0.5));
}

NTT-快速数论变换

一, 时间复杂度

时间复杂度: \(\mathcal{O(nlogn)}\)

二,限制

  • 系数必须是整数;

  • 系数取模时, 模数有限制, 需要知道模数的原根。

    • 常见模数对应原根:

      ​ 998244353 \(\to\) 3

      ​ 100000009 \(\to\) 5

三,模板

计算多项式 \(C=A\times B\), 其中 多项式 \(A\) 的长度是 \(n+1\), 多项式 \(B\) 的长度是 \(m+1\) ,得到多项式 \(C\) 的长度是 \(n+m+1\)。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=2e6+5;
const int mod=998244353;
const int g=3;
const int gi=332748118;//invg
ll kpow(ll a,ll b){
    ll ans=1;
    while(b){
        if(b&1)ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}
int up,L,R[maxn<<1];
void ntt_init(int n,int m){
    up=1;L=0;
    while(up<n+m+2)up<<=1,L++;
    for(int i=0;i<up;i++)
        R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
}
int a[maxn<<1],b[maxn<<1],c[maxn<<1];
void ntt(int a[],int type)
{
    for(int i=0;i<up;i++)
        if(i<R[i])swap(a[i],a[R[i]]);
    for(int mid=1;mid<up;mid<<=1){
        ll wn=kpow(type==1?g:gi,(mod-1)/(mid<<1));
        for(int r=mid<<1,j=0;j<up;j+=r)
        {
            ll w=1;
            for(int k=0;k<mid;k++,w=w*wn%mod)
            {
                int x=a[j+k],y=w*a[j+k+mid]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
}
int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++)scanf("%d",&a[i]);
    for(int i=0;i<=m;i++)scanf("%d",&b[i]);
    ntt_init(n,m);
    ntt(a,1);ntt(b,1);
    for(int i=0;i<up;i++)c[i]=1ll*a[i]*b[i]%mod;
    ntt(c,-1);
    int inv=kpow(up,mod-2);
    for(int i=0;i<=n+m;i++)printf("%lld ",1ll*c[i]*inv%mod);
}

FWT-快速沃尔什变换

一,作用

FFT可以解决多项式卷积, 即:

\[C_k=\sum_{k=i+j}\,A_i*B_j \]

FWT可以解决 或/与/异或 卷积, 即:

\[C_k=\sum_{k=i|j}\,A_i*B_j\\ C_k=\sum_{k=i\&j}\,A_i*B_j\\ C_k=\sum_{k=i\bigotimes j}\,A_i*B_j \]

二,时间复杂度

时间复杂度: \(\mathcal{O}(nlogn)\)

三,做法

1, 将 \(A\) 转换成 \(FWT[A]\) , 将 \(B\) 转换成 \(FWT[B]\)

2, 计算

\[FWT[C][i]=FWT[A][i]*FWT[B][i] \]

3, 将 \(FWT[C]\) 转换成 \(C\) .

四,模板

\(code:\)

void FWT_or(long long a[],int len)//A -> FAT[A]
{
	for(int mid=2;mid<=len;mid<<=1)
	for(int i=0;i<len;i+=mid)
	for(int j=i;j<i+(mid>>1);j++)a[j+(mid>>1)]+=a[j];
}
void IFWT_or(long long a[],int len)// FAT[A] -> A
{
	for(int mid=2;mid<=len;mid<<=1)
	for(int i=0;i<len;i+=mid)
	for(int j=i;j<i+(mid>>1);j++)a[j+(mid>>1)]-=a[j];
}
void FWT_and(long long a[],int len)
{
	for(int mid=2;mid<=len;mid<<=1)
	for(int i=0;i<len;i+=mid)
	for(int j=i;j<i+(mid>>1);j++)a[j]+=a[j+(mid>>1)];
}
void IFWT_and(long long a[],int len)
{
	for(int mid=2;mid<=len;mid<<=1)
	for(int i=0;i<len;i+=mid)
	for(int j=i;j<i+(mid>>1);j++)a[j]-=a[j+(mid>>1)];
}
void FWT_xor(long long a[],int len)
{
	for(int mid=2;mid<=len;mid<<=1)
	for(int i=0;i<len;i+=mid)
	for(int j=i;j<i+(mid>>1);j++)
	{
		long long x=a[j],y=a[j+(mid>>1)];
		a[j]=x+y,a[j+(mid>>1)]=x-y;
	}
}
inline void IFWT_xor(long long a[],int len)
{
	for(int mid=2;mid<=len;mid<<=1)
	for(int i=0;i<len;i+=mid)
	for(int j=i;j<i+(mid>>1);j++)
	{
		long long x=a[j],y=a[j+(mid>>1)];
		a[j]=(x+y)>>1,a[j+(mid>>1)]=(x-y)>>1;
	}
}

五,例题

2020牛客多校第二场 E- Exclusive OR

题意:

给定长度为 \(n\) 的数组 \(A\) ,其中 \(1\le n\le 2\times 10^5,\,0\le A_i<2^{18}\) 。

对于所有的 \(1\le i\le n\) ,要求:

从数组 \(A\) 中取出 \(i\) 个数(可重复取同一个数) , 计算出 \(a_1\bigotimes a_2\bigotimes...\bigotimes a_i\) 的值。

题解:

当 \(i> 19\) 之后 , 有 \(ans_i=ans_{i-2}\) 。

只需用异或卷积求出前 \(19\) 个答案。

代码:

#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int inf=0x3f3f3f3f;
const int maxn=1e6+5;

void read(int&x)
{
    char c;
    while(!isdigit(c=getchar()));x=c-'0';
    while(isdigit(c=getchar()))x=(x<<3)+(x<<1)+c-'0';
}
void FWT_xor(int a[],int n){
    for(int mid=2;mid<=n;mid<<=1)
        for(int i=0;i<n;i+=mid)
            for(int j=i,x,y;j<i+(mid>>1);j++)
            {
                x=a[j];y=a[j+(mid>>1)];
                a[j]=x+y;a[j+(mid>>1)]=x-y;
            }
}
void IFWT_xor(int a[],int n){
    for(int mid=2;mid<=n;mid<<=1)
        for(int i=0;i<n;i+=mid)
            for(int j=i,x,y;j<i+(mid>>1);j++)
            {
                x=a[j];y=a[j+(mid>>1)];
                a[j]=(x+y)>>1;a[j+(mid>>1)]=(x-y)>>1;
            }
}
int a[maxn],b[maxn],c[maxn];
int ans[maxn];

int main()
{
    int n,up,ed=1<<18;
    read(n);
    for(int i=1,x;i<=n;i++){
        read(x);
        a[x]=1;ans[1]=max(ans[1],x);
    }
    up=min(n,19);
    FWT_xor(a,ed);
    for(int i=0;i<ed;i++)b[i]=a[i];
    for(int i=2;i<=up;i++){
        for(int j=0;j<ed;j++)
            c[j]=a[j]*b[j];
        IFWT_xor(c,ed);
        for(int j=0;j<ed;j++)
            if(c[j])ans[i]=j,c[j]=1;
        FWT_xor(c,ed);
        for(int j=0;j<ed;j++)b[j]=c[j];
    }
    for(int i=up+1;i<=n;i++)ans[i]=ans[i-2];
    for(int i=1;i<=n;i++)printf("%d ",ans[i]);putchar(10);
}
上一篇:Luogu P4717 【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)


下一篇:找树