TREAP 的实现及应用
概念
treap是tree+heap的合成,顾名思义,它既有堆(heap)的性质,又有二叉查找树(tree)的性质.
treap中每个节点有两个属性 A A A和 B B B,其中 A A A属性符合堆的性质:任意一个节点的 A A A均小于(或大于)其儿子的 A A A;
而属性 B B B符合二叉查找树的性质:任意节点的 B B B均大于其左子树中所有节点的 B B B值,且小于其右子树中所有节点的 B B B值。
如下图所示:字母表示A属性,数字表示B属性。
让我们回顾一下二叉查找树的构建过程,如果输入数据依次给出每个节点的值,而节点的值恰好是严格递增或者严格递减的,我们会得到一条链。而如果我们随机的打乱输入顺序,即在构建二叉查找树时,按随机的顺序插入节点,则很难退化成一条链,更大概率是比较平衡的一棵二叉树。
现在我们给每个节点随机的分配一个优先级,按照优先级由高到低的顺序,选择节点插入到树中,这和刚才随机的打乱输入顺序是一样的,这样我们构建的二叉树基本上也是平衡的。而这棵树正是treap,优先级是A属性,节点的值是B属性;A属性符合堆的性质,B属性符合二叉查找树的性质。
treap中,优先级是随机分配的,它通过随机性来保证不会过度失衡,但也不能保证它严格平衡。所以,treap是一种弱平衡的树。
看一道例题:普通平衡树
题目描述 普通平衡树
你需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入一个整数x
2. 删除一个整数x(若有多个相同的数,只删除一个)
3. 查询整数x的排名(若有多个相同的数,输出最小的排名),相同的数依次排名,不并列排名
4. 查询排名为x的数,排名的概念同3
5. 求x的前驱(前驱定义为小于x,且最大的数),保证x有前驱
6. 求x的后继(后继定义为大于x,且最小的数),保证x有后继
输入格式
第一行为n,表示操作的个数(n <= 500000)
下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6, -10^7 <= x <= 10^7)
大规模输入数据,建立读入优化
输出格式
对于操作3,4,5,6每行输出一个数,表示对应答案
数据规模
$ 1 \leq n \leq 500000$
采用treap来解决这个题。
节点
节点用结构体表示,其中包括左右儿子,节点的值,节点的优先级等
节点的优先级采用随机数。
struct node{
int ch[2], val, pri, sz;
}arr[MAXN];
旋转操作
旋转操作有一个技巧,如果是从上而下递归做的,以旋转边上方的点作为参数,实现起来相对方便。
void rote(int &r, int flg){
int t = arr[r].ch[1-flg];
arr[r].ch[1-flg] = arr[t].ch[flg];
arr[t].ch[flg] = r;
pushup(r);
pushup(t);
r = t;
}
插入节点
插入函数与普通的二叉查找树差不多,只是当儿子的优先级大于父亲的优先级时,需要做旋转。
void insert(int &r, int x){
if(r == 0){
arr[++pcnt].val = x, arr[pcnt].pri = rand(), ++arr[pcnt].sz;
r = pcnt;
pushup(r);
return;
}
if(x <= arr[r].val) {
insert(arr[r].ch[0], x);
if(arr[r].ch[0] && arr[r].pri > arr[arr[r].ch[0]].pri) rote(r, 1); //right tot
}
else {
insert(arr[r].ch[1], x);
if(arr[r].ch[1] && arr[r].pri > arr[arr[r].ch[1]].pri) rote(r, 0); // left rot
}
pushup(r);
}
删除操作
删除操作和普通的二叉查找树完全一样,不需要旋转。
首先递归地找到待删除节点,找到了以后,判断一下:若待删除节点不足两个儿子时,直接删除,儿子取代它的位置;否则,在左子树中找最大的一个节点(该节点一定没有右儿子),将该节点剪切下来替代待删除节点。
这里有一个地方要格外注意,如果树中存在多个值相同的节点,要注意删除操作是一次删一个还是一次删掉所有。
更好的方法是将多个值相同的节点合并为一个节点,增加一个属性,用于表示该节点出现的次数。
//此处treap中存储了重复节点,删除时采用的由下网上逐一替换,最终保证删除1个节点
int del(int &r, int x){
int tmp;
if(arr[r].val == x || arr[r].val > x && arr[r].ch[0] == 0 || arr[r].val < x && arr[r].ch[1] == 0){
if(arr[r].ch[0] == 0 || arr[r].ch[1] == 0) {
tmp = arr[r].val;
r = arr[r].ch[0] + arr[r].ch[1];
return tmp;
}
else{
tmp = arr[r].val;
arr[r].val = del(arr[r].ch[0], x);
}
}
else if(x < arr[r].val) tmp = del(arr[r].ch[0], x);
else tmp = del(arr[r].ch[1], x);
pushup(r);
return tmp;
}
查找第x个元素
int xth(int r, int x){
if(arr[arr[r].ch[0]].sz >= x) return xth(arr[r].ch[0], x);
else if(arr[arr[r].ch[0]].sz + 1 >= x) return arr[r].val;
else return xth(arr[r].ch[1], x - arr[arr[r].ch[0]].sz - 1);
}
询问x的排名
int getrank(int r, int x){
if(r == 0) return 1;
if(x > arr[r].val)
return arr[arr[r].ch[0]].sz + 1 + getrank(arr[r].ch[1], x);
else return getrank(arr[r].ch[0], x);
}
查找前驱
int getpre(int r, int x){
if(r == 0) return -MOD;
if(arr[r].val >= x) return getpre(arr[r].ch[0], x);
else return max(arr[r].val, getpre(arr[r].ch[1], x));
}
查找后继
int getnxt(int r, int x){
if(r == 0) return MOD;
if(arr[r].val <= x)return getnxt(arr[r].ch[1], x);
else return min(arr[r].val, getnxt(arr[r].ch[0], x));
}
完整代码如下:
#include <bits/stdc++.h>
using namespace std;
#define MAXN 1000005
#define MOD 1000000000
struct node{
int ch[2], val, pri, sz;
}arr[MAXN];
int n, m, opt, x, pcnt, rt;
int cntres;
void pushup(int r){
if(r) arr[r].sz = arr[arr[r].ch[0]].sz + arr[arr[r].ch[1]].sz + 1;
}
void rote(int &r, int flg){
int t = arr[r].ch[1-flg];
arr[r].ch[1-flg] = arr[t].ch[flg];
arr[t].ch[flg] = r;
pushup(r);
pushup(t);
r = t;
}
void insert(int &r, int x){
if(r == 0){
arr[++pcnt].val = x, arr[pcnt].pri = rand(), ++arr[pcnt].sz;
r = pcnt;
pushup(r);
return;
}
if(x <= arr[r].val) {
insert(arr[r].ch[0], x);
if(arr[r].ch[0] && arr[r].pri > arr[arr[r].ch[0]].pri) rote(r, 1); //right tot
}
else {
insert(arr[r].ch[1], x);
if(arr[r].ch[1] && arr[r].pri > arr[arr[r].ch[1]].pri) rote(r, 0); // left rot
}
pushup(r);
}
int del(int &r, int x){
int tmp;
if(arr[r].val == x || arr[r].val > x && arr[r].ch[0] == 0 || arr[r].val < x && arr[r].ch[1] == 0){
if(arr[r].ch[0] == 0 || arr[r].ch[1] == 0) {
tmp = arr[r].val;
r = arr[r].ch[0] + arr[r].ch[1];
return tmp;
}
else{
tmp = arr[r].val;
arr[r].val = del(arr[r].ch[0], x);
}
}
else if(x < arr[r].val) tmp = del(arr[r].ch[0], x);
else tmp = del(arr[r].ch[1], x);
pushup(r);
return tmp;
}
bool find(int r, int x){
if(r == 0) return 0;
if(arr[r].val == x) return 1;
else if(x < arr[r].val) return find(arr[r].ch[0], x);
else return find(arr[r].ch[1], x);
}
int xth(int r, int x){
if(arr[arr[r].ch[0]].sz >= x) return xth(arr[r].ch[0], x);
else if(arr[arr[r].ch[0]].sz + 1 >= x) return arr[r].val;
else return xth(arr[r].ch[1], x - arr[arr[r].ch[0]].sz - 1);
}
int getrank(int r, int x){
if(r == 0) return 1;
if(x > arr[r].val) return arr[arr[r].ch[0]].sz + 1 + getrank(arr[r].ch[1], x);
else return getrank(arr[r].ch[0], x);
}
int getpre(int r, int x){
if(r == 0) return -MOD;
if(arr[r].val >= x) return getpre(arr[r].ch[0], x);
else return max(arr[r].val, getpre(arr[r].ch[1], x));
}
int getnxt(int r, int x){
if(r == 0) return MOD;
if(arr[r].val <= x)return getnxt(arr[r].ch[1], x);
else return min(arr[r].val, getnxt(arr[r].ch[0], x));
}
int main(){
srand(time(0));
int rescnt = 0, res = 0;
scanf("%d", &n);
for(int i = 1; i <= n; i++){
scanf("%d %d", &opt, &x);
if(opt == 1){insert(rt, x); cntres++;}
else if(opt == 2) { del(rt, x), --cntres;}
else if(opt == 3) printf("%d\n",getrank(rt, x));
else if(opt == 4) printf("%d\n", xth(rt, x));
else if(opt == 5) printf("%d\n", getpre(rt, x));
else printf("%d\n", getnxt(rt, x));
}
return 0;
}
非旋转的treap
非旋转treap是由范浩强大佬发明的,它摈弃了旋转操作,而是采用分裂和合并操作来作为基本操作,实现插入、删除节点的同时,维护二叉查找树的有序性和堆的性质。
如下图所示,这是一棵treap,每个节点中的第一个数字为权值,第二个数字为优先级。
权值满足二叉查找树的性质,优先级满足堆的性质。
split操作
分裂操作是非旋转treap中最难理解的操作,它实质上是通过一条路径,将树分成两棵树。我们插入节点、或删除节点,都需要先找到一个插入节点的位置或者删除节点的位置,于是就得到了一条从根到该位置的路径(路径上有可能只有1个点)。
以插入一个权值为20的节点为例:首先从根开始,找到新节点对应的位置,它的位置应该位于节点 J J J的右儿子处。于是我们得到一条从 A A A到 J J J的路径,将这条路径上的边全部断开,小于等于20的点分到左侧;大于20的点分到右侧,于是得到了两棵树:如下图:
接下来生成权值为 20 20 20的节点 Q Q Q,优先级假设为 3 3 3.
接下来我们需要用到合并操作merge。
merge操作
合并操作的对象是两棵树,这两棵树一定满足,左边的树权最大值小于右边的树的权值最小值。我们根据其优先级来合并。为了描述方面,我们设左边的树为 L L L,右边的树为 R R R, 首先比较两棵树的树根,谁优先级小,谁就作为新的树根,假设 L L L的优先级较小,则问题转换为 L L L的右子树与 R R R的合并问题了;否则就是 R R R的根作为新树的根,问题转换为 L L L和 R . r s o n R.rson R.rson的合并问题了,这样递归下去,直到某棵树问空,则递归结束。
合并操作比较简单,就不上图了。
像上面把插入节点Q的代码,我们先将子树A和Q合并,再将合并的新树继续与子树C合并即可。
删除操作的代码,也是可以用 s p l i t split split和 m e r g e merge merge操作来完成的。比如要删除权值为 x x x的节点。
先通过一次split操作,将权值小于等于 x x x和权值大于 x x x的节点分开,然后再将权值等于 x x x的和小于 x x x的分开,此时得到了三棵树。
如果权值等于 x x x的节点不止一个,则我们可以将权值等于 x x x的那颗树的左子树和右子树合并,这样就等于删掉了根节点。
现在还是三棵树,我们将它们从左到右依次合并为一棵树。这样就完成了删除任务。我们需要两次 s p l i t split split和三次 m e r g e merge merge操作。
而如果要删除所有权值为 x x x的节点,则我们只需要两次 s p l i t split split和一次 m e r g e merge merge操作.
模板题 普通平衡树
你需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入一个整数x
2. 删除一个整数x(若有多个相同的数,只删除一个)
3. 查询整数x的排名(若有多个相同的数,输出最小的排名),相同的数依次排名,不并列排名
4. 查询排名为x的数,排名的概念同3
5. 求x的前驱(前驱定义为小于x,且最大的数),保证x有前驱
6. 求x的后继(后继定义为大于x,且最小的数),保证x有后继
输入格式
第一行为n,表示操作的个数(n <= 500000)
下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6, -10^7 <= x <= 10^7)
大规模输入数据,建立读入优化
输出格式
对于操作3,4,5,6每行输出一个数,表示对应答案
数据规模
$ 1 \leq n \leq 500000$
采用非旋转treap来解决这个题
#include <bits/stdc++.h>
using namespace std;
#define MAXN 1000005
#define INF 999999999
int n;
struct node{
int ch[2], val, pri, sz;
}tree[MAXN];
int rt, root1, root2, tot;
void update(int rt){
if(rt == 0) return;
tree[rt].sz = tree[tree[rt].ch[0]].sz + tree[tree[rt].ch[1]].sz + 1;
}
void split(int rt, int &xroot, int &yroot, int v){
if(rt == 0) xroot = yroot = 0;
else if(v < tree[rt].val) yroot = rt, split(tree[rt].ch[0], xroot, tree[yroot].ch[0], v);
else xroot = rt, split(tree[rt].ch[1], tree[xroot].ch[1], yroot, v);
update(rt);
}
void merge(int &rt, int xroot, int yroot){
if(xroot == 0 || yroot == 0) rt = xroot + yroot;
else if(tree[xroot].pri < tree[yroot].pri){
rt = xroot;
merge(tree[rt].ch[1], tree[rt].ch[1], yroot);
}
else {
rt = yroot;
merge(tree[rt].ch[0], xroot, tree[rt].ch[0]);
}
update(rt);
}
void insert(int &rt, int v){
split(rt, root1, root2, v);
tree[++tot].val = v, tree[tot].pri = rand(), tree[tot].sz = 1, rt = tot;
merge(root1, root1, tot);
merge(rt, root1, root2);
}
void del(int &rt, int v){
int z;
split(rt, root1, root2, v);
split(root1, root1, z, v - 1);
merge(z, tree[z].ch[0], tree[z].ch[1]);
merge(rt, root1, z);
merge(rt, rt, root2);
}
int getxth(int rt, int v){
if(v <= tree[tree[rt].ch[0]].sz) return getxth(tree[rt].ch[0], v);
else if(v <= tree[tree[rt].ch[0]].sz + 1) return tree[rt].val;
else return getxth(tree[rt].ch[1], v - tree[tree[rt].ch[0]].sz - 1);
}
int getrank(int rt, int v){
if(rt == 0) return 1;
if(v <= tree[rt].val) return getrank(tree[rt].ch[0], v);
else return tree[tree[rt].ch[0]].sz + 1 + getrank(tree[rt].ch[1], v);
}
int getpre(int rt, int v){
if(rt == 0) return -INF;
if(v <= tree[rt].val) return getpre(tree[rt].ch[0], v);
else{
return max(tree[rt].val, getpre(tree[rt].ch[1], v));
}
}
int getnxt(int rt, int v){
if(rt == 0) return INF;
if(v >= tree[rt].val) return getnxt(tree[rt].ch[1], v);
else return min(tree[rt].val, getnxt(tree[rt].ch[0], v));
}
int main(){
int opt, x;
scanf("%d", &n);
for(int i = 1; i <= n; i++){
scanf("%d %d", &opt, &x);
switch (opt)
{
case 1: insert(rt, x); break;
case 2: del(rt, x); break;
case 4: printf("%d\n",getxth(rt, x)); break;
case 3: printf("%d\n",getrank(rt, x)); break;
case 5: printf("%d\n",getpre(rt, x)); break;
default: printf("%d\n", getnxt(rt,x)); break;
}
}
}