Splay Tree(伸展树)
简介
Splay Tree是一种二叉查找树(BST),即满足二叉树上任意一个节点的左儿子权值>自身权值>右儿子权值,它通过旋转操作使得树上单次操作的均摊复杂度为 \(\log n\),由Daniel Sleator和Robert Endre Tarjan(又是Tarjan)发明,希望了解复杂度证明的可以自行查询资料(我不会证)
实现
存储、维护与更新
一般我们把树上的每一个节点都作为结构体存放,节点中维护当前节点对应的值,当前节点对应值的个数,两个子节点的编号,子树中节点的个数以及父节点编号
为了实现方便,常用 \(ch[0]\) 或 \(son[0]\) 表示左儿子,用 \(ch[1]\) 或 \(son[1]\) 表示右儿子
struct Node
{
int v,cnt,siz,fa,ch[2];//节点值,个数,子树及自身节点个数,父亲,儿子 0->left 1->right
};
实现中常常需要改变节点间的关系,我们需要从子节点更新当前节点的信息
void update(int x)
{
node[x].siz=node[node[x].ch[0]].siz+node[node[x].ch[1]].siz+node[x].cnt;
}
旋转
旋转可以说是Splay的基础,利用旋转才能让树保持平衡,Splay中有左旋(Left Rotation)和右旋(Right Rotation),如下图
不难发现,旋转操作需要让树仍然满足二叉查找树的性质
具体地说,例如在这张图中,右旋C节点我们就需要 1.用C替换B成为A的子节点 2.把C的右儿子作为B的左儿子 3.把B作为C的右儿子
右旋B节点我们就需要 1.用B替换C成为A的子节点 2.把B的左儿子作为C的右儿子 3.把C作为B的左儿子
void rotate(int x)
{
int y=node[x].fa,z=node[y].fa,k=(node[y].ch[1]==x);//y是x的父亲,z是y的父亲,k表示x是不是y的左儿子
node[z].ch[node[z].ch[1]==y]=x;//用x替换y作为z的儿子
node[x].fa=z;
node[y].ch[k]=node[x].ch[k^1];//把x的对应儿子转移给y
node[node[x].ch[k^1]].fa=y;
node[x].ch[k^1]=y;//把y更新为x的儿子
node[y].fa=x;
update(y),update(x);//y是x的子节点所以必须先更新y再更新x
}
代码中 ^1 表示异或1,其中 1^1=0 ,0^1=1,也就是取反的意思
伸展(Splay)
Splay操作即为把一个节点通过不断旋转旋转到根,每次查询或者修改操作后都将操作的节点Splay到根就能保证单次操作复杂度均摊为 \(\log n\)
但是我们不能单纯地把节点不断向上转,考虑如下这种情况,我们想把C转到根
我们发现如果单纯把C向上转的话,如果原本是链,转动之后还是链,并不能优化,所以我们需要双旋,也就是在当前节点和当前节点的父节点都是自己父节点的左儿子或者右儿子时,我们先转动当前节点的父节点,再转动当前节点,如下图
这样就让单次操作复杂度均摊到了 \(\log n\),代码实现如下
void splay(int x,int target)//让x成为target的子节点,根节点是节点0的子节点
{
while(node[x].fa!=target)//转到目标就停止
{
int y=node[x].fa,z=node[y].fa;
if(z!=target)//如果转一下就满足就不多转一次
((node[z].ch[0]==y)^(node[y].ch[0]==x))?rotate(x):rotate(y);//三点为链就转y否则转x
rotate(x);//转一下x
}
if(!target)Root=x;//如果是转到根就更新根
}
插入
插入一个新节点要先从根据插入节点与当前找到的节点的大小关系从根节点开始向两个子节点走,直到找到插入节点值与当前节点相等,或者当前找到的节点编号不存在(即这个值从未出现过),然后更新信息即可
具体实现很好理解,看代码
void insert(int x)//插入x
{
int cur=Root,from=0;//当前找到的节点编号cur以及节点来源from
while(cur&&x!=node[cur].v)//找到了有相同值或者进入了未定义节点就停下
from=cur,cur=node[cur].ch[x>node[cur].v];//往子节点走
if(cur)//已经存在就增加个数即可
++node[cur].cnt;
else//创建新节点
{
cur=++node_cnt;//分配新编号
if(!from) Root=cur;//是根就更新根信息
else node[from].ch[x>node[from].v]=cur;//不是根就更新父节点信息
//更新新节点信息
node[cur].v=x;
node[cur].cnt=1;
node[cur].fa=from;
node[cur].siz=1;
node[cur].ch[0]=node[cur].ch[1]=0;
}
splay(cur,0);//转到根
}
查找
查找操作就是在树上找一个值对应的节点,并且把这个节点转到根,方便进行操作,不断往下找就行,看代码
void find(int x)//查找元素,调用后根即为查找的元素
{
int cur=Root;//从根开始查
if(!cur)return;//树为空就退出
while(node[cur].ch[x>node[cur].v]&&x!=node[cur].v)
//x不是当前节点值且当前节点有更小或更大值就进入子节点继续找
cur=node[cur].ch[x>node[cur].v];//进入子节点
splay(cur,0);//将找到的节点转到根
}
查找前驱后继
前驱定义为比一个数小的数中最大的,后继定义为比一个数大的数中最小的
直接把要查找的值对应的节点转到根,要查前驱就从左儿子开始一直往右找,要查后继就从右儿子开始一直往左找,非常简单
int find_pre_id(int x)//查前驱编号
{
find(x);//转到根
if(node[Root].v<x)return Root;//原树中没有这个值,就直接返回根
int cur=node[Root].ch[0];//进左子树
while(node[cur].ch[1]) cur=node[cur].ch[1];//往右
return cur;
}
int find_nxt_id(int x)//查后继编号
{
find(x);//转到根
if(node[Root].v>x)return Root;//原树中没有这个值,就直接返回根
int cur=node[Root].ch[1];//进右子树
while(node[cur].ch[0]) cur=node[cur].ch[0];//往左
return cur;
}
int find_pre(int x)//查前驱值
{
x=find_pre_id(x);//找到前驱编号
return node[x].v;//返回值
}
int find_nxt(int x)//查后继值
{
x=find_nxt_id(x);//找到后继编号
return node[x].v;//返回值
}
查数的排名
排名定义为比一个数小的数的个数+1,只需要将这个数转到根,返回比它左儿子的 \(size\) 即可(因为已经插入了极小值,所以不需要再多+1)
int get_rank(int x)
{
find(x);//转到根
return node[node[Root].ch[0]].siz;
//比他小的数的个数+1就是排名,这里因为我们插入了极小值就不用+1了
}
查排名对应的数
我们从根节点开始找,左边和当前节点个数小于排名,就说明这个数一定在右子树,我们把排名减去左边和当前节点个数然后进入右子树查询新排名
如果左子树节点更多就直接进左子树查
否则当前找到的节点就是对应的数,返回即可
int kth(int rank)//查排名为k的数
{
++rank;//这里因为我们插入了极小值,排名需要+1
int cur=Root,son;//cur从根开始,son为当前节点的右儿子
if(node[cur].siz<rank) return -1;//没有这么多数就退出
while(1)
{
son=node[cur].ch[0];
if(rank>node[son].siz+node[cur].cnt)//左边和当前节点个数不到k
{
rank-=node[son].siz+node[cur].cnt;//减去这么多个
cur=node[cur].ch[1];//进入右子树
}
else if(node[son].siz>=rank) cur=son;//左子树节点更多就进左子树查
else return node[cur].v;//找到了就返回
}
}
删除
删除操作较为复杂,我们先找到要删除数的前驱和后继,把前驱转到根,后继转到根节点也就是前驱的子节点,此时显然有后继是前驱的右儿子,后继的左儿子一定大于前驱小于后继,也就是我们要删除的数,根据前驱和后继的定义,后继一定有且仅有一个左儿子,对这个节点进行删除即可
void erase(int x)
{
int x_pre=find_pre_id(x),x_nxt=find_nxt_id(x);//找x的前驱后继
splay(x_pre,0);//把前驱转到根
splay(x_nxt,x_pre);//把后继转到根的子节点
int cur=node[x_nxt].ch[0];//此时x一定是后继的左儿子
if(node[cur].cnt>1)//删不完
{
--node[cur].cnt;//减少一个
splay(cur,0);//转
}
else node[x_nxt].ch[0]=0;//切断后继的左子树
}
初始化
在实际使用过程中,为了防止越界,我们常常会在树中插入一个极小值,一个极大值,本文关于排名的代码都是以已经插入极小值极大值为前提,自己编写程序时一定要记得初始化插入一个极小值,一个极大值
纯享版封装Splay
code
template<int N,typename _Tp=int,_Tp INF=2147483647> class Splay
{
private:
int Root,node_cnt;
struct Node
{
_Tp v;
int cnt,siz,fa,ch[2];
};
Node node[N];
void update(int x)//更新
{
node[x].siz=node[node[x].ch[0]].siz+node[node[x].ch[1]].siz+node[x].cnt;
}
void rotate(int x)//旋转
{
int y=node[x].fa,z=node[y].fa,k=(node[y].ch[1]==x);
node[z].ch[node[z].ch[1]==y]=x;
node[x].fa=z;
node[y].ch[k]=node[x].ch[k^1];
node[node[x].ch[k^1]].fa=y;
node[x].ch[k^1]=y;
node[y].fa=x;
update(y),update(x);
}
void splay(int x,int target)//转到目标节点的儿子
{
while(node[x].fa!=target)
{
int y=node[x].fa,z=node[y].fa;
if(z!=target)
((node[z].ch[0]==y)^(node[y].ch[0]==x))?rotate(x):rotate(y);
rotate(x);
}
if(!target)Root=x;
}
void find(_Tp x)//对应值节点转到根
{
int cur=Root;
if(!cur)return;
while(node[cur].ch[x>node[cur].v]&&x!=node[cur].v)
cur=node[cur].ch[x>node[cur].v];
splay(cur,0);
}
public:
Splay(){Root=node_cnt=0;insert(INF),insert(-INF);}//初始化
void insert(_Tp x)//插入
{
int cur=Root,from=0;
while(cur&&x!=node[cur].v)
from=cur,cur=node[cur].ch[x>node[cur].v];
if(cur)
++node[cur].cnt;
else
{
cur=++node_cnt;
if(!from) Root=cur;
else node[from].ch[x>node[from].v]=cur;
node[cur].v=x;
node[cur].cnt=1;
node[cur].fa=from;
node[cur].siz=1;
node[cur].ch[0]=node[cur].ch[1]=0;
}
splay(cur,0);
}
int find_pre_id(_Tp x)//查前驱编号
{
find(x);
if(node[Root].v<x)return Root;
int cur=node[Root].ch[0];
while(node[cur].ch[1]) cur=node[cur].ch[1];
return cur;
}
int find_nxt_id(_Tp x)//查后继编号
{
find(x);
if(node[Root].v>x)return Root;
int cur=node[Root].ch[1];
while(node[cur].ch[0]) cur=node[cur].ch[0];
return cur;
}
_Tp find_pre(_Tp x)//查前驱值
{
x=find_pre_id(x);
return node[x].v;
}
_Tp find_nxt(_Tp x)//查后继值
{
x=find_nxt_id(x);
return node[x].v;
}
void erase(_Tp x)//删除
{
int x_pre=find_pre_id(x),x_nxt=find_nxt_id(x);
splay(x_pre,0);
splay(x_nxt,x_pre);
int cur=node[x_nxt].ch[0];
if(node[cur].cnt>1)
{
--node[cur].cnt;
splay(cur,0);
}
else node[x_nxt].ch[0]=0;
}
_Tp kth(int rank)//找排名为k的
{
++rank;
int cur=Root,son;
if(node[cur].siz<rank) return INF;
while(1)
{
son=node[cur].ch[0];
if(rank>node[son].siz+node[cur].cnt)
{
rank-=node[son].siz+node[cur].cnt;
cur=node[cur].ch[1];
}
else if(node[son].siz>=rank) cur=son;
else return node[cur].v;
}
}
int get_rank(_Tp x)//查排名
{
find(x);
return node[node[Root].ch[0]].siz;
}
};
食用方法
将上方代码加入您的代码中,定义时您需要提供一个参数 \(N\) 表示该Splay Tree的节点数上限,另有两个可选参数 \(\_Tp\) 表示您在其中所存储值的数据类型,\(INF\) 表示您所希望的正无穷大小,它的类型为 \(\_Tp\) 的类型,如果您不提供可选参数,\(\_Tp\) 将默认为int,\(INF\) 将默认为long long类型的 \(2147483647\)
具体见下方栗子
Splay<100000> A;
//定义一个名为A,节点个数至多100000个,储值类型为int,正无穷为2147483647的Splay Tree
Splay<500,long long> B;
//定义一个名为B,节点个数至多500个,储值类型为long long,正无穷为2147483647的Splay Tree
Splay<114514,double,1919810> C;
//定义一个名为C,节点个数至多114514个,储值类型为double,正无穷为1919810的Splay Tree
Splay<123,short int> D[233];
//定义一个名为D的容量为233的一维数组,每一个下标有一棵节点个数至多123个,储值类型为short int,正无穷为2147483647的Splay Tree
致谢
FJN 妹子 和 npy SYQ
该文为本人原创,转载请注明出处