对于一个区间有如下数字{5,6,7,8},他们的二进制表示分别为:
5:0101
6:0110
7:0111
8:1000
那么这区间数的总和可以这么计算:
1(2^3) + 3(2^2) + 2(2^1) + 2(2^0);
其中(2^i)次方前的系数就是第i位的1的个数之和
将他们异或上一个值 如 4:
5^4: 0101^0100 = 0001
6^4: 0110^0100 = 0010
7^4: 0111^0100 = 0011
8^4: 1000^0100 = 1100
那么异或后这一段区间的总和为:
1(2^3) + 1(2^2) + 2(2^1) + 2(2^0);
发现对比异或之前就只有(2^2)前的系数改变;
原因是4的二进制表示为0100, 只有(2^2)的系数为1;
所以对于需要异或的一个值x,如果x的第i位为1,那么这个区间内的第i位的1的个数就需要变化;
变化的结果就是区间长度减去第i位1的个数;
于是可以开20棵线段树维护每一位的结果;
代码段我对于一个节点开了个数组,就不搞个20棵了,本质上是一样的;
#include <bits/stdc++.h>
//#pragma GCC optimize(2)
using namespace std;
#define LL long long
#define ll long long
#define ULL unsigned long long
#define Pair pair<LL,LL>
#define ls rt<<1
#define rs rt<<1|1
#define Pi acos(-1.0)
#define eps 1e-6
#define DBINF 1e100
#define mod 1000000007
#define MAXN 100000
#define MXLEN 17
#define MS 100009
int n,m;
struct node{
int cnt[22];
int la[22];
bool isla;
}p[MS<<2];
void push_up(int rt){
for(int i=20;i>=0;i--){
p[rt].cnt[i] = p[ls].cnt[i] + p[rs].cnt[i];
}
}
void build(int l,int r,int rt){
if(l == r){
int x;
cin >> x;
for(int i=20;i>=0;i--){
p[rt].cnt[i] = ( (x>>i)&1 );
}
return;
}
int m = l+r>>1;
build(l,m,ls); build(m+1,r,rs);
push_up(rt);
}
void push_down(int rt,int l,int r){
if(p[rt].isla){
int m = l+r>>1;
int ln = m-l+1;
int rn = r-m;
p[ls].isla = p[rs].isla = false;
for(int i=20;i>=0;i--){
int t = p[rt].la[i];
p[rt].la[i] = 0;
if(t){
p[ls].cnt[i] = ln - p[ls].cnt[i];
p[rs].cnt[i] = rn - p[rs].cnt[i];
}
p[ls].la[i] ^= t;
p[rs].la[i] ^= t;
if(p[ls].la[i]) p[ls].isla = true;
if(p[rs].la[i]) p[rs].isla = true;
}
p[rt].isla = false;
}
}
void modify(int L,int R,int l,int r,int rt,int x){
if(L <= l && r <= R){
int sum = r-l+1;
p[rt].isla = false;
for(int i=20;i>=0;i--){
int t = (x>>i)&1;
if(t) p[rt].cnt[i] = sum-p[rt].cnt[i];
p[rt].la[i] ^= t;
if(p[rt].la[i]) p[rt].isla = true;
}
return;
}
int m = l+r>>1;
push_down(rt,l,r);
if(m >= L) modify(L,R,l,m,ls,x);
if(m < R) modify(L,R,m+1,r,rs,x);
push_up(rt);
}
LL query(int L,int R,int l,int r,int rt){
if(L <= l && r <= R){
LL sum = 0;
for(int i=20;i>=0;i--){
LL t = p[rt].cnt[i];
sum += (1ll<<i)*t;
}
return sum;
}
int m = l+r>>1;
push_down(rt,l,r);
LL ans = 0;
if(m >= L) ans += query(L,R,l,m,ls);
if(m < R) ans += query(L,R,m+1,r,rs);
return ans;
}
int main() {
ios::sync_with_stdio(false);
cin >> n;
build(1,n,1);
cin >> m;
for(int i=1;i<=m;i++){
int op,l,r,x;
cin >> op >> l >> r;
if(op == 1){
cout << query(l,r,1,n,1) << "\n";
}
else if(op == 2){
cin >> x;
modify(l,r,1,n,1,x);
}
}
return 0;
}