二逼平衡树 bzoj-3196 Tyvj-1730
题目大意:请写出一个维护序列的数据结构支持:查询给定权值排名;查询区间k小值;单点修改;查询区间内定值前驱;查询区间内定值后继。
注释:$1\le n,m\le 5\times 10^4$。
想法:
在这里给予三种题解:
1)首先,最容易想到的应该就是树状数组套主席树也就是常说的带修改主席树。
第一个操作是简单的,我们只需要提取出当前区间的权值线段树后在上面二分即可。
第二个操作是主席树的看家本领好伐
第三个操作就是待修改主席树的意义。我们利用外层的树状数组即可实现这个操作。
第四个操作一个比较简单的办法就是在主席树的每个节点上再维护一个最大值。只需要二分之后输出左侧节点的最大值即可。
最后一个操作和第四个操作同理。
空间复杂度$O(nlog^2n)$,时间复杂度$O(nlog^2n)$。
2)其次,更容易想到的就是位置线段树套平衡树。
我们对于位置线段树上的每个节点对应的位置区间上的所有数建立一个以权值为关键字的非旋转$Treap$(嘻嘻就是比$splay$优越)
第一个操作比较简单,每次询问我们可以把它对应成外层线段树的$log$节点,也就对应了$log$棵非旋转$Treap$。我们对于每一棵非旋转$Treap$都求一下比$k$小的个数,然后把这$log$个加起来。因为求排名我们再$+1$即可$O(log^2n)$。
第二个操作较为复杂:我们要求区间$k$小值,其实就是因为这个操作我们才需要树套树。因为位置线段树的局限性,我们首先二分答案然后验证。验证就是第一个操作的验证,查询有多少个比它少即可,时间复杂度$O(log^3n)$。
第三个操作是将包含对应位置的位置线段树上的节点对应的非旋转$Treap$修改一次。我们发现这就是位置线段树对应位置到根节点的一条链,而每一棵非旋转$Treap$的修改只需要删除之前的权值然后插入对应权值即可。
第四个操作更容易了,我们像第一个操作一样将询问分成$log$个节点后,我们只需要在那$log$棵非旋转$Treap$上每棵都求一遍前驱求一次最大值即可。求前驱就是正常的平衡树求前驱。
第五个操作同理,求$log$个后继求最小值。
空间复杂度$O(nlogn)$,时间复杂度$O(nlog^3n)$。
附上唯一写了的代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define inf 2147483647
#define N 50010
using namespace std;
inline void dbug(int x) {printf("-----------------------------%d-----------------------------\n",x);}
inline char nc() {static char *p1,*p2,buf[100000]; return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;}
int rd() {int x=0; char c=nc(); while(!isdigit(c)) c=nc(); while(isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=nc(); return x;}
struct par {int x,y;};
struct Node
{
int ls,rs;
int val,key,size,mx,mn;
}a[N*30];
int rt[N<<2],cnt,num[N];
inline void Max(int &x,int y) {x=max(x,y);}
inline void Min(int &x,int y) {x=min(x,y);}
inline void pushup_tr(int x)
{
// puts("pushup_tr");
int ls=a[x].ls,rs=a[x].rs;
// a[ls].mx=a[rs].mx=-inf;
a[x].size=1; a[x].mx=a[x].mn=a[x].val;
if(ls) a[x].size+=a[ls].size,Max(a[x].mx,a[ls].mx),Min(a[x].mn,a[ls].mn);
if(rs) a[x].size+=a[rs].size,Max(a[x].mx,a[rs].mx),Min(a[x].mn,a[rs].mn);
}
inline int newnode(int val)
{
// puts("newnode");
int x=++cnt;
a[x].val=a[x].mx=a[x].mn=val; a[x].ls=a[x].rs=0;
a[x].size=1; a[x].key=rand();
return x;
}
int merge(int x,int y)
{
// puts("merge");
if(!x||!y) return x|y;
if(a[x].key>a[y].key)
{
a[x].rs=merge(a[x].rs,y); pushup_tr(x);
return x;
}
else
{
a[y].ls=merge(x,a[y].ls); pushup_tr(y);
return y;
}
}
par split(int x,int k)
{
// printf("%d %d\n",x,k);
if(!k) return (par){0,x};
int ls=a[x].ls,rs=a[x].rs;
if(k==a[ls].size)
{
a[x].ls=0; pushup_tr(x);
return (par){ls,x};
}
else if(k==a[ls].size+1)
{
a[x].rs=0; pushup_tr(x);
return (par){x,rs};
}
else if(k<a[ls].size)
{
par t=split(ls,k);
a[x].ls=t.y; pushup_tr(x);
return (par){t.x,x};
}
else
{
par t=split(rs,k-a[ls].size-1);
a[x].rs=t.x; pushup_tr(x);
return (par){x,t.y};
}
}
int getrank(int x,int k)
{
int ls=a[x].ls,rs=a[x].rs;
if(!x) return 0;
if(k<a[x].val) return getrank(ls,k);
else if(k>a[x].val) return a[ls].size+1+getrank(rs,k);
else
{
if(k>a[ls].mx) return a[ls].size;
else return getrank(ls,k);
}
}
int getpre(int x,int k)
{
if(k<=a[x].mn) return -inf;
int ls=a[x].ls,rs=a[x].rs;
if(k<a[x].val) return getpre(ls,k);
else if(k>a[x].val)
{
if(k<=a[rs].mn) return a[x].val;
else return getpre(rs,k);
}
else
{
if(k>a[ls].mx) return a[ls].mx;
else return getpre(ls,k);
}
}
int getnxt(int x,int k)
{
if(!x) return 0;
if(k>=a[x].mx) return inf;
int ls=a[x].ls,rs=a[x].rs;
if(k>a[x].val) return getnxt(rs,k);
else if(k<a[x].val)
{
if(k>=a[ls].mx) return a[x].val;
else return getnxt(ls,k);
}
else
{
if(k<a[rs].mn) return a[rs].mn;
else return getnxt(rs,k);
}
}
void del(int &x,int k)
{
int rk=getrank(x,k);
if(!x) return;
par t1=split(x,rk);
par t2=split(t1.y,1);
x=merge(t1.x,t2.y);
}
void insert(int &x,int k)
{
// puts("insert");
int rk=getrank(x,k);
par t1=split(x,rk);
x=merge(t1.x,merge(newnode(k),t1.y));
}
void build(int l,int r,int p)
{
rt[p]=newnode(num[l]);
for(int i=l+1;i<=r;i++)
{
int rk=getrank(rt[p],num[i]);
par t=split(rt[p],rk);
rt[p]=merge(t.x,merge(newnode(num[i]),t.y));
}
if(l==r) return;
int mid=(l+r)>>1;
build(l,mid,p<<1); build(mid+1,r,p<<1|1);
}
void update(int x,int val,int l,int r,int p)
{
// puts("update");
del(rt[p],num[x]); insert(rt[p],val);
if(l==r) return;
int mid=(l+r)>>1;
if(x<=mid) update(x,val,l,mid,p<<1);
else update(x,val,mid+1,r,p<<1|1);
}
int query_rank(int x,int y,int k,int l,int r,int p)
{
// printf("%d %d %d %d %d %d\n",x,y,k,l,r,p);
if(x<=l&&r<=y) return getrank(rt[p],k);
int mid=(l+r)>>1,ans=0;
if(x<=mid) ans+=query_rank(x,y,k,l,mid,p<<1);
if(mid<y) ans+=query_rank(x,y,k,mid+1,r,p<<1|1);
return ans;
}
int query_pre(int x,int y,int k,int l,int r,int p)
{
// puts("query_pre");
if(x<=l&&r<=y) return getpre(rt[p],k);
int mid=(l+r)>>1,ans=-inf;
if(x<=mid) Max(ans,query_pre(x,y,k,l,mid,p<<1));
if(mid<y) Max(ans,query_pre(x,y,k,mid+1,r,p<<1|1));
return ans;
}
int query_nxt(int x,int y,int k,int l,int r,int p)
{
// puts("query_nxt");
if(x<=l&&r<=y) return getnxt(rt[p],k);
int mid=(l+r)>>1,ans=inf;
if(x<=mid) Min(ans,query_nxt(x,y,k,l,mid,p<<1));
if(mid<y) Min(ans,query_nxt(x,y,k,mid+1,r,p<<1|1));
return ans;
}
void output(int x)
{
int ls=a[x].ls,rs=a[x].rs;
if(!x) return;
if(ls) output(ls);
printf("%d ",a[x].val);
if(rs) output(rs);
}
int main()
{
// freopen("a.in","r",stdin);
// freopen("a.out","w",stdout);
srand(20021214);
a[0].mx=-inf; a[0].mn=inf;
int n=rd(),m=rd(); for(int i=1;i<=n;i++) num[i]=rd();
build(1,n,1);
for(int i=1;i<=m;i++)
{
int opt=rd(),x,y,z;
if(opt==1)
{
x=rd(),y=rd(),z=rd();
printf("%d\n",query_rank(x,y,z,1,n,1)+1);
}
else if(opt==2)
{
x=rd(),y=rd(),z=rd();
int l=0,r=inf;
while(l<r)
{
int mid=(l+r)>>1;
int b1=query_rank(x,y,mid,1,n,1);
int b2=query_rank(x,y,mid+1,1,n,1);
if(b1<z&&b2>=z) {printf("%d\n",mid); break;}
else if(b1>=z) r=mid;
else l=mid+1;
}
}
else if(opt==3)
{
x=rd(),y=rd();
update(x,y,1,n,1);
num[x]=y;
}
else if(opt==4)
{
x=rd(),y=rd(),z=rd();
printf("%d\n",query_pre(x,y,z,1,n,1));
}
else
{
x=rd(),y=rd(),z=rd();
printf("%d\n",query_nxt(x,y,z,1,n,1));
}
}
return 0;
}
// void output_sgmtr(int l,int r,int p)
// {
// dbug(p);
// output(rt[p]); puts("");
// if(l==r) return;
// int mid=(l+r)>>1;
// output_sgmtr(l,mid,p<<1); output_sgmtr(mid+1,r,p<<1|1);
// }
// int main()
// {
// freopen("a.in","r",stdin);
// int n=rd(),m=rd(); for(int i=1;i<=n;i++) num[i]=rd();
// build(1,n,1);
// output(rt[1]); puts("");
// // output_sgmtr(1,n,1);
// // printf("%d\n",getrank(rt[1],4));
// printf("%d\n",query_rank(3,6,4,1,n,1));
// return 0;
// }
3)最后一个就是权值线段树套位置线段树。
第一个操作显然,我们对于外层权值线段树的每个节点维护一下对应权值下数的个数。查询的时候从外层根节点开始遍历,如果左儿子维护的线段树最大值大于查询权值则递归左儿子;反之递归右儿子并将答案加上左儿子中对应区间的$size$,单次操作时间复杂度$O(log^2n)$。
第二个操作就是这种写法的亮点。权值线段树可以带着一帮小弟一起二分!所以我们在第二种解法中外面的二分就可以删掉了。具体地,我们每次查询一下左儿子对应区间的$size$时候比$k$大。如果比$k$大我们就直接递归左儿子,反之将$k-=size[ls]$,然后递归右儿子即可,单次操作时间复杂度$O(log^2n)$。
第三个操作就是将之前权值对应的链中每个都删一下,然后把新权值对应的链再加回来即可。只考虑删除:首先包含权值的节点就是对应的叶子到根的链,有$log$个。对应的位置线段树都需要进行一次单点减。所以单次操作的时间复杂度是$O(log^2n)$。
第四个操作更好弄了。外层的权值线段树再维护一下最大权值然后在外层权值线段树二分即可,单次操作时间复杂度为$O(log^2n)$。
第五个操作同理。
这种做法的空间复杂度是$O(nlogn)$,时间复杂度$O(nlog^2n)$。
小结:这三种做法各有优劣吧,总体来讲带修改主席树好写,权值线段树套位置线段树优越(滑稽),位置线段树套非旋转$Treap$虽然常数较大但是更加直观而且调试时间明显少。