前言
前几天有幸听学长讲平衡树,想着好久没写博客了,记录一下。
简介
Splay,平衡树的一种,依靠每次将访问到的点旋到根来保持树的平衡。
并且,Splay 还可以高效解决序列翻转等操作。
实现
前提
以下代码是基于这样的定义的:
struct Tree{int ch[2],val,siz,fa;}nd[MAXN];//表示某一个节点
void pushup(int rt){nd[rt].siz=nd[nd[rt].ch[0]].siz+nd[nd[rt].ch[1]].siz+1;}//更新某个节点子树的大小
int chk(int rt){return rt==nd[nd[rt].fa].ch[1];}//返回 rt 是左孩子还是右孩子
旋转
首先,对于依赖旋转的平衡树,这个操作是十分重要的。我们可以通过图片来理解,如何将父子关系互换并且不违背二叉搜索树的性质。
旋转有两种,对于左儿子用右旋,右儿子用左旋,下面以右旋为例。
首先最关心的是二叉搜索树的性质,我们可以发现,原来有:
在旋转后仍然满足:
\[uls'<u'<urs'<fa'<fars' \]因此旋转没有破坏二叉搜索树的美妙性质。那我们来看旋转的实现:
- 把 \(u\) 变成 \(fa\) 的父亲的儿子,\(u\) 的父亲更改为 \(fa\) 的父亲;
- 把 \(u\) 的右儿子变成 \(fa\) 的左儿子;
- 互换 \(u\) 与 \(fa\) 的父子关系。
左旋也一样,左右儿子倒过来就可以,因此我们可以根据 \(u\) 初始是左儿子还是右儿子,将两种旋转合并。
void rot(int rt){
int p=nd[rt].fa,g=nd[p].fa,d=chk(rt);
nd[g].ch[chk(p)]=rt;nd[rt].fa=g;
nd[p].ch[d]=nd[rt].ch[d^1];
nd[nd[rt].ch[d^1]].fa=p;
nd[rt].ch[d^1]=p;nd[p].fa=rt;
pushup(p);pushup(rt);//注意最后的更新
}
由于旋转是平衡树的基本操作,所以这里就先这样,主要理解双旋对 Splay 的优化。
双旋
双旋关心节点 \(u\) 的父亲的儿子类型与 \(u\) 的关系。
对此,我们可以分成 \(3\) 种情况:
- 父亲是根节点。此时直接旋转就可以了。
- 如果父亲的儿子类型与 \(u\) 相同。那么就先旋父亲,再旋 \(u\)。
- 如果不同。那么就把 \(u\) 旋转 \(2\) 次。
为什么要这么麻烦呢?前面说过,Splay 是把访问的节点旋到根来维护平衡的,那我直接一个一个旋不就好了?为什么要定义一个双旋呢?
很简单,来看一个例子:
如果我访问顺序是 \(5\to 4\to 3\cdots 1\to 5\to 4\cdots\),那么可以发现,如果只是用单旋,每次查找为 \(\mathcal{O}(n)\),而用单旋不能改变链的事实,所以总的复杂度会高达 \(\mathcal{O}(n^2)\),直接 GG。(不理解可以自己手动模拟一下,发现每次旋到根后,整体还是同样形状的链)
那如果采用双旋呢?
可以看一下第一次操作如果用双旋结果是什么(把 \(5\) 转到根上)。
可以看到改变了链的形式,使得高度《大大》降低。
所以采用双旋,可以有效规避在链的情况下出现时间爆炸的情况。
所以我们采用双旋来实现 Splay,而对于把一个节点旋到根的操作,我们称之 \(Splay\) 操作(((
void splay(int rt){
while(nd[rt].fa!=gl){
int p=nd[rt].fa,g=nd[p].fa;
if(g==0) rot(rt);//case 1
else if(chk(rt)==chk(p)) rot(p),rot(rt);//case 2
else rot(rt),rot(rt);// case 3
}root=rt;
}
别的操作
有了 \(Splay\) 操作,就基本完成了 Splay,接下来就是一些比较细节的,和别的平衡树异曲同工的操作了。只要记住,对于所有操作,我们只要对目标点操作完后 \(Splay\) 一下就可以了,非常舒服。
插入
void ins(int rt,int val,int f){
if(!rt){
rt=++tot;nd[rt].val=val;nd[rt].siz=1;
nd[rt].fa=f;nd[f].ch[val>=nd[f].val]=rt;
splay(rt,0);return;
}
int d=(nd[rt].val<=val);
ins(nd[rt].ch[d],val,rt);
}
查找第 k 大
int find(int rt,int k){
pushdown(rt);
int cur=nd[nd[rt].ch[0]].siz+1;
//如果是有重复元素的并且记录了 cnt 的,这里的 1 要改 nd[rt].cnt
if(k<cur) return find(nd[rt].ch[0],k);
else if(k==cur) return rt;
else return find(nd[rt].ch[1],k-cur);
}
To be continued
序列上的 Splay
先来看个题哈:P3391 【模板】文艺平衡树
这题需要支持区间翻转,并输出最终结果。那如果暴力的话是 \(\mathcal{O}(n^2)\) 的,显然 TLE。
这时候,我们考虑把位置作为权值,建立一棵 Splay。如果对 \([l,r]\) 翻转,那么我们就在书上查找 \(l-1\) 和 \(r+1\),然后把 \(l-1\) 转到根,把 \(r+1\) 转到根的右儿子,此时可以发现,根据二叉搜索树的性质,区间 \([l,r]\) 就是以根的右儿子的左儿子为根的子树。
然后我们在这里记一个 \(tag\),之后如果向下访问了,就 \(pushdown\) 即可。
呃,一点小问题,我们知道 \(Splay\) 操作是可以把节点旋转到根的,那怎么旋转到根的右儿子呢?我们可以多加一个 \(gl\) 参数,表示目标点的父亲。
void splay(int rt,int gl){
while(nd[rt].fa!=gl){
int p=nd[rt].fa,g=nd[p].fa;
if(g==gl) rot(rt);
else if(chk(rt)==chk(p)) rot(p),rot(rt);
else rot(rt),rot(rt);
}if(!gl) root=rt;
}
那我们就把这题做完了……
注意,由于我们用到 \(l-1\) 和 \(r+1\),所以需要一个极小值和极大值防止翻转 \([1,n]\) 的时候爆炸~
贴个代码,以示诚意。
Code
#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
using namespace std;
const int MAXN=1e5+10;
int tot,root;
struct Tree{int ch[2],val,siz,fa,rev;}nd[MAXN];
void pushup(int rt){nd[rt].siz=nd[nd[rt].ch[0]].siz+nd[nd[rt].ch[1]].siz+1;}
int chk(int rt){return rt==nd[nd[rt].fa].ch[1];}
void pushdown(int rt){
if(nd[rt].rev==0) return;
nd[nd[rt].ch[0]].rev^=1;
nd[nd[rt].ch[1]].rev^=1;
nd[rt].rev=0;
swap(nd[rt].ch[0],nd[rt].ch[1]);
}
void rot(int rt){
int p=nd[rt].fa,g=nd[p].fa,d=chk(rt);
nd[g].ch[chk(p)]=rt;nd[rt].fa=g;
nd[p].ch[d]=nd[rt].ch[d^1];
nd[nd[rt].ch[d^1]].fa=p;
nd[rt].ch[d^1]=p;nd[p].fa=rt;
pushup(p);pushup(rt);
}
void splay(int rt,int gl){
while(nd[rt].fa!=gl){
int p=nd[rt].fa,g=nd[p].fa;
if(g==gl) rot(rt);
else if(chk(rt)==chk(p)) rot(p),rot(rt);
else rot(rt),rot(rt);
}
if(!gl) root=rt;
}
void ins(int rt,int val,int f){
if(!rt){
rt=++tot;nd[rt].val=val;nd[rt].siz=1;
nd[rt].fa=f;nd[f].ch[val>=nd[f].val]=rt;
splay(rt,0);return;
}
int d=(nd[rt].val<=val);
ins(nd[rt].ch[d],val,rt);
}
int find(int rt,int k){
pushdown(rt);
int cur=nd[nd[rt].ch[0]].siz+1;
if(k<cur) return find(nd[rt].ch[0],k);
else if(k==cur) return rt;
else return find(nd[rt].ch[1],k-cur);
}
int n,m;
void print(int rt){
pushdown(rt);
if(nd[rt].ch[0]) print(nd[rt].ch[0]);
if(nd[rt].val-1>=1&&nd[rt].val-1<=n) printf("%d ",nd[rt].val-1);
if(nd[rt].ch[1]) print(nd[rt].ch[1]);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n+2;i++) ins(root,i,0);
//这里和讲的不太一样,由于查找 0 比较麻烦,所以还是将整个数组向后移动一个,把 1 和 n+2 当成边界
int l,r;
while(m--){
scanf("%d%d",&l,&r);
l=find(root,l);r=find(root,r+2);
splay(l,0); splay(r,l);
nd[nd[nd[root].ch[1]].ch[0]].rev^=1;
}
print(root);
}