【P2710 数列】【Splay】
Splay的经典模板题,细节非常多。这里主要记录一些容易错的点,和不太容易理解的地方
-
pushdown 函数和懒标记的含义
有两种写法:
第一种是懒标记代表其子节点是否更新,这也是通常的写法(线段树和平衡树都是),在这种写法中,pushdown的作用是将子节点的信息更新。
这样做的好处是在很多情况下pushup(x)的时候,不需要先pushdown(son(x))了【具体见代码】
在查询中可能也会更方便一些
第二种是懒标记代表该节点是否更新,很少有人用这种写法,在这种写法中pushdown的作用是更新该节点,比上一种写法要少写“一半”(只更新一个点),但是在查询或pushup的时候得多考虑考虑是否需要pushdown了(可以把第一种写法理解为比第二种写法更快一步) -
标记之间的顺序与覆盖
不论是线段树还是平衡树,标记之间都要自定义一个顺序(例如线段树2),有的标记之间可以覆盖,如本题中的same就可以覆盖rev。
笔者在本题中犯的一个错误就是在pushdown时,先判断same标记,若为真,就不再处理rev标记,但是这样做就导致rev标记没有清空,下次查到该点时还会再更新。
对于这种情况有两种处理方式:
一种是只让same标记和rev标记最多存在一种,即makesame时把rev清空,reverse时若same为真则return。
第二种就是在pushdown时确保把所有标记都清空。 -
内存回收
原理很简单,把删除的点用一个栈存起来,添加新点时优先从栈中取点。
再删除中通常是用splay的提取区间操作,用dfs遍历一颗子树。
笔者犯下的错误是在dfs中采取中序遍历方式,且在访问完一个点后就将该点信息清空,这样的话由于该点的信息已被清空,就无法遍历右子树了。所以一定要遍历完整颗子树后再将信息清空。
最后贴一下代码
#include<bits/stdc++.h>
using namespace std;
inline int read()
{
register int x=0,w=1;
register char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if(ch=='-') {w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
return x*w;
}
const int N=2e5+100,inf=1e8;
int n,m,a[N],nodes[N],top,rt;
struct node{
int p,s[2],siz,rev,same,sum,val,ls,rs,ms;
}t[N];
char op[10];
void clear(int x)
{
t[x].p=t[x].s[0]=t[x].s[1]=t[x].siz=t[x].rev=t[x].same=t[x].sum=t[x].val=t[x].ls=t[x].rs=t[x].ms=0;
}
int get(int x)
{
return t[t[x].p].s[1]==x;
}
void pushup(int x)
{
t[x].siz=t[t[x].s[0]].siz+t[t[x].s[1]].siz+1;
t[x].sum=t[t[x].s[0]].sum+t[t[x].s[1]].sum+t[x].val;
t[x].ls=max(t[x].s[0]?t[t[x].s[0]].ls:-inf,t[t[x].s[0]].sum+t[x].val+max(0,t[t[x].s[1]].ls));
t[x].rs=max(t[x].s[1]?t[t[x].s[1]].rs:-inf,t[t[x].s[1]].sum+t[x].val+max(0,t[t[x].s[0]].rs));
t[x].ms=max(max(t[x].s[0]?t[t[x].s[0]].ms:-inf,t[x].s[1]?t[t[x].s[1]].ms:-inf),t[x].val+max(0,t[t[x].s[0]].rs)+max(0,t[t[x].s[1]].ls));
}
void makesame(int x,int c)
{
t[x].val=c;
t[x].rev=0,t[x].same=1;
t[x].sum=t[x].siz*c;
if(c>=0) t[x].ls=t[x].rs=t[x].ms=t[x].sum;
else t[x].ls=t[x].rs=t[x].ms=c;
}
void reverse(int x)
{
if(t[x].same) return;
swap(t[x].s[0],t[x].s[1]);
swap(t[x].ls,t[x].rs);
t[x].rev^=1;
}
void pushdown(int x)
{
if(t[x].same)
{
t[x].same=0;
// t[x].rev=0;
if(t[x].s[0]) makesame(t[x].s[0],t[x].val);
if(t[x].s[1]) makesame(t[x].s[1],t[x].val);
return;
}
if(t[x].rev)
{
t[x].rev=0;
if(t[x].s[0]) reverse(t[x].s[0]);
if(t[x].s[1]) reverse(t[x].s[1]);
}
}
int rkget(int k)
{
int now=rt;
while(now)
{
pushdown(now);
if(k<=t[t[now].s[0]].siz) now=t[now].s[0];
else
{
if(k==t[t[now].s[0]].siz+1) return now;
k-=t[t[now].s[0]].siz+1;
now=t[now].s[1];
}
}
}
int build(int l,int r,int p)
{
int mid=l+r>>1;
int u=nodes[top--];
t[u].val=a[mid];
t[u].p=p;
if(l<mid) t[u].s[0]=build(l,mid-1,u);
if(r>mid) t[u].s[1]=build(mid+1,r,u);
pushup(u);
return u;
}
void rotate(int x)
{
int y=t[x].p,z=t[y].p,k=get(x);
t[z].s[get(y)]=x;t[x].p=z;
t[y].s[k]=t[x].s[k^1];t[t[y].s[k]].p=y;
t[x].s[k^1]=y;t[y].p=x;
pushup(y);pushup(x);
}
void splay(int x,int k)
{
for(int y=t[x].p;y!=k;rotate(x),y=t[x].p)
if(t[y].p!=k) rotate(get(y)==get(x)?y:x);
if(!k) rt=x;
}
void dfs(int x)
{
if(t[x].s[0]) dfs(t[x].s[0]);
if(t[x].s[1]) dfs(t[x].s[1]);
nodes[++top]=x;clear(x);
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;++i) a[i]=read();
for(int i=1;i<N;++i) nodes[++top]=i;
rt=build(0,n+1,0);
for(int i=1;i<=m;++i)
{
scanf("%s",op);int x=read();
if(!strcmp(op,"INSERT"))
{
int cnt=read();
for(int j=1;j<=cnt;++j)
{
a[j]=read();
}
int l=rkget(x+1),r=rkget(x+2);
splay(l,0);splay(r,l);
t[r].s[0]=build(1,cnt,r);
pushup(r);pushup(l);
}
else if(!strcmp(op,"DELETE"))
{
int cnt=read();
int l=rkget(x),r=rkget(x+cnt+1);
splay(l,0);splay(r,l);
dfs(t[r].s[0]);
t[r].s[0]=0;
pushup(r);pushup(l);
}
else if(!strcmp(op,"REVERSE"))
{
int cnt=read();
int l=rkget(x),r=rkget(x+cnt+1);
splay(l,0);splay(r,l);
reverse(t[r].s[0]);
pushup(r);pushup(l);
}
else if(!strcmp(op,"MAKE-SAME"))
{
int cnt=read(),c=read();
int l=rkget(x),r=rkget(x+cnt+1);
splay(l,0);splay(r,l);
makesame(t[r].s[0],c);
pushup(r);pushup(l);
}
else if(!strcmp(op,"GET-SUM"))
{
int cnt=read();
int l=rkget(x),r=rkget(x+cnt+1);
splay(l,0);splay(r,l);
printf("%d\n",t[t[r].s[0]].sum);
}
else if(!strcmp(op,"GET"))
{
int l=rkget(x+1);splay(l,0);
printf("%d\n",t[rt].val);
}
else
{
int cnt=read();
int l=rkget(x),r=rkget(x+cnt+1);
splay(l,0);splay(r,l);
printf("%d\n",t[t[r].s[0]].ms);
}
}
return 0;
}