题目大意
给你两个长度为 \(3^k\) 的数组。
定义两个数的 \(\text{mex}_3(a,b)\) 为两个数的在三进制下每一位的 \(\text{mex}\) 所组成的数。
求一个新的数组:
\[c_k=\sum_{\text{mex}_3(i,j)=k}a_i\cdot b_j \]题解
考虑分治乘法,将每一位的 \(0,1,2\) 分别作贡献。
那么此时你可以将整个长度分成三份,分别为 \(a_{0},a_{1},a_{2}\) 三份,这三份分别去搞再合并就可以了,具体可以看代码。
代码
#pragma GCC optimize ("Ofast")
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=6e4+5,K=15;
int k,ksm[K];
vector<int> a,b,c;
int mex[3][3]={{1,2,1},{2,0,0},{1,0,0}};
vector<int> solve(int k,vector<int> a,vector<int> b){
vector<int> x,y,z,c;
for(int i=0;i<ksm[k];++i) c.push_back(0);
if(k==1){
int tmp;
tmp=a[0]*b[1],c[2]+=tmp,c[1]-=tmp;
tmp=a[1]*b[0],c[2]+=tmp,c[1]-=tmp;
tmp=(a[1]+a[2])*(b[1]+b[2]),c[0]+=tmp,c[1]-=tmp;
tmp=(a[0]+a[1]+a[2])*(b[0]+b[1]+b[2]),c[1]+=tmp;
return c;
}
for(int i=0;i<ksm[k];++i) x.push_back(0);
for(int i=0;i<ksm[k];++i) y.push_back(0);
for(int i=0;i<ksm[k-1];++i) x[i]=a[i];
for(int i=0;i<ksm[k-1];++i) y[i]=b[ksm[k-1]+i];
z=solve(k-1,x,y);
for(int i=0;i<ksm[k-1];++i) c[ksm[k-1]*2+i]+=z[i],c[ksm[k-1]+i]-=z[i];
//--------------------------------------------------
for(int i=0;i<ksm[k-1];++i) x[i]=a[ksm[k-1]+i];
for(int i=0;i<ksm[k-1];++i) y[i]=b[i];
z=solve(k-1,x,y);
for(int i=0;i<ksm[k-1];++i) c[ksm[k-1]*2+i]+=z[i],c[ksm[k-1]+i]-=z[i];
//--------------------------------------------------
for(int i=0;i<ksm[k-1];++i) x[i]=a[ksm[k-1]+i]+a[ksm[k-1]*2+i];
for(int i=0;i<ksm[k-1];++i) y[i]=b[ksm[k-1]+i]+b[ksm[k-1]*2+i];
z=solve(k-1,x,y);
for(int i=0;i<ksm[k-1];++i) c[i]+=z[i],c[ksm[k-1]+i]-=z[i];
//--------------------------------------------------
for(int i=0;i<ksm[k-1];++i) x[i]=a[i]+a[ksm[k-1]+i]+a[ksm[k-1]*2+i];
for(int i=0;i<ksm[k-1];++i) y[i]=b[i]+b[ksm[k-1]+i]+b[ksm[k-1]*2+i];
z=solve(k-1,x,y);
for(int i=0;i<ksm[k-1];++i) c[ksm[k-1]+i]+=z[i];
//--------------------------------------------------
return c;
}
signed main(){
cin>>k,ksm[0]=1;
for(int i=1;i<=k;++i) ksm[i]=ksm[i-1]*3;
for(int i=0;i<ksm[k];++i) a.push_back(0);
for(int i=0;i<ksm[k];++i) b.push_back(0);
for(int i=0;i<ksm[k];++i) scanf("%lld",&a[i]);
for(int i=0;i<ksm[k];++i) scanf("%lld",&b[i]);
c=solve(k,a,b);
for(int i=0;i<ksm[k];++i) printf("%lld ",c[i]);
printf("\n");
return 0;
}