不回收被删除的节点的版本,注意MAXN要足够容纳所有的插入操作。
struct Treap {
#define ls ch[id][0]
#define rs ch[id][1]
static const int MAXN = 500000 + 10;
int top, root, ch[MAXN][2], rnd[MAXN];
int val[MAXN], cnt[MAXN], siz[MAXN];
// ll sum[MAXN];
int NewNode(int v, int c) {
int id = ++top;
ls = 0, rs = 0, rnd[id] = rand();
val[id] = v, cnt[id] = c, siz[id] = c;
// sum[id] = 1LL * c * v;
return id;
}
void PushUp(int id) {
siz[id] = siz[ls] + siz[rs] + cnt[id];
// sum[id] = sum[ls] + sum[rs] + 1LL * cnt[id] * val[id];
}
void Rotate(int &id, int d) {
int tmp = ch[id][d ^ 1];
ch[id][d ^ 1] = ch[tmp][d];
ch[tmp][d] = id, id = tmp;
PushUp(ch[id][d]), PushUp(id);
}
/* Insert c nodes with value v */
void InsertHelp(int &id, int v, int c) {
if(!id) {
id = NewNode(v, c);
return;
}
if(v == val[id])
cnt[id] += c;
else {
int d = val[id] > v ? 0 : 1;
InsertHelp(ch[id][d], v, c);
if(rnd[id] < rnd[ch[id][d]])
Rotate(id, d ^ 1);
}
PushUp(id);
}
void RemoveHelp(int &id, int v, int c) {
if(!id)
return;
if(v == val[id]) {
if(cnt[id] > c)
cnt[id] -= c;
else if(ls || rs) {
int d = (!rs || rnd[ls] > rnd[rs]);
Rotate(id, d), RemoveHelp(ch[id][d], v, c);
} else {
id = 0;
return;
}
} else {
int d = (val[id] < v);
RemoveHelp(ch[id][d], v, c);
}
PushUp(id);
}
void Init() {
top = 0, root = 0;
}
/* Insert c nodes with value v */
void Insert(int v, int c = 1) {
InsertHelp(root, v, c);
}
/* Remove at most c nodes with value v */
void Remove(int v, int c = 1) {
RemoveHelp(root, v, c);
}
/* "Rank of value v" means the first node with value >= v */
/* Get the rank of value v */
int GetRank(int v) {
int id = root, res = 1;
while(id) {
if(val[id] > v)
id = ls;
else if(val[id] == v) {
res += siz[ls];
break;
} else {
res += siz[ls] + cnt[id];
id = rs;
}
}
return res;
}
/* Get the value with rank r */
int GetValue(int r) {
int id = root, res = INF;
while(id) {
if(siz[ls] >= r)
id = ls;
else if(siz[ls] + cnt[id] >= r) {
res = val[id];
break;
} else {
r -= siz[ls] + cnt[id];
id = rs;
}
}
return res;
}
/* Get the value of last node with value < v */
int GetPrev(int v) {
int id = root, res = -INF;
while(id) {
if(val[id] < v)
res = val[id], id = rs;
else
id = ls;
}
return res;
}
/* Get the value of first node with value > v */
int GetNext(int v) {
int id = root, res = INF;
while(id) {
if(val[id] > v)
res = val[id], id = ls;
else
id = rs;
}
return res;
}
// /* Get the sum of nodes with value <= v. */
// ll GetSumValue(int v) {
// int id = root;
// ll res = 0;
// while(id) {
// if(val[id] > v)
// id = ls;
// else if(val[id] == v) {
// res += sum[ls] + 1LL * cnt[id] * val[id];
// break;
// } else {
// res += sum[ls] + 1LL * cnt[id] * val[id];
// id = rs;
// }
// }
// return res;
// }
//
// /* Get the sum of the first r nodes */
// ll GetSumRank(int r) {
// int id = root;
// ll res = 0;
// while(id) {
// if(siz[ls] >= r)
// id = ls;
// else if(siz[ls] + cnt[id] >= r) {
// res += sum[ls] + 1LL * (r - siz[ls]) * val[id];
// break;
// } else {
// res += sum[ls] + 1LL * cnt[id] * val[id];
// r -= siz[ls] + cnt[id];
// id = rs;
// }
// }
// return res;
// }
#undef ls
#undef rs
} treap;