题目
给出 \(n\) 个三元组\(\{ a_i,b_i,c_i \}\)和\(x,y,z\);
将每个三元组扩展成(\(x\)个\(a_i\),\(y\)个\(b_i\),\(z\)个\(c_i\));
问从\(n\)组里面每组选一个数,这\(n\)个数异或值为x 的方案数\(mod \ 998244353\)是多少;
\(1 \le n \le 10^5 \ , \ 1 \le k \le 17 \ , \ 0 \le x,y,z \le 10^9 \ , 0 \le \ a_i,b_i,c_i \lt 2^k\) ;
题解
最后的答案异或一个 \(\oplus_{i=1}^{n} a_i\) ,令\(\{a_i,b_i,c_i\}\)变成$ { 0 , a_i \wedge b_i , a_i \wedge c_i } $ ;
令\(F_{i,0}+=x \ , \ F_{i,b_i}+=y \ , \ F_{i,c_i}+=z\) ,把所有\(fwt(F_i)\)点乘起来再\(ifwt\)回去即可;
考虑如何求最后的乘积\(\Pi F_i\);
-
对于\(fwt(F_i)\),每一项一定都是\(x+y+z \ , \ x+y-z \ , \ x-y+z \ , x - y - z\) 之一;
设纵向的个数为\(i,j,k,l\),解出每一位\(i,j,k,l\)即可快速算出最后的乘积,首先:
\[ \begin{align} i +j +k + l = n \end{align} \]
令只考虑\(F_i,b_i=1\),设所有\(F\)加起来\(fwt\)到得到对应位值上的值为\(p\)(x=0,y=1,z=0):
\[ i + j - k - l = p \]
同理只令\(F_i,c_i = 1\),有(x=0,y=0,z=1):
\[ i - j + k - l = p \]
令\(F_{i,b_i \wedge c_i}=1\),相当于上面两个的点值乘法,有
\[ i - j - k + l = p \]
解方程即可; -
最后\(ifwt\)回来;
#include<bits/stdc++.h> #define mod 998244353 #define ll long long using namespace std; const int N=1<<17; int n,X,Y,Z,l,s; int A[N],B[N],C[N],ans[N]; char gc(){ static char*p1,*p2,S[1000000]; if(p1==p2)p2=(p1=S)+fread(S,1,1000000,stdin); return(p1==p2)?EOF:*p1++; } int rd(){ int x=0;char c=gc(); while(c<'0'||c>'9')c=gc(); while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+c-'0',c=gc(); return x; } int pw(int x,int y){ int re=1; while(y){ if(y&1)re=(ll)re*x%mod; y>>=1;x=(ll)x*x%mod; } return re; } void fwt(int*a){ for(int i=1;i<l;i<<=1) for(int j=0;j<l;j+=i<<1) for(int k=0;k<i;++k){ int t1=a[j+k],t2=a[j+k+i]; a[j+k]=t1+t2; a[j+k+i]=t1-t2; } } void dec(int&x,int y){x-=y;if(x<0)x+=mod;} void ifwt(int*a){ for(int i=1;i<l;i<<=1) for(int j=0;j<l;j+=i<<1) for(int k=0;k<i;++k){ int iv2=(mod+1)/2; int t1=a[j+k],t2=a[j+k+i]; a[j+k]=(ll)(t1+t2)*iv2%mod; a[j+k+i]=(ll)(t1-t2+mod)*iv2%mod; } } int main(){ //freopen("H.in","r",stdin); //freopen("H.out","w",stdout); n=rd();l=1<<rd(); X=rd();Y=rd();Z=rd(); for(int i=1;i<=n;++i){ int a=rd(),b=rd(),c=rd(); s^=a;b^=a;c^=a;a=b^c; A[b]++,B[c]++,C[a]++; } fwt(A);fwt(B);fwt(C); int t1=((ll)X+Y+Z)%mod; int t2=((ll)X+Y-Z+mod)%mod; int t3=((ll)X-Y+Z+mod)%mod; int t4=((ll)X-Y-Z+mod+mod)%mod; for(int i=0;i<l;++i){ ans[i] = (ll)pw(t1,(n+A[i]+B[i]+C[i])>>2) *pw(t2,(n+A[i]-B[i]-C[i])>>2)%mod *pw(t3,(n-A[i]+B[i]-C[i])>>2)%mod *pw(t4,(n-A[i]-B[i]+C[i])>>2)%mod; } ifwt(ans); for(int i=0;i<l;++i)printf("%d ",ans[i^s]); return 0; }