替罪羊树
一种基于部分重建的自平衡二叉搜索树。在替罪羊树上,插入或删除节点的平摊最坏时间复杂度是\(O(log n)\),搜索节点的最坏时间复杂度是\(O(log n)\)。
我们定义一个平衡树因子\(\alpha\)。对于替罪羊树的每个节点\(t\),需要满足\(max(siz[ls],siz[rs]<\alpha * size[t])\),其中\(ls,rs\)分别是\(t\)的左儿子,右儿子。
通俗的来讲,就是要保证每一个节点的左右子树的大小都不超过它本身大小的\(\alpha\)倍,否则就把这个节点及它的子树重构,使其满足这个性质
一般取\(\alpha=0.75\),使其达到最佳性能
#include<bits/stdc++.h>
using namespace std;
const int N = 100005;
struct node
{
int l, r, val, siz, cnt;
}nod[N];
int n, cnt, root;
vector<int> p;
void pushup(int rt)
{
int l = nod[rt].l, r = nod[rt].r;
nod[rt].siz = nod[l].siz + nod[r].siz;
}
int build(int l, int r)//按照线段树的方法建树
{
if(l > r) return 0;
int mid = (l + r) >> 1;
nod[p[mid]].l = build(l, mid - 1);
nod[p[mid]].r = build(mid + 1, r);
pushup(p[mid]);
return p[mid];
}
void dfs(int rt)//要保证大小关系
{
if(rt == 0) return;
dfs(nod[rt].l);
p.push_back(rt);
dfs(nod[rt].r);
}
void rebuild(int &rt)
{
if(rt == 0) return;
if(nod[rt].siz * 0.75 < nod[nod[rt].l].siz || nod[rt].siz * 0.75 < nod[nod[rt].r].siz)
{
p.clear();
p.push_back(-1);//为了保证下标从1开始
dfs(rt);
rt = build(1, p.size() - 1);
}
}
void insert(int &rt, int x)
{
if(rt == 0)
{
rt = ++cnt;
nod[rt].val = x;
nod[rt].siz = nod[rt].cnt = 1;
return;
}
rebuild(rt);
if(nod[rt].val == x)
{
nod[rt].cnt ++;
nod[rt].siz ++;
return;
}
if(nod[rt].val < x)
{
insert(nod[rt].r, x);
pushup(rt);
return;
}
if(nod[rt].val > x)
{
insert(nod[rt].l, x);
pushup(rt);
return;
}
}
int delmin(int &rt)
{
if(nod[rt].l)//向左儿子跳
{
int ret = delmin(nod[rt].l);
pushup(rt);
return ret;
}
int ret = rt;
rt = nod[rt].r;//传址符
return ret;
}
void del(int &rt,int x)
{
if(nod[rt].val > x)
{
del(nod[rt].l, x);
pushup(rt);
}
if(nod[rt].val < x)
{
del(nod[rt].r, x);
pushup(rt);
}
if(nod[rt].val == x)
{
if(nod[rt].cnt > 1)
{
nod[rt].cnt --;
nod[rt].siz --;
return;
}
if(nod[rt].l == 0)
{
rt = nod[rt].r;
return;
}
if(nod[rt].r == 0)
{
rt = nod[rt].l;
return;
}
int tmp = delmin(nod[rt].r);
nod[rt].val = nod[tmp].val;
nod[rt].cnt = nod[tmp].cnt;
pushup(rt);
return;
}
}
int getk(int rt, int x)
{
if(nod[rt].val == x) return nod[nod[rt].l].siz + 1;
if(nod[rt].val < x) return nod[nod[rt].l].siz + nod[rt].cnt + getk(nod[rt].r, x);
if(nod[rt].val > x) return getk(nod[rt].l, x);
}
int getkth(int rt, int x)
{
if(nod[nod[rt].l].siz + 1 <= x && x <= nod[nod[rt].l].siz + nod[rt].cnt) return nod[rt].val;
if(nod[nod[rt].l].siz + 1 > x) return getkth(nod[rt].l, x);
if(nod[nod[rt].l].siz + nod[rt].cnt < x) return getkth(nod[rt].r, x-(nod[nod[rt].l].siz + nod[rt].cnt));
}
int getpre(int rt, int x)
{
int p = rt, ans;
while(p)
{
if(x <= nod[p].val) p = nod[p].l;
else
{
ans = p;
p = nod[p].r;
}
}
return ans;
}
int getsuc(int rt, int x)
{
int p = rt, ans;
while(p)
{
if(nod[p].val <= x) p = nod[p].r;
else
{
ans = p;
p = nod[p].l;
}
}
return ans;
}
int main()
{
scanf("%d", &n);
while(n --)
{
int opt, x;
scanf("%d%d", &opt, &x);
if(opt == 1) insert(root, x);
if(opt == 2) del(root, x);
if(opt == 3) printf("%d\n", getk(root, x));
if(opt == 4) printf("%d\n", getkth(root, x));
if(opt == 5) printf("%d\n", nod[getpre(root, x)].val);
if(opt == 6) printf("%d\n", nod[getsuc(root, x)].val);
}
return 0;
}