【P2710 数列】【Splay】

【P2710 数列】【Splay】

Splay的经典模板题,细节非常多。这里主要记录一些容易错的点,和不太容易理解的地方

  1. pushdown 函数和懒标记的含义
    有两种写法:
    第一种是懒标记代表其子节点是否更新,这也是通常的写法(线段树和平衡树都是),在这种写法中,pushdown的作用是将子节点的信息更新。
    这样做的好处是在很多情况下pushup(x)的时候,不需要先pushdown(son(x))了【具体见代码】
    在查询中可能也会更方便一些
    第二种是懒标记代表该节点是否更新,很少有人用这种写法,在这种写法中pushdown的作用是更新该节点,比上一种写法要少写“一半”(只更新一个点),但是在查询或pushup的时候得多考虑考虑是否需要pushdown了(可以把第一种写法理解为比第二种写法更快一步)
  2. 标记之间的顺序与覆盖
    不论是线段树还是平衡树,标记之间都要自定义一个顺序(例如线段树2),有的标记之间可以覆盖,如本题中的same就可以覆盖rev。
    笔者在本题中犯的一个错误就是在pushdown时,先判断same标记,若为真,就不再处理rev标记,但是这样做就导致rev标记没有清空,下次查到该点时还会再更新。
    对于这种情况有两种处理方式:
    一种是只让same标记和rev标记最多存在一种,即makesame时把rev清空,reverse时若same为真则return。
    第二种就是在pushdown时确保把所有标记都清空。
  3. 内存回收
    原理很简单,把删除的点用一个栈存起来,添加新点时优先从栈中取点。
    再删除中通常是用splay的提取区间操作,用dfs遍历一颗子树。
    笔者犯下的错误是在dfs中采取中序遍历方式,且在访问完一个点后就将该点信息清空,这样的话由于该点的信息已被清空,就无法遍历右子树了。所以一定要遍历完整颗子树后再将信息清空。

最后贴一下代码

#include<bits/stdc++.h>
using namespace std;

inline int read()
{
	register int x=0,w=1;
	register char ch=getchar();
	while((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
	if(ch=='-') {w=-1;ch=getchar();}
	while(ch>='0'&&ch<='9') {x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return x*w;
}
const int N=2e5+100,inf=1e8;
int n,m,a[N],nodes[N],top,rt;
struct node{
	int p,s[2],siz,rev,same,sum,val,ls,rs,ms;
}t[N];
char op[10];
void clear(int x)
{
	t[x].p=t[x].s[0]=t[x].s[1]=t[x].siz=t[x].rev=t[x].same=t[x].sum=t[x].val=t[x].ls=t[x].rs=t[x].ms=0;
}
int get(int x)
{
	return t[t[x].p].s[1]==x;
}
void pushup(int x)
{
	t[x].siz=t[t[x].s[0]].siz+t[t[x].s[1]].siz+1;
	t[x].sum=t[t[x].s[0]].sum+t[t[x].s[1]].sum+t[x].val;
	t[x].ls=max(t[x].s[0]?t[t[x].s[0]].ls:-inf,t[t[x].s[0]].sum+t[x].val+max(0,t[t[x].s[1]].ls));
	t[x].rs=max(t[x].s[1]?t[t[x].s[1]].rs:-inf,t[t[x].s[1]].sum+t[x].val+max(0,t[t[x].s[0]].rs));
	t[x].ms=max(max(t[x].s[0]?t[t[x].s[0]].ms:-inf,t[x].s[1]?t[t[x].s[1]].ms:-inf),t[x].val+max(0,t[t[x].s[0]].rs)+max(0,t[t[x].s[1]].ls));
}
void makesame(int x,int c)
{
	t[x].val=c;
	t[x].rev=0,t[x].same=1;
	t[x].sum=t[x].siz*c;
	if(c>=0) t[x].ls=t[x].rs=t[x].ms=t[x].sum;
	else t[x].ls=t[x].rs=t[x].ms=c;
}
void reverse(int x)
{
	if(t[x].same) return;
	swap(t[x].s[0],t[x].s[1]);
	swap(t[x].ls,t[x].rs);
	t[x].rev^=1;
}
void pushdown(int x)
{
	if(t[x].same)
	{
		t[x].same=0;
//		t[x].rev=0;
		if(t[x].s[0]) makesame(t[x].s[0],t[x].val);
		if(t[x].s[1]) makesame(t[x].s[1],t[x].val);
		return;
	}
	if(t[x].rev)
	{
		t[x].rev=0;
		if(t[x].s[0]) reverse(t[x].s[0]);
		if(t[x].s[1]) reverse(t[x].s[1]);
	}
}
int rkget(int k)
{
	int now=rt;
	while(now)
	{
		pushdown(now);
		if(k<=t[t[now].s[0]].siz) now=t[now].s[0];
		else
		{
			if(k==t[t[now].s[0]].siz+1) return now;
			k-=t[t[now].s[0]].siz+1;
			now=t[now].s[1];
		}
	}
}
int build(int l,int r,int p)
{
	int mid=l+r>>1;
	int u=nodes[top--];
	t[u].val=a[mid];
	t[u].p=p;
	if(l<mid) t[u].s[0]=build(l,mid-1,u);
	if(r>mid) t[u].s[1]=build(mid+1,r,u);
	pushup(u);
	return u;
}
void rotate(int x)
{
	int y=t[x].p,z=t[y].p,k=get(x);
	t[z].s[get(y)]=x;t[x].p=z;
	t[y].s[k]=t[x].s[k^1];t[t[y].s[k]].p=y;
	t[x].s[k^1]=y;t[y].p=x;
	pushup(y);pushup(x);
}
void splay(int x,int k)
{
	for(int y=t[x].p;y!=k;rotate(x),y=t[x].p)
	  if(t[y].p!=k) rotate(get(y)==get(x)?y:x);
	if(!k) rt=x; 
}
void dfs(int x)
{
	if(t[x].s[0]) dfs(t[x].s[0]);
	if(t[x].s[1]) dfs(t[x].s[1]);
	nodes[++top]=x;clear(x);
}
int main()
{
    n=read();m=read();
    for(int i=1;i<=n;++i) a[i]=read();
    for(int i=1;i<N;++i) nodes[++top]=i;
    rt=build(0,n+1,0);
    for(int i=1;i<=m;++i)
    {
    	scanf("%s",op);int x=read();
    	if(!strcmp(op,"INSERT"))
    	{
    		int cnt=read();
    		for(int j=1;j<=cnt;++j)
    		{
    			a[j]=read();
			}
			int l=rkget(x+1),r=rkget(x+2);
			splay(l,0);splay(r,l);
			t[r].s[0]=build(1,cnt,r);
			pushup(r);pushup(l);
		}
		else if(!strcmp(op,"DELETE"))
		{
			int cnt=read();
			int l=rkget(x),r=rkget(x+cnt+1);
			splay(l,0);splay(r,l);
			dfs(t[r].s[0]);
			t[r].s[0]=0;
			pushup(r);pushup(l);
		}
		else if(!strcmp(op,"REVERSE"))
		{
			int cnt=read();
			int l=rkget(x),r=rkget(x+cnt+1);
			splay(l,0);splay(r,l);
			reverse(t[r].s[0]);
			pushup(r);pushup(l);
		}
		else if(!strcmp(op,"MAKE-SAME"))
		{
			int cnt=read(),c=read();
			int l=rkget(x),r=rkget(x+cnt+1);
			splay(l,0);splay(r,l);
			makesame(t[r].s[0],c);
			pushup(r);pushup(l);
		}
		else if(!strcmp(op,"GET-SUM"))
		{
			int cnt=read();
			int l=rkget(x),r=rkget(x+cnt+1);
			splay(l,0);splay(r,l);
			printf("%d\n",t[t[r].s[0]].sum);
		}
		else if(!strcmp(op,"GET"))
		{
			int l=rkget(x+1);splay(l,0);
			printf("%d\n",t[rt].val);
		}
		else
		{
			int cnt=read();
			int l=rkget(x),r=rkget(x+cnt+1);
			splay(l,0);splay(r,l);
			printf("%d\n",t[t[r].s[0]].ms);
		}
	}
	return 0;
}
上一篇:Leetcode 239. 滑动窗口最大值(困难) 单调队列解决滑动窗口最大值


下一篇:#1051. Pop Sequence【栈 + 模拟】