题目链接
题目大意
让你求区间异或和前\(k\)大的异或和的和。
正解
这道题目是Blue sky大佬教我做的(祝贺bluesky大佬进HA省A队)
我们做过某一些题目,非常的相似。【超级钢琴】还有【最小函数值】还有【最大异或和】
感觉这一些题目拼在一起就变成了这一道水题。
首先我们需要预处理出,所有区间的异或最大值。
这个东西可以用可持久化\(01trie\)实现,那么我们思考一下如何实现查询第\(k\)大的值的操作。
以下是关于01字典树中查询第k大的操作的讲解
可以参考平衡树和01trie贪心的策略。
因为我们是找到当前的子节点的另外一个。
因为从高位开始贪心,所以如果不相同则一定是比我们要求的答案要大,那么就减去这一部分,并且调到另外一个儿子上。
代码实现
ll query(int rt, ll val, int kth, int len) {
ll res = 0;
for (int i = len; ~i; i --) {
int p = (val >> i) & 1;
if (cnt[ch[rt][p ^ 1]] < kth) kth -= cnt[ch[rt][p ^ 1]], rt = ch[rt][p];
else res += (1ll << i), rt = ch[rt][p ^ 1];
}
return res;
}
那么维护以每一个节点为结束的区间异或最大值。
那么回归正题,参照最小函数值和超级钢琴的思路,我们就每一次取出最大值之后,将这个区间次大值拎出来,放入优先队列中。
运行k遍就是我们需要的答案了。
#include <bits/stdc++.h>
#define ms(a, b) memset(a, b, sizeof(a))
#define ll long long
#define ull unsigned long long
#define ms(a, b) memset(a, b, sizeof(a))
#define inf 0x3f3f3f3f
#define db double
#define Pi acos(-1)
#define eps 1e-8
#define N 600005
using namespace std;
template <typename T> T power(T x, T y, T mod) { x %= mod; T res = 1; for (; y; y >>= 1) { if (y & 1) res = (res * x) % mod; x = (x * x) % mod; } return res; }
template <typename T> void read(T &x) {
x = 0; T fl = 1; char ch = 0;
for (; ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') fl = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
x *= fl;
}
template <typename T> void write(T x) {
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10); putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) { write(x); puts(""); }
struct node {
ll val; int id;
node(ll Val = 0, int Id = 0) { val = Val; id = Id; }
bool operator < (const node &B) const { return val < B.val; }
};
priority_queue<node> q;
struct L_Trie {
int ch[N * 40][2], cnt[N * 40], tot;
void ins(int &rt, int pre, ll val, int len) {
rt = ++ tot; int k = rt;
for (int i = len; ~i; i --) {
ch[k][0] = ch[pre][0]; ch[k][1] = ch[pre][1]; cnt[k] = cnt[pre] + 1;
int p = (val >> i) & 1;
ch[k][p] = ++ tot;
k = ch[k][p]; pre = ch[pre][p];
}
cnt[k] = cnt[pre] + 1;
}
ll query(int rt, ll val, int kth, int len) {
ll res = 0;
for (int i = len; ~i; i --) {
int p = (val >> i) & 1;
if (cnt[ch[rt][p ^ 1]] < kth) kth -= cnt[ch[rt][p ^ 1]], rt = ch[rt][p];
else res += (1ll << i), rt = ch[rt][p ^ 1];
}
return res;
}
} trie;
int root[N], kth[N];
ll sumxor[N], a[N];
int n, k;
ll ans;
int main() {
read(n); read(k);
trie.ins(root[0], 0, 0, 31);
for (int i = 1; i <= n; i ++) {
kth[i] = 1; read(a[i]);
sumxor[i] = sumxor[i - 1] ^ a[i];
trie.ins(root[i], root[i - 1], sumxor[i], 31);
q.push(node(trie.query(root[i - 1], sumxor[i], kth[i], 31), i));
}
while (k --) {
node cur = q.top(); q.pop();
ans += cur.val; kth[cur.id] ++;
q.push(node(trie.query(root[cur.id - 1], sumxor[cur.id], kth[cur.id], 31), cur.id));
}
printf("%lld\n", ans);
return 0;
}