2017年西安区域赛 Sum of xor sum(线段树)

传送门

题意:

给出一个长度为 \(n\) 的数组,有 \(q\) 次查询,每次查询给出一个区间 \([l,r]\) ,求这段区间里面所有子区间的异或和的总和。

题解:

不难想到,要按位考虑贡献,对于第 \(i\) 位的贡献是 \(2^i\) 乘上区间\(1\)的个数为奇数的子区间的数量。

考虑利用线段树维护一个区间中包含奇数个\(1\)的子区间数量。

如何区间合并?

设左区间范围是 \([l,mid]\) ,右区间范围是 \([mid+1,r]\) ,\(ans\)表示区间中包含奇数个\(1\)​的子区间数量。

那么 \(ans=left.ans+right.ans+\) 以\(mid\)为右端点包含奇数个\(1\)的区间数量 \(\times\) 以\(mid+1\)为左端点包含偶数个\(1\)的区间数量 \(+\) 以\(mid\)为右端点包含偶数个\(1\)的区间数量 \(\times\) 以\(mid+1\)为左端点包含奇数个\(1\)​的区间数量。

那么维护三个东西即可:区间中包含奇数个\(1\)的子区间数量 ,以\(mid\)为右端点包含奇数个\(1\)的区间数量,以\(mid+1\)为左端点包含奇数个\(1\)的区间数量。

代码:

#pragma GCC diagnostic error "-std=c++11"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
const int mod = 1e9 + 7;
const int MAXN = 2e5 + 5;
const int inf = 0x3f3f3f3f;
int a[MAXN], base[22];
struct Node {
    int l, r, sum[22];
    ll ans[22], lsum[22], rsum[22];
} node[MAXN << 2];
Node combine(Node x,Node y)
{
    Node k;
    k.l = x.l;
    k.r = y.r;
    int len1 = x.r - x.l + 1;
    int len2 = y.r - y.l + 1;
    for (int i = 0; i <= 20;i++)
    {
        k.sum[i] = x.sum[i] + y.sum[i];
        k.ans[i] = (x.ans[i] + y.ans[i] + x.rsum[i] * (len2 - y.lsum[i])%mod + (len1 - x.rsum[i]) * y.lsum[i]%mod)%mod;
        if(x.sum[i]&1)
            k.lsum[i] = x.lsum[i] + (len2 - y.lsum[i]);
        else
            k.lsum[i] = x.lsum[i] + y.lsum[i];
        if(y.sum[i]&1)
            k.rsum[i] = y.rsum[i] + (len1 - x.rsum[i]);
        else
            k.rsum[i] = y.rsum[i] + x.rsum[i];
    }
    return k;
}
void build(int l, int r, int num)
{
    node[num].l = l;
    node[num].r = r;
    if (l == r) {
        for (int i = 20; i >= 0; i--) {
            if((a[l]>>i)&1){
                node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]=1;
            }
            else {
                node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]= 0;
            }
        }
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, num << 1);
    build(mid + 1, r, num << 1|1);
    node[num] = combine(node[num << 1], node[num << 1 | 1]);
}
Node query(int l,int r,int num)
{
    if(node[num].l>=l&&node[num].r<=r)
    {
        return node[num];
    }
    int mid = (l + r) >> 1;
    if(r<=mid)
        return query(l, r, num << 1);
    else if(l>mid)
        return query(l, r, num << 1 | 1);
    else {
        Node tmp1 = query(l, r, num << 1);
        Node tmp2 = query(l, r, num << 1 | 1);
        Node tmp = combine(tmp1, tmp2);
        return tmp;
    }
}
int main()
{
    base[0] = 1;
    for (int i = 1; i <= 20; i++)
        base[i] = base[i - 1] * 2;
    int t;
    scanf("%d", &t);
    while (t--) {
        int n, q;
        scanf("%d%d", &n, &q);
        for (int i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
        }
        build(1, n, 1);
        while ((q--))
        {
            int l, r;
            scanf("%d%d", &l, &r);
            Node ans = query(l, r, 1);
            ll sum = 0;
            for (int i = 0; i <= 20;i++)
            {
                sum = (sum + base[i] * ans.ans[i]%mod)%mod;
            }
            printf("%lld\n", sum);
        }
        
    }

}
上一篇:C. Moamen and XOR[Codeforces Round #737 (Div. 2)]


下一篇:「Leetcode-算法_Easy461」通过「简单」题目学习位运算