【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)

题干:

链接:https://ac.nowcoder.com/acm/contest/881/H
来源:牛客网
 

Bobo has a set A of n integers a1,a2,…,ana1,a2,…,an.
He wants to know the sum of sizes for all subsets of A whose xor sum is zero modulo (109+7)(109+7).
Formally, find (∑S⊆A,⊕x∈Sx=0|S|)mod(109+7)(∑S⊆A,⊕x∈Sx=0|S|)mod(109+7). Note that ⊕⊕ denotes the exclusive-or (XOR).

输入描述:

The input consists of several test cases and is terminated by end-of-file.

The first line of each test case contains an integer n.
The second line contains n integers a1,a2,…,ana1,a2,…,an.

* 1≤n≤1051≤n≤105
* 0≤ai≤10180≤ai≤1018
* The number of test cases does not exceed 100.
* The sum of n does not exceed 2×1062×106.

输出描述:

For each test case, print an integer which denotes the result.

示例1

输入

复制

1
0
3
1 2 3
2
1000000000000000000 1000000000000000000

输出

复制

1
3
2

题目大意:

给你n个数字,然后让你求所有满足 异或和为0的子集 的大小之和。

解题报告:

根据期望的线性性(Orz),转化一下题意:相当于求每个数出现在子集中的次数之和。

那么如何求每个的合法出现次数呢?对于每个数x的出现次数,也就是这个数x必选的情况下,有多少种选择方案可以让剩余n-1个数凑出x来。然后就可以【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)

先对n个数求线性基b1,设线性基大小为r,可以分别计算线性基内数的贡献和线性基外数的贡献

1.线性基外:共n-r个数,枚举每个数,让他必选,将线性基外剩余的【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)个数任意排列(也可不选),则共有 【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)种方案,每种方案再异或x的结果,能被刚刚求出的那组线性基唯一的异或出来,所以每个数的贡献为:【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)。又因为一共有n-r个数,所以线性基外的数的贡献为:【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)

2.线性基内:依旧是枚举每个数x(注意枚举的x是a数组中的数而不是b1中的数),求出剩余n-1个数凑出x的方案数,做法是这样的:将所有剩余的n-1个数再求一次线性基,设为b2,分两种情况:

(1) x不能被b2异或出。那么显然x不能在任意一个集合中出现,x的贡献为0。

(2) x可以被b2异或出。此时b2的大小必定也为r,因为b2已经能表示所有n个数了。那么在除去x和b2的情况下,剩余n-r-1个数显然也是任意排列,x贡献为 【2019牛客暑期多校训练营(第一场) - H】XOR(线性基,期望的线性性)

对于在线性基内的方案统计时,还有另一种理解:因为这个数x在线性基内,也就是说他代表了二进制中的一个Bit位在线性基中的存在,那么我们要凑出这个x的话(注意x依旧是原数,而不是线性基内的数字),就必须要找一个可以代表这个Bit位的数字来替代他,也就是:用除去x的剩余n-1个数,重构一组线性基后,如果基的大小仍为b1.r,则说明x在原基b1中不是必须的,可以被替代,此时也才说明x对答案做出了贡献,不然的话只有x能代表这个Bit,那他代表给谁看?和谁去异或成0?所以对答案没有任何贡献。所以写代码最后一部分的时候,既可以这样写:if(b3.r == b1.r) 也可以这样写:if(b3.ins(a[i])==0)也就是说如果插不进去a[i]了,就代表他对答案做出了贡献,因为新基可以凑出所有数字。

AC代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<queue>
#include<map>
#include<vector>
#include<set>
#include<string>
#include<cmath>
#include<cstring>
#define F first
#define S second
#define ll long long
#define pb push_back
#define pm make_pair
using namespace std;
typedef pair<int,int> PII;
const int MAX = 2e5 + 5;
const ll mod = 1e9 + 7;
int n,top;
ll a[MAX],ans,tmp,b[MAX];
struct Node {
    ll Base[66];
    int r;
    void init() {
        memset(Base,0,sizeof Base);
        r=0;
    }
    bool ins(ll x) {
        for(int i = 62; i>=0; i--) {
            if(x & (1LL<<i)) {
                if(Base[i] == 0) {
                    Base[i] = x;r++;return 1;
                }
                else x ^= Base[i];
            }
            if(x == 0) return 0;
        }
        return 0;
    }
} b1,b2,b3;
ll qpow(ll a,ll k) {
    ll res = 1;
    while(k) {
        if(k&1) {
            res = (res * a) % mod;
        }
        k>>=1;
        a = (a * a)%mod;
    }
    return res;
}
int main()
{
    while(~scanf("%d",&n)) {
        ans = tmp = 0;b1.init();b2.init();
        top=0;
        for(int i = 1; i<=n; i++) scanf("%lld",a+i);
        for(int i = 1; i<=n; i++) {
            if(b1.ins(a[i])) b[++top] = a[i];
            else b2.ins(a[i]);
        }
        if(n == b1.r) {
            printf("0\n");
            continue;
        }
        ans = (n-b1.r) * qpow(2,n - b1.r - 1);
        for(int i = 1; i<=top; i++) {
            b3 = b2;
            for(int j = 1; j<=top; j++) {
                if(i == j) continue;
                b3.ins(b[j]);
            }
            if(b3.r == b1.r) tmp++;
        }
        ans = ans + tmp*qpow(2,n - b1.r - 1);
        printf("%lld\n",ans%mod);
    }  
    return 0 ;
}

AC代码2:(咖啡鸡的代码,也不知道是维护了个啥)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=2e5+3;
const ll M=1000000007;
ll v;
struct base{
    ll r[64],o[64];
    bool ins(ll x){
        bool flag=0;
        ll tmp=0;
        for (int i=62;i>=0;i--)
            if (x>>i){
                if (!r[i]) o[i]=tmp|(1ll<<i),r[i]=x,flag=1;
                x^=r[i]; tmp^=o[i];
                if (!x) break;
            }
        if (!flag){
            v|=tmp;
        }
        return flag;
    }
    void clear(){
        for (int i=0;i<64;i++) r[i]=0,o[i]=0;
    }
}f;
ll pow_(ll x,ll y){
    ll ret=1;
    while (y){
        if (y&1) ret=ret*x%M;
        x=x*x%M; y>>=1;
    }
    return ret;
}
int n,w;
ll a[maxn],ans;
 
int main(){
    while (scanf("%d",&n)==1){
        for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
        for (int i=0;i<64;i++) v=0; f.clear(); ans=w=0;
        for (int i=1;i<=n;i++){
            ans+=!f.ins(a[i]);
        }
        for (int i=0;i<64;i++) {
            if (v&(1ll<<i)) ++ans;
            if (f.r[i]) ++w;
        }
        cout << ans*pow_(2ll,n-1-w)%M << endl;
    }
    return 0;
}

 

上一篇:类的操作


下一篇:Java面试题_001