题意:
给出一个长度为 \(n\) 的数组,有 \(q\) 次查询,每次查询给出一个区间 \([l,r]\) ,求这段区间里面所有子区间的异或和的总和。
题解:
不难想到,要按位考虑贡献,对于第 \(i\) 位的贡献是 \(2^i\) 乘上区间\(1\)的个数为奇数的子区间的数量。
考虑利用线段树维护一个区间中包含奇数个\(1\)的子区间数量。
如何区间合并?
设左区间范围是 \([l,mid]\) ,右区间范围是 \([mid+1,r]\) ,\(ans\)表示区间中包含奇数个\(1\)的子区间数量。
那么 \(ans=left.ans+right.ans+\) 以\(mid\)为右端点包含奇数个\(1\)的区间数量 \(\times\) 以\(mid+1\)为左端点包含偶数个\(1\)的区间数量 \(+\) 以\(mid\)为右端点包含偶数个\(1\)的区间数量 \(\times\) 以\(mid+1\)为左端点包含奇数个\(1\)的区间数量。
那么维护三个东西即可:区间中包含奇数个\(1\)的子区间数量 ,以\(mid\)为右端点包含奇数个\(1\)的区间数量,以\(mid+1\)为左端点包含奇数个\(1\)的区间数量。
代码:
#pragma GCC diagnostic error "-std=c++11"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
const int mod = 1e9 + 7;
const int MAXN = 2e5 + 5;
const int inf = 0x3f3f3f3f;
int a[MAXN], base[22];
struct Node {
int l, r, sum[22];
ll ans[22], lsum[22], rsum[22];
} node[MAXN << 2];
Node combine(Node x,Node y)
{
Node k;
k.l = x.l;
k.r = y.r;
int len1 = x.r - x.l + 1;
int len2 = y.r - y.l + 1;
for (int i = 0; i <= 20;i++)
{
k.sum[i] = x.sum[i] + y.sum[i];
k.ans[i] = (x.ans[i] + y.ans[i] + x.rsum[i] * (len2 - y.lsum[i])%mod + (len1 - x.rsum[i]) * y.lsum[i]%mod)%mod;
if(x.sum[i]&1)
k.lsum[i] = x.lsum[i] + (len2 - y.lsum[i]);
else
k.lsum[i] = x.lsum[i] + y.lsum[i];
if(y.sum[i]&1)
k.rsum[i] = y.rsum[i] + (len1 - x.rsum[i]);
else
k.rsum[i] = y.rsum[i] + x.rsum[i];
}
return k;
}
void build(int l, int r, int num)
{
node[num].l = l;
node[num].r = r;
if (l == r) {
for (int i = 20; i >= 0; i--) {
if((a[l]>>i)&1){
node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]=1;
}
else {
node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]= 0;
}
}
return;
}
int mid = (l + r) >> 1;
build(l, mid, num << 1);
build(mid + 1, r, num << 1|1);
node[num] = combine(node[num << 1], node[num << 1 | 1]);
}
Node query(int l,int r,int num)
{
if(node[num].l>=l&&node[num].r<=r)
{
return node[num];
}
int mid = (l + r) >> 1;
if(r<=mid)
return query(l, r, num << 1);
else if(l>mid)
return query(l, r, num << 1 | 1);
else {
Node tmp1 = query(l, r, num << 1);
Node tmp2 = query(l, r, num << 1 | 1);
Node tmp = combine(tmp1, tmp2);
return tmp;
}
}
int main()
{
base[0] = 1;
for (int i = 1; i <= 20; i++)
base[i] = base[i - 1] * 2;
int t;
scanf("%d", &t);
while (t--) {
int n, q;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
build(1, n, 1);
while ((q--))
{
int l, r;
scanf("%d%d", &l, &r);
Node ans = query(l, r, 1);
ll sum = 0;
for (int i = 0; i <= 20;i++)
{
sum = (sum + base[i] * ans.ans[i]%mod)%mod;
}
printf("%lld\n", sum);
}
}
}