【模板】Splay

Splay 均摊复杂度证明见此处 \(\rightarrow\) 链接

代码如下

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
const int inf=0x3f3f3f3f; struct node{
#define ls(x) t[x].ch[0]
#define rs(x) t[x].ch[1]
int fa,ch[2],val,size,cnt;
}t[maxn];
int tot,root;
inline int get(int x){return x==rs(t[x].fa);}
inline void pushup(int x){
t[x].size=t[ls(x)].size+t[rs(x)].size+t[x].cnt;
}
inline int find(int val){
int x=root;
while(t[x].val!=val&&t[x].ch[t[x].val<val])x=t[x].ch[t[x].val<val];
return x;
}
inline void rotate(int x){
int fa=t[x].fa,gfa=t[fa].fa;
int d1=get(x),d2=get(fa);
t[fa].ch[d1]=t[x].ch[d1^1],t[t[x].ch[d1^1]].fa=fa;
t[x].ch[d1^1]=fa,t[fa].fa=x;
t[x].fa=gfa,t[gfa].ch[d2]=x;
pushup(fa),pushup(x);
}
inline void splay(int x,int goal){
while(t[x].fa!=goal){
int fa=t[x].fa,gfa=t[fa].fa;
if(gfa!=goal)get(x)==get(fa)?rotate(fa):rotate(x);
rotate(x);
}
if(!goal)root=x;
}
void insert(int val){
int x=root,fa=0;
while(x&&t[x].val!=val)fa=x,x=t[x].ch[t[x].val<val];
if(x)++t[x].cnt;
else{
x=++tot;
if(fa)t[fa].ch[t[fa].val<val]=x;
t[x].fa=fa,t[x].val=val,t[x].cnt=t[x].size=1;
}
splay(x,0);
}
int kth(int x,int k){
if(k<=t[ls(x)].size)return kth(ls(x),k);
else if(k>t[ls(x)].size+t[x].cnt)return kth(rs(x),k-t[ls(x)].size-t[x].cnt);
else return t[x].val;
}
int getrank(int val){
splay(find(val),0);
return t[ls(root)].size;
}
int getpre(int val){
splay(find(val),0);
if(t[root].val<val)return root;
int x=ls(root);
while(rs(x))x=rs(x);
return x;
}
int getnxt(int val){
splay(find(val),0);
if(t[root].val>val)return root;
int x=rs(root);
while(ls(x))x=ls(x);
return x;
}
void remove(int val){
int pre=getpre(val),nxt=getnxt(val);
splay(pre,0),splay(nxt,pre);
if(t[ls(nxt)].cnt>1)--t[ls(nxt)].cnt,splay(ls(nxt),0);
else ls(nxt)=0,splay(nxt,0);
}
void initial(){insert(-inf),insert(inf);} int main(){
initial();
int opt,val,n;
scanf("%d",&n);
while(n--){
scanf("%d%d",&opt,&val);
switch(opt){
case 1:insert(val);break;
case 2:remove(val);break;
case 3:printf("%d\n",getrank(val));break;
case 4:printf("%d\n",kth(root,val+1));break;
case 5:printf("%d\n",t[getpre(val)].val);break;
case 6:printf("%d\n",t[getnxt(val)].val);break;
}
}
return 0;
}
上一篇:jenkins jmeter持续集成批处理jmx脚本


下一篇:Android 开源项目源码解析(第二期)