在看懂之前看大家就写了下面三个式子,感觉个个都是谜语人,然而确实只需要下面三个式子就行了...
对于 FFT,我们求的就是多项式卷积,思路是对数组变换一波,然后点乘,然后再逆变换回来
对于数组 A 和 B,求 C 定义为:$$C_i=\sum_{j\oplus k = i}A_j \times B_k$$
其中 \(\oplus\) 分别可以是 or, and, xor
要看具体说明和证明看这里,下面的只要知道一些写法的含义:\(A = (A_0, A_1, ...)\) 这种意思是数组拼接成新数组的表示方式;A_0也是数组,另外FWT(A)也是数组;A+B 表示数组的按位相加...
FMT 处理的是 or 和 and,其实就是二进制的高维前缀和、然后点乘后再差分,这里直接上式子:
设当前有 \(2^n\) 项,\(A_0\) 表示前 \(2^{n-1}\) 项,就是编号最高位为 0 的部分(别忘了下标 0, 1, ..., 2^n-1)
记 \(A'\) 为 \(FWT(A)\)
or 卷积:
\[FWT(A) = \begin{cases} (FWT(A_0), FWT(A_1)+FWT(A_0)) & (n>0) \\ A & (n=0) \end{cases} \] \[IFWT(A') = \begin{cases} (IFWT(A'_0), IWFT(A'_1)-IWFT(A'_0)) & (n>0) \\ A' & (n=0) \end{cases} \]and 卷积:
\[FWT(A) = \begin{cases} (FWT(A_0)+FWT(A_1), FWT(A_1)) & (n>0) \\ A & (n=0) \end{cases} \] \[IFWT(A') = \begin{cases} (IFWT(A'_0)-IFWT(A'_1), IFWT(A'_1)) & (n>0) \\ A' & (n=0) \end{cases} \]xor 卷积的证明以后再看...(
xor 卷积:
\[FWT(A) = \begin{cases} (FWT(A_0)+FWT(A_1), FWT(A_0)-FWT(A_1)) & (n>0) \\ A & (n=0) \end{cases} \] \[IFWT(A') = \begin{cases} (\frac{IFWT(A'_0)+IFWT(A'_1)}{2}, \frac{IWFT(A'_0)-IFWT(A'_1)}{2}) & (n>0) \\ A' & (n=0) \end{cases} \]#include <cstdio>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const ll inv2 = 499122177; // inv_2 of mod
const int MAXN = 200005;
int N; ll A[MAXN], B[MAXN];
struct myFWT {
// n must be power of 2; f [0, n-1]
int n; ll a[MAXN], b[MAXN];
void OR(ll *f, int type) {
for (int mid=1; mid< n; mid<<=1)
for (int blk=(mid<<1), j=0; j< n; j+=blk)
for (int i=j; i< j+mid; i++)
f[i+mid] = (f[i+mid] + f[i]*type + mod) % mod;
}
void AND(ll *f, int type) {
for (int mid=1; mid< n; mid<<=1)
for (int blk=(mid<<1), j=0; j< n; j+=blk)
for (int i=j; i< j+mid; i++)
f[i] = (f[i] + f[i+mid]*type + mod) % mod;
}
void XOR(ll *f, int type) {
for (int mid=1; mid< n; mid<<=1)
for (int blk=(mid<<1), j=0; j< n; j+=blk)
for (int i=j; i< j+mid; i++) {
ll x = f[i], y = f[i+mid];
f[i] = (x+y) % mod * (type==1?1:inv2) % mod;
f[i+mid] = (x-y+mod) % mod * (type==1?1:inv2) % mod;
}
}
void init() {
for (int i=0; i< n; i++) a[i] = A[i], b[i] = B[i];
}
void print() {
for (int i=0; i< n; i++) printf("%lld ", a[i]); printf("\n");
}
void workOR () {
init(), OR(a, 1), OR(b, 1);
for (int i=0; i< n; i++) a[i] = a[i] * b[i] % mod;
OR(a, -1), print();
}
void workAND() {
init(), AND(a, 1), AND(b, 1);
for (int i=0; i< n; i++) a[i] = a[i] * b[i] % mod;
AND(a, -1), print();
}
void workXOR() {
init(), XOR(a, 1), XOR(b, 1);
for (int i=0; i< n; i++) a[i] = a[i] * b[i] % mod;
XOR(a, -1), print();
}
} fwt;
int main()
{
scanf("%d", &N); N = 1<<N, fwt.n = N;
for (int i=0; i< N; i++) scanf("%lld", &A[i]), A[i] %= mod;
for (int i=0; i< N; i++) scanf("%lld", &B[i]), B[i] %= mod;
fwt.workOR(), fwt.workAND(), fwt.workXOR();
}