P4462 [CQOI2018]异或序列
题目描述
已知一个长度为n的整数数列\(a_1,a_2,...,a_n\),给定查询参数\(l、r\),问在\(a_l,a_{l+1},...,a_r\)?区间内,有多少子序列满足异或和等于\(k\)。也就是说,对于所有的\(x,y (I ≤ x ≤ y ≤ r)\),能够满足\(a_x \bigoplus a_{x+1} \bigoplus ... \bigoplus a_y = k\)的\(x,y\)有多少组。
输入格式
输入文件第一行,为3个整数\(n,m,k。\)
第二行为空格分开的n个整数,即\(a_1,a_2,..a_n\)。
接下来m行,每行两个整数\(l_j,r_j\)?,表示一次查询。
输出格式
输出文件共m行,对应每个查询的计算结果。
输入输出样例
输入 #1
4 5 1
1 2 3 1
1 4
1 3
2 3
2 4
4 4
输出 #1
4
2
1
2
1
说明/提示
对于30%的数据,\(1 ≤ n, m ≤ 1000\)
对于100%的数据,\(1 ≤ n, m ≤ 10^5, 0 ≤ k, a_i ≤ 10^5,1 ≤ l_j ≤ r_j ≤ n\)
Solution
异或操作有一个性质
已知a^b=c
则a^c=b,b^c=a
那么我们将原序列\(a[]\)变成前缀和序列\(s[]\)之后
\(a_x \bigoplus a_{x+1} \bigoplus ... \bigoplus a_y = k\)\(\Rightarrow\)\(s_{x-1}\bigoplus s_r=k\)\(\Rightarrow\)\(s_{r}\bigoplus k=s_{x-1}\)
那么问题就由多少个区间异或和为\(k\)转化成了多少个数对的异或值为\(k\)
可以使用莫队求解
重点还是这个\(add()\)函数和\(remove()\)函数
当我们加入一个值\(a[x]\)时,我们想要知道当前所在区间有多少个\(a[i]\bigoplus a[x]=k\),其中\(i<x\),也就是\(a[x]\bigoplus k\)的个数,因为这些数都可以和\(a[x]\)异或起来得到\(k\)
void add(int x) {ans+=cnt[a[x]^k]; cnt[a[x]]++;}
当我们删除一个值时,我们依然需要知道前面有多少个\(a[i]\bigoplus a[x]=k\),其中\(i<x\)
void remove(int x) {ans-=cnt[a[x]^k]; cnt[a[x]]--;}
代码是这样子的吗,不对!
当k=0时,\(a[x]\bigoplus k=a[x]\),如果我们先用\(ans\)减去贡献,再使\(cnt[]--\),答案会少一
原因就是我们要统计的其实是\(i<x\)的\(cnt[a[x]]\)的个数,不能把第\(x\)个位置上的值算进去,所以需要先\(cnt[a[x]]--\),再用\(ans\)减去贡献
void remove(int x) {cnt[a[x]]--; ans-=cnt[a[x]^k];}
当然,在k!=0的情况下,顺序是没有影响的,因为如果k!=0,那么\(a[x]\bigoplus k!=a[x]\),那么当前的\(cnt[a[x]]\)不会影响到答案
如果不相信的话,下面是hack数据
3 2 0
0 0 0
1 2
3 3
正确答案应该是
3
1
Code
#include<bits/stdc++.h>
#define lol long long
#define in(i) (i=read())
using namespace std;
const lol N=1e5+10,mod=1e9+7;
lol read() {
lol ans=0,f=1; char i=getchar();
while(i<‘0‘ || i>‘9‘) {if(i==‘-‘) f=-1; i=getchar();}
while(i>=‘0‘ && i<=‘9‘) ans=(ans<<1)+(ans<<3)+(i^48),i=getchar();
return ans*f;
}
int n,m,k,block,ans;
int a[N],cnt[N],sum[N];
struct query{
int l,r,id,pos;
bool operator < (const query &a) const {
return pos==a.pos?r<a.r:pos<a.pos;
}
}t[N];
void add(int x) {ans+=cnt[a[x]^k]; cnt[a[x]]++;}
void remove(int x) {cnt[a[x]]--; ans-=cnt[a[x]^k];}
int main() {
in(n), in(m), in(k); block=sqrt(n);
for (int i=1;i<=n;i++) in(a[i]), a[i]^=a[i-1];
for (int i=1,l,r;i<=m;i++) {
in(l), in(r);
t[i].l=l-1, t[i].r=r;
t[i].id=i;
t[i].pos=(l-1)/block+1;
}
sort(t+1,t+1+m);
for (int i=1,curl=1,curr=0;i<=m;i++) {
int l=t[i].l,r=t[i].r;
while(curl<l) remove(curl++);
while(curl>l) add(--curl);
while(curr<r) add(++curr);
while(curr>r) remove(curr--);
sum[t[i].id]=ans;
}
for (int i=1;i<=m;i++) cout<<sum[i]<<endl;
}