treap的实现及应用

TREAP 的实现及应用

概念

treap是tree+heap的合成,顾名思义,它既有堆(heap)的性质,又有二叉查找树(tree)的性质.

treap中每个节点有两个属性 A A A和 B B B,其中 A A A属性符合堆的性质:任意一个节点的 A A A均小于(或大于)其儿子的 A A A;

而属性 B B B符合二叉查找树的性质:任意节点的 B B B均大于其左子树中所有节点的 B B B值,且小于其右子树中所有节点的 B B B值。

如下图所示:字母表示A属性,数字表示B属性。

treap的实现及应用

让我们回顾一下二叉查找树的构建过程,如果输入数据依次给出每个节点的值,而节点的值恰好是严格递增或者严格递减的,我们会得到一条链。而如果我们随机的打乱输入顺序,即在构建二叉查找树时,按随机的顺序插入节点,则很难退化成一条链,更大概率是比较平衡的一棵二叉树。

现在我们给每个节点随机的分配一个优先级,按照优先级由高到低的顺序,选择节点插入到树中,这和刚才随机的打乱输入顺序是一样的,这样我们构建的二叉树基本上也是平衡的。而这棵树正是treap,优先级是A属性,节点的值是B属性;A属性符合堆的性质,B属性符合二叉查找树的性质。

treap中,优先级是随机分配的,它通过随机性来保证不会过度失衡,但也不能保证它严格平衡。所以,treap是一种弱平衡的树。
看一道例题:普通平衡树

题目描述 普通平衡树

你需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入一个整数x
2. 删除一个整数x(若有多个相同的数,只删除一个)
3. 查询整数x的排名(若有多个相同的数,输出最小的排名),相同的数依次排名,不并列排名
4. 查询排名为x的数,排名的概念同3
5. 求x的前驱(前驱定义为小于x,且最大的数),保证x有前驱
6. 求x的后继(后继定义为大于x,且最小的数),保证x有后继

输入格式

第一行为n,表示操作的个数(n <= 500000)

下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6,   -10^7 <= x <= 10^7)

大规模输入数据,建立读入优化

输出格式

对于操作3,4,5,6每行输出一个数,表示对应答案

数据规模

$ 1 \leq n \leq 500000$

采用treap来解决这个题。

节点

节点用结构体表示,其中包括左右儿子,节点的值,节点的优先级等
节点的优先级采用随机数。

struct node{
    int ch[2], val, pri, sz;
}arr[MAXN];

旋转操作

旋转操作有一个技巧,如果是从上而下递归做的,以旋转边上方的点作为参数,实现起来相对方便。

void rote(int &r, int flg){
    int t = arr[r].ch[1-flg];
    arr[r].ch[1-flg] = arr[t].ch[flg];
    arr[t].ch[flg] = r;
    pushup(r);
    pushup(t);
    r = t;
}

插入节点

插入函数与普通的二叉查找树差不多,只是当儿子的优先级大于父亲的优先级时,需要做旋转。

void insert(int &r,  int x){
    if(r == 0){
        arr[++pcnt].val = x, arr[pcnt].pri = rand(), ++arr[pcnt].sz;
        r = pcnt;
        pushup(r);
        return;
    }
    if(x <= arr[r].val) {
        insert(arr[r].ch[0], x);
         if(arr[r].ch[0] && arr[r].pri > arr[arr[r].ch[0]].pri) rote(r, 1); //right tot
    }
    else  {
        insert(arr[r].ch[1], x);
         if(arr[r].ch[1] && arr[r].pri > arr[arr[r].ch[1]].pri) rote(r, 0); // left rot
    }
    pushup(r);
}

删除操作

删除操作和普通的二叉查找树完全一样,不需要旋转。
首先递归地找到待删除节点,找到了以后,判断一下:若待删除节点不足两个儿子时,直接删除,儿子取代它的位置;否则,在左子树中找最大的一个节点(该节点一定没有右儿子),将该节点剪切下来替代待删除节点。
这里有一个地方要格外注意,如果树中存在多个值相同的节点,要注意删除操作是一次删一个还是一次删掉所有。
更好的方法是将多个值相同的节点合并为一个节点,增加一个属性,用于表示该节点出现的次数。

//此处treap中存储了重复节点,删除时采用的由下网上逐一替换,最终保证删除1个节点
int del(int &r, int x){
    int tmp;
    if(arr[r].val == x || arr[r].val > x && arr[r].ch[0] == 0 || arr[r].val < x && arr[r].ch[1] == 0){
        if(arr[r].ch[0] == 0 || arr[r].ch[1] == 0) {
            tmp = arr[r].val;
            r = arr[r].ch[0] + arr[r].ch[1];
            return tmp;
        }
        else{
            tmp = arr[r].val; 
            arr[r].val = del(arr[r].ch[0], x);
        }
    }
    else if(x < arr[r].val) tmp = del(arr[r].ch[0], x);
    else tmp = del(arr[r].ch[1], x);
    pushup(r);
    return tmp;
}

查找第x个元素

int xth(int r, int x){
    if(arr[arr[r].ch[0]].sz >= x) return xth(arr[r].ch[0], x);
    else if(arr[arr[r].ch[0]].sz + 1 >= x) return arr[r].val;
    else return xth(arr[r].ch[1], x - arr[arr[r].ch[0]].sz - 1);
}

询问x的排名

int getrank(int r, int x){
    if(r == 0) return 1;
    if(x > arr[r].val) 
    return arr[arr[r].ch[0]].sz + 1 + getrank(arr[r].ch[1], x);
    else return getrank(arr[r].ch[0], x);
}

查找前驱

int getpre(int r, int x){
    if(r == 0) return -MOD;
    if(arr[r].val >= x) return getpre(arr[r].ch[0], x);
    else return max(arr[r].val, getpre(arr[r].ch[1], x));
}

查找后继

int getnxt(int r, int x){
    if(r == 0) return MOD;
    if(arr[r].val <= x)return getnxt(arr[r].ch[1], x);
    else return min(arr[r].val, getnxt(arr[r].ch[0], x));
}

完整代码如下:

#include <bits/stdc++.h>
using namespace std;
#define MAXN 1000005
#define MOD 1000000000
struct node{
    int ch[2], val, pri, sz;
}arr[MAXN];
int n, m, opt, x, pcnt, rt;
int cntres;
void pushup(int r){
    if(r) arr[r].sz = arr[arr[r].ch[0]].sz + arr[arr[r].ch[1]].sz + 1;
}
void rote(int &r, int flg){
    int t = arr[r].ch[1-flg];
    arr[r].ch[1-flg] = arr[t].ch[flg];
    arr[t].ch[flg] = r;
    pushup(r);
    pushup(t);
    r = t;
}
void insert(int &r,  int x){
    if(r == 0){
        arr[++pcnt].val = x, arr[pcnt].pri = rand(), ++arr[pcnt].sz;
        r = pcnt;
        pushup(r);
        return;
    }
    if(x <= arr[r].val) {
        insert(arr[r].ch[0], x);
         if(arr[r].ch[0] && arr[r].pri > arr[arr[r].ch[0]].pri) rote(r, 1); //right tot
    }
    else  {
        insert(arr[r].ch[1], x);
         if(arr[r].ch[1] && arr[r].pri > arr[arr[r].ch[1]].pri) rote(r, 0); // left rot
    }
    pushup(r);
}
int del(int &r, int x){
    int tmp;
    if(arr[r].val == x || arr[r].val > x && arr[r].ch[0] == 0 || arr[r].val < x && arr[r].ch[1] == 0){
        if(arr[r].ch[0] == 0 || arr[r].ch[1] == 0) {
            tmp = arr[r].val;
            r = arr[r].ch[0] + arr[r].ch[1];
            return tmp;
        }
        else{
            tmp = arr[r].val;
            arr[r].val = del(arr[r].ch[0], x);
        }
    }
    else if(x < arr[r].val) tmp = del(arr[r].ch[0], x);
    else tmp = del(arr[r].ch[1], x);
    pushup(r);
    return tmp;
}
bool find(int r, int x){
    if(r == 0) return 0;
    if(arr[r].val == x) return 1;
    else if(x < arr[r].val) return find(arr[r].ch[0], x);
    else return find(arr[r].ch[1], x);
}
int xth(int r, int x){
    if(arr[arr[r].ch[0]].sz >= x) return xth(arr[r].ch[0], x);
    else if(arr[arr[r].ch[0]].sz + 1 >= x) return arr[r].val;
    else return xth(arr[r].ch[1], x - arr[arr[r].ch[0]].sz - 1);
}
int getrank(int r, int x){
    if(r == 0) return 1;
    if(x > arr[r].val) return arr[arr[r].ch[0]].sz + 1 + getrank(arr[r].ch[1], x);
    else return getrank(arr[r].ch[0], x);
}
int getpre(int r, int x){
    if(r == 0) return -MOD;
    if(arr[r].val >= x) return getpre(arr[r].ch[0], x);
    else return max(arr[r].val, getpre(arr[r].ch[1], x));
}
int getnxt(int r, int x){
    if(r == 0) return MOD;
    if(arr[r].val <= x)return getnxt(arr[r].ch[1], x);
    else return min(arr[r].val, getnxt(arr[r].ch[0], x));
}
int main(){
    srand(time(0));
    int rescnt = 0, res = 0;
    scanf("%d", &n);
    for(int i = 1; i <= n; i++){
        scanf("%d %d", &opt, &x);
        if(opt == 1){insert(rt, x); cntres++;}
        else if(opt == 2) { del(rt, x), --cntres;}
        else if(opt == 3) printf("%d\n",getrank(rt, x));  
        else if(opt == 4) printf("%d\n", xth(rt, x));
        else if(opt == 5) printf("%d\n", getpre(rt, x));
        else printf("%d\n", getnxt(rt, x)); 
    }
    return 0;
}

非旋转的treap

非旋转treap是由范浩强大佬发明的,它摈弃了旋转操作,而是采用分裂和合并操作来作为基本操作,实现插入、删除节点的同时,维护二叉查找树的有序性和堆的性质。

如下图所示,这是一棵treap,每个节点中的第一个数字为权值,第二个数字为优先级。

权值满足二叉查找树的性质,优先级满足堆的性质。

split操作

分裂操作是非旋转treap中最难理解的操作,它实质上是通过一条路径,将树分成两棵树。我们插入节点、或删除节点,都需要先找到一个插入节点的位置或者删除节点的位置,于是就得到了一条从根到该位置的路径(路径上有可能只有1个点)。

以插入一个权值为20的节点为例:首先从根开始,找到新节点对应的位置,它的位置应该位于节点 J J J的右儿子处。于是我们得到一条从 A A A到 J J J的路径,将这条路径上的边全部断开,小于等于20的点分到左侧;大于20的点分到右侧,于是得到了两棵树:如下图:

treap的实现及应用

接下来生成权值为 20 20 20的节点 Q Q Q,优先级假设为 3 3 3.

接下来我们需要用到合并操作merge

merge操作

合并操作的对象是两棵树,这两棵树一定满足,左边的树权最大值小于右边的树的权值最小值。我们根据其优先级来合并。为了描述方面,我们设左边的树为 L L L,右边的树为 R R R, 首先比较两棵树的树根,谁优先级小,谁就作为新的树根,假设 L L L的优先级较小,则问题转换为 L L L的右子树与 R R R的合并问题了;否则就是 R R R的根作为新树的根,问题转换为 L L L和 R . r s o n R.rson R.rson的合并问题了,这样递归下去,直到某棵树问空,则递归结束。

合并操作比较简单,就不上图了。

像上面把插入节点Q的代码,我们先将子树A和Q合并,再将合并的新树继续与子树C合并即可。

删除操作的代码,也是可以用 s p l i t split split和 m e r g e merge merge操作来完成的。比如要删除权值为 x x x的节点。

先通过一次split操作,将权值小于等于 x x x和权值大于 x x x的节点分开,然后再将权值等于 x x x的和小于 x x x的分开,此时得到了三棵树。

如果权值等于 x x x的节点不止一个,则我们可以将权值等于 x x x的那颗树的左子树和右子树合并,这样就等于删掉了根节点。

现在还是三棵树,我们将它们从左到右依次合并为一棵树。这样就完成了删除任务。我们需要两次 s p l i t split split和三次 m e r g e merge merge操作。

而如果要删除所有权值为 x x x的节点,则我们只需要两次 s p l i t split split和一次 m e r g e merge merge操作.

模板题 普通平衡树

你需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入一个整数x
2. 删除一个整数x(若有多个相同的数,只删除一个)
3. 查询整数x的排名(若有多个相同的数,输出最小的排名),相同的数依次排名,不并列排名
4. 查询排名为x的数,排名的概念同3
5. 求x的前驱(前驱定义为小于x,且最大的数),保证x有前驱
6. 求x的后继(后继定义为大于x,且最小的数),保证x有后继

输入格式

第一行为n,表示操作的个数(n <= 500000)

下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6,   -10^7 <= x <= 10^7)

大规模输入数据,建立读入优化

输出格式

对于操作3,4,5,6每行输出一个数,表示对应答案

数据规模

$ 1 \leq n \leq 500000$

采用非旋转treap来解决这个题

#include <bits/stdc++.h>
using namespace std;
#define MAXN 1000005
#define INF 999999999
int n;
struct node{
    int ch[2], val, pri, sz;
}tree[MAXN];
int rt, root1, root2, tot;
void update(int rt){
    if(rt == 0) return;
    tree[rt].sz = tree[tree[rt].ch[0]].sz + tree[tree[rt].ch[1]].sz + 1;
}
void split(int rt, int &xroot, int &yroot, int v){
    if(rt == 0) xroot = yroot = 0;
    else if(v < tree[rt].val) yroot = rt, split(tree[rt].ch[0], xroot, tree[yroot].ch[0], v);
    else xroot = rt, split(tree[rt].ch[1], tree[xroot].ch[1], yroot, v);
    update(rt);
} 

void merge(int &rt, int xroot, int yroot){
    if(xroot == 0 || yroot == 0) rt = xroot + yroot;
    else if(tree[xroot].pri < tree[yroot].pri){
        rt = xroot;
        merge(tree[rt].ch[1], tree[rt].ch[1], yroot);
    }
    else {
        rt = yroot;
        merge(tree[rt].ch[0], xroot, tree[rt].ch[0]);
    }
    update(rt);
}

void insert(int &rt, int v){
    split(rt, root1, root2, v);
    tree[++tot].val = v, tree[tot].pri = rand(), tree[tot].sz = 1, rt = tot;
    merge(root1, root1, tot);
    merge(rt, root1, root2);
}

void del(int &rt, int v){
    int z;
    split(rt, root1, root2, v);
    split(root1, root1, z, v - 1);
    merge(z, tree[z].ch[0], tree[z].ch[1]);
    merge(rt, root1, z);
    merge(rt, rt, root2);
}

int getxth(int rt, int v){
    if(v <= tree[tree[rt].ch[0]].sz) return getxth(tree[rt].ch[0], v);
    else if(v <= tree[tree[rt].ch[0]].sz + 1) return tree[rt].val;
    else return getxth(tree[rt].ch[1], v - tree[tree[rt].ch[0]].sz - 1);
}

int getrank(int rt, int v){
    if(rt == 0) return 1;
    if(v <= tree[rt].val) return getrank(tree[rt].ch[0], v);
    else return tree[tree[rt].ch[0]].sz + 1 + getrank(tree[rt].ch[1], v);
}

int getpre(int rt, int v){
    if(rt == 0) return -INF;
    if(v <= tree[rt].val) return getpre(tree[rt].ch[0], v);
    else{
        return max(tree[rt].val, getpre(tree[rt].ch[1], v));
    }
}

int getnxt(int rt, int v){
    if(rt == 0) return INF;
    if(v >= tree[rt].val) return getnxt(tree[rt].ch[1], v);
    else return min(tree[rt].val, getnxt(tree[rt].ch[0], v));
}
int main(){
    int opt, x;
    scanf("%d", &n);
    for(int i = 1; i <= n; i++){
        scanf("%d %d", &opt, &x);
        switch (opt)
        {
        case 1: insert(rt, x); break;
        case 2: del(rt, x); break;
        case 4: printf("%d\n",getxth(rt, x)); break;
        case 3: printf("%d\n",getrank(rt, x)); break;
        case 5: printf("%d\n",getpre(rt, x)); break;
        default: printf("%d\n", getnxt(rt,x)); break;
        }
    } 
}
上一篇:fhq_treap


下一篇:OI学习日志 12月份