题目大意:
有一个01序列,现在对于这个序列有五种变换操作和询问操作: 0 a b 把[a, b]区间内的所有数全变成0;1 a b 把[a, b]区间内的所有数全变成1;2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0;3 a b 询问[a, b]区间内总共有多少个1;4 a b 询问[a, b]区间内最多有多少个连续的1。
思路:
维护每一段数的和、左端和右端以及整段中连续的0和1的长度,并使用标记进行下传。
代码:
#include<cstdio>
#include<cstring>
#include<iostream>
#define N 400000
using namespace std; int root,num,cnt,lmax[][N],rmax[][N],mmax[][N],sum[N],tag[][N],rev[N],h[N],t[N],rc[N],lc[N],a[N]; void up_date(int x,int k)
{
int l=lc[x],r=rc[x];
lmax[k][x]=lmax[k][l],rmax[k][x]=rmax[k][r];
if (lmax[k][l]==t[l]-h[l]+) lmax[k][x]+=lmax[k][r];
if (rmax[k][r]==t[r]-h[r]+) rmax[k][x]+=rmax[k][l];
mmax[k][x]=max(max(mmax[k][l],mmax[k][r]),rmax[k][l]+lmax[k][r]);
} void build(int l,int r,int &cur)
{
cur=++num,tag[][cur]=tag[][cur]=rev[cur]=;
h[cur]=l,t[cur]=r;
if (l==r)
{
lc[cur]=rc[cur]=;
lmax[][cur]=rmax[][cur]=mmax[][cur]=(a[l]==);
lmax[][cur]=rmax[][cur]=sum[cur]=mmax[][cur]=(a[r]==);
return;
}
int mid=l+r>>;
build(l,mid,lc[cur]),build(mid+,r,rc[cur]);
up_date(cur,),up_date(cur,),sum[cur]=sum[lc[cur]]+sum[rc[cur]];
} void mark(int x,int k)
{
tag[k][x]=,tag[k^][x]=rev[x]=,sum[x]=(t[x]-h[x]+)*k;
lmax[k][x]=rmax[k][x]=mmax[k][x]=t[x]-h[x]+;
lmax[k^][x]=rmax[k^][x]=mmax[k^][x]=;
} void re(int x)
{
rev[x]^=,sum[x]=t[x]-h[x]+-sum[x];
swap(lmax[][x],lmax[][x]),swap(rmax[][x],rmax[][x]),swap(mmax[][x],mmax[][x]);
} void push_down(int x)
{
if (tag[][x]) mark(lc[x],),mark(rc[x],),tag[][x]=;
if (tag[][x]) mark(lc[x],),mark(rc[x],),tag[][x]=;
if (rev[x]) re(lc[x]),re(rc[x]),rev[x]=;
} void change(int l,int r,int cur,int k)
{
if (h[cur]>r || t[cur]<l) return;
if (h[cur]>=l && t[cur]<=r)
{
if (k<) mark(cur,k);
else re(cur);
return;
}
push_down(cur);
change(l,r,lc[cur],k),change(l,r,rc[cur],k);
up_date(cur,),up_date(cur,),sum[cur]=sum[lc[cur]]+sum[rc[cur]];
} int ask1(int l,int r,int cur)
{
if (h[cur]>r || t[cur]<l) return ;
if (h[cur]>=l && t[cur]<=r) return sum[cur];
push_down(cur);
return ask1(l,r,lc[cur])+ask1(l,r,rc[cur]);
} int ask2(int l,int r,int cur)
{
if (l==h[cur] && r==t[cur]) return cur;
int mid=h[cur]+t[cur]>>;
push_down(cur);
if (r<=mid) return ask2(l,r,lc[cur]);
else if (l>mid) return ask2(l,r,rc[cur]);
else
{
int ans=++cnt;
lc[ans]=ask2(l,mid,lc[cur]),rc[ans]=ask2(mid+,r,rc[cur]);
up_date(ans,),up_date(ans,),sum[ans]=sum[lc[ans]]+sum[rc[ans]];
return ans;
}
} int main()
{
int n,m,i,x,y,z;
scanf("%d%d",&n,&m);
for (i=;i<=n;i++) scanf("%d",&a[i]);
build(,n,root);
for (i=;i<=m;i++)
{
scanf("%d%d%d",&z,&x,&y);
x++,y++;
if (z==) printf("%d\n",ask1(x,y,root));
else if (z==) cnt=num,printf("%d\n",mmax[][ask2(x,y,root)]);
else change(x,y,root,z);
}
return ;
}