终于调出来这道题了,写篇题解(
首先碰到这样的题我们肯定要考虑每种操作会对树的形态产生怎样的影响:
-
插入操作:对于 BST 有一个性质是,当你插入一个节点时,其在 BST 上的父亲肯定是,你把 BST 中父亲按权值 sort 一遍排成一列后,在待插入的数的两侧的数对应的节点中,深度较大者。因此我们考虑用一个
set
,将所有点的权值和编号压进去然后在里面lower_bound
即可找出待插入点两侧的点。 -
单旋最小值:稍微画几个图即可发现,对于最小值代表的点 \(x\),如果 \(x\) 已经是根了就可以忽略此次操作,否则假设 \(x\) 在 splay 上的父亲为 \(f\),右儿子为 \(son\),原来的根为 \(rt\),那么此次操作等价于以下四个删断边操作:
- 断开 \(x,f\) 之间的边
- 断开 \(x,son\) 之间的边(如果 \(x\) 不存在右儿子则忽略)
- 连上 \(x,rt\) 之间的边,其中 \(x\) 为 \(rt\) 的父亲
- 连上 \(f,son\) 之间的边
注意到这里涉及删断边,并且在任意时刻图都是一棵森林,因此可以 LCT 维护。
-
单旋最大值:同单旋最小值的情况,只不过这里需要把右儿子改为左儿子
-
单旋删除最小值:与单旋最小值的情况类似,只不过这次不需要连 \(x\) 与 \(rt\) 之间的边
-
单旋删除最大值:与单旋删除最小值的情况类似,只不过这里也需要把右儿子改为左儿子
程序的大致框架构建出来了,接下来考虑如何具体实现每个操作:
- 查询一个点在 BST 上的深度:直接把这个点
access
一遍并转到 splay 的根,那么这个点的siz
就是该点在 BST 上的深度大小。 - 查询一个点的左/右儿子:在我们 LCT 的过程中,我们失去了原 BST 上左右儿子的信息,因此我们无法直接通过将它转到根,然后调用其
ch[0]/ch[1]
的方法求其左右儿子。不过注意到每个点在 BST 上儿子个数 \(\le 2\),因此我们考虑 top tree 的思想,用一个set
维护其虚儿子,这样我们每次查询一个点的左右儿子时,只需把它access
一遍并转到根,然后在它的虚儿子集合中找到键值大于 / 小于该点的键值的点即可。
最后是一些注意点:
- 在
rotate
时,如果 \(x\) 的父亲是 \(x\) 所在splay
的根,那么我们要在 \(x\) 父亲的父亲的虚儿子集合中删除 \(y\) 加入 \(x\),这一点在普通的 top tree 中不用考虑,因为转 \(x\) 不会影响 \(x\) 的父亲的父亲的子树的大小,但是这里我们维护的是一个点的虚儿子具体是什么,虚儿子变了,父亲的信息也要改变。 - 在查询左右儿子时,不能找到一个键值比待查询点键值大 / 小的点就
return
,要在对应子树中找到深度最浅(中序遍历中第一位)的点再返回。
const int MAXN=1e5;
const int INF=0x3f3f3f3f;
int ncnt=0;
struct node{int ch[2],f,siz,rev_lz,val;set<int> img_ch;} s[MAXN+5];
void pushup(int k){s[k].siz=s[s[k].ch[0]].siz+s[s[k].ch[1]].siz+1;}
int ident(int k){return ((s[s[k].f].ch[0]==k)?0:((s[s[k].f].ch[1]==k)?1:-1));}
void connect(int k,int f,int op){s[k].f=f;if(~op) s[f].ch[op]=k;}
void rotate(int x){
int y=s[x].f,z=s[y].f,dx=ident(x),dy=ident(y);
connect(s[x].ch[dx^1],y,dx);connect(y,x,dx^1);connect(x,z,dy);
pushup(y);pushup(x);assert(~dx);
if(dy==-1&&z){
s[z].img_ch.erase(s[z].img_ch.find(y));
s[z].img_ch.insert(x);
}
}
void splay(int k){
while(~ident(k)){
if(ident(s[k].f)==-1) rotate(k);
else if(ident(k)==ident(s[k].f)) rotate(s[k].f),rotate(k);
else rotate(k),rotate(k);
}
}
void access(int k){
int pre=0;
for(;k;pre=k,k=s[k].f){
splay(k);
if(s[k].ch[1]) s[k].img_ch.insert(s[k].ch[1]);s[k].ch[1]=pre;
if(s[k].ch[1]) s[k].img_ch.erase(s[k].img_ch.find(s[k].ch[1]));
pushup(k);
}
}
int findroot(int k){
access(k);splay(k);
while(s[k].ch[0]) k=s[k].ch[0];
splay(k);return k;
}
void link(int x,int y){
access(x);splay(x);
s[x].f=y;s[y].img_ch.insert(x);
}//y is x's father
int getfa(int x){
access(x);splay(x);x=s[x].ch[0];
while(s[x].ch[1]) x=s[x].ch[1];
return x;
}
int getls(int x){
access(x);splay(x);
for(int c:s[x].img_ch) if(s[c].val<s[x].val){
while(s[c].ch[0]) c=s[c].ch[0];
return c;
}
return 0;
}
int getrs(int x){
access(x);splay(x);
for(int c:s[x].img_ch) if(s[c].val>s[x].val){
while(s[c].ch[0]) c=s[c].ch[0];
return c;
}
return 0;
}
void cut(int x,int y){
access(x);splay(x);int son=s[x].ch[0];
s[x].ch[0]=s[son].f=0;pushup(x);
}//y is x's father
set<pii> st;
int calc_dep(int x){access(x);splay(x);return s[x].siz;}
void splay_mn(){
pii p=*++st.begin();int id=p.se;
access(id);splay(id);printf("%d\n",s[id].siz);
if(findroot(id)==id) return;
int fa=getfa(id),rt=findroot(id);
cut(id,fa);int son=getrs(id);
if(son) assert(getfa(son)==id),cut(son,id),link(son,fa);
link(rt,id);assert(findroot(fa)==id);
}
void splay_mx(){
pii p=*-- --st.end();int id=p.se;
access(id);splay(id);printf("%d\n",s[id].siz);
if(findroot(id)==id) return;
int fa=getfa(id),rt=findroot(id);
cut(id,fa);int son=getls(id);
if(son) assert(getfa(son)==id),cut(son,id),link(son,fa);
link(rt,id);
}
void del_mn(){
pii p=*++st.begin();int id=p.se;st.erase(st.find(p));
access(id);splay(id);printf("%d\n",s[id].siz);
if(findroot(id)==id){
int son=getrs(id);
if(son) cut(son,id);
return;
}
int fa=getfa(id),rt=findroot(id);
cut(id,fa);int son=getrs(id);
if(son) assert(getfa(son)==id),cut(son,id),link(son,fa);
}
void del_mx(){
pii p=*-- --st.end();int id=p.se;st.erase(st.find(p));
access(id);splay(id);printf("%d\n",s[id].siz);
if(findroot(id)==id){
int son=getls(id);
if(son) cut(son,id);
return;
}
int fa=getfa(id),rt=findroot(id);
cut(id,fa);int son=getls(id);
if(son) assert(getfa(son)==id),cut(son,id),link(son,fa);
}
int main(){
int qu;scanf("%d",&qu);
st.insert(mp(0,0));st.insert(mp(INF,0));
while(qu--){
int opt;scanf("%d",&opt);
if(opt==1){
int x;scanf("%d",&x);
s[++ncnt].val=x;s[ncnt].siz=1;
st.insert(mp(x,ncnt));
if(st.size()>3){
pii nxt=*st.upper_bound(mp(x,ncnt));
pii pre=*--st.lower_bound(mp(x,ncnt));
int L=(pre.se)?calc_dep(pre.se):0;
int R=(nxt.se)?calc_dep(nxt.se):0;
if(L>R) link(ncnt,pre.se);
else link(ncnt,nxt.se);
} printf("%d\n",calc_dep(ncnt));
} else if(opt==2) splay_mn();
else if(opt==3) splay_mx();
else if(opt==4) del_mn();
else del_mx();
}
return 0;
}