分析
segment tree beats!模板题。
看了gxz的博客突然发现自己写的mxbt
和mnbt
两个标记没用诶。
代码
#include <bits/stdc++.h>
#define rin(i,a,b) for(register int i=(a);i<=(b);++i)
#define irin(i,a,b) for(register int i=(a);i>=(b);--i)
#define trav(i,a) for(register int i=head[a];i;i=e[i].nxt)
typedef long long LL;
using std::cin;
using std::cout;
using std::endl;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=5e5+5;
int n,m,a[MAXN];
int maxn[MAXN<<2],minn[MAXN<<2],
mxcnt[MAXN<<2],mncnt[MAXN<<2],
mxsec[MAXN<<2],mnsec[MAXN<<2],
tag[MAXN<<2],
ql,qr,kk;
LL sum[MAXN<<2];
bool mxbt[MAXN<<2],mnbt[MAXN<<2];
#define mid ((l+r)>>1)
#define lc (o<<1)
#define rc ((o<<1)|1)
inline void pushup(int o){
sum[o]=sum[lc]+sum[rc];
if(maxn[lc]>maxn[rc]){
maxn[o]=maxn[lc];
mxcnt[o]=mxcnt[lc];
mxsec[o]=std::max(mxsec[lc],maxn[rc]);
}
else if(maxn[lc]<maxn[rc]){
maxn[o]=maxn[rc];
mxcnt[o]=mxcnt[rc];
mxsec[o]=std::max(mxsec[rc],maxn[lc]);
}
else{
maxn[o]=maxn[lc];
mxcnt[o]=mxcnt[lc]+mxcnt[rc];
mxsec[o]=std::max(mxsec[lc],mxsec[rc]);
}
if(minn[lc]<minn[rc]){
minn[o]=minn[lc];
mncnt[o]=mncnt[lc];
mnsec[o]=std::min(mnsec[lc],minn[rc]);
}
else if(minn[lc]>minn[rc]){
minn[o]=minn[rc];
mncnt[o]=mncnt[rc];
mnsec[o]=std::min(mnsec[rc],minn[lc]);
}
else{
minn[o]=minn[lc];
mncnt[o]=mncnt[lc]+mncnt[rc];
mnsec[o]=std::min(mnsec[lc],mnsec[rc]);
}
}
inline void pushtag(int o,int l,int r,int _kk){
sum[o]+=1ll*_kk*(r-l+1);
maxn[o]+=_kk;
minn[o]+=_kk;
mxsec[o]+=_kk;
mnsec[o]+=_kk;
tag[o]+=_kk;
}
inline void pushmaxbeat(int o,int _kk){
if(minn[o]>=_kk) return;
else if(minn[o]<_kk){
sum[o]+=1ll*(_kk-minn[o])*mncnt[o];
if(maxn[o]==minn[o]) maxn[o]=_kk;
else if(mxsec[o]==minn[o]) mxsec[o]=_kk;
minn[o]=_kk,mxbt[o]=true;
}
}
inline void pushminbeat(int o,int _kk){
if(maxn[o]<=_kk) return;
else if(maxn[o]>_kk){
sum[o]+=1ll*(_kk-maxn[o])*mxcnt[o];
if(minn[o]==maxn[o]) minn[o]=_kk;
else if(mnsec[o]==maxn[o]) mnsec[o]=_kk;
maxn[o]=_kk,mnbt[o]=true;
}
}
inline void pushdown(int o,int l,int r){
if(mxbt[o]){
pushmaxbeat(lc,minn[o]-tag[o]);
pushmaxbeat(rc,minn[o]-tag[o]);
mxbt[o]=false;
}
if(mnbt[o]){
pushminbeat(lc,maxn[o]-tag[o]);
pushminbeat(rc,maxn[o]-tag[o]);
mnbt[o]=false;
}
if(tag[o]){
pushtag(lc,l,mid,tag[o]);
pushtag(rc,mid+1,r,tag[o]);
tag[o]=0;
}
}
void build(int o,int l,int r){
if(l==r){
maxn[o]=minn[o]=sum[o]=a[l];
mxcnt[o]=mncnt[o]=1;
mxsec[o]=-1e9,mnsec[o]=1e9;
return;
}
build(lc,l,mid);
build(rc,mid+1,r);
pushup(o);
}
void add(int o,int l,int r){
if(ql<=l&&r<=qr){
pushtag(o,l,r,kk);
return;
}
pushdown(o,l,r);
if(mid>=ql) add(lc,l,mid);
if(mid<qr) add(rc,mid+1,r);
pushup(o);
}
void checkmax(int o,int l,int r){
if(ql<=l&&r<=qr){
if(kk<=minn[o]) return;
else if(minn[o]<kk&&kk<mnsec[o]){
sum[o]+=1ll*(kk-minn[o])*mncnt[o];
if(maxn[o]==minn[o]) maxn[o]=kk;
else if(mxsec[o]==minn[o]) mxsec[o]=kk;
minn[o]=kk,mxbt[o]=true;
return;
}
}
pushdown(o,l,r);
if(mid>=ql) checkmax(lc,l,mid);
if(mid<qr) checkmax(rc,mid+1,r);
pushup(o);
}
void checkmin(int o,int l,int r){
if(ql<=l&&r<=qr){
if(kk>=maxn[o]) return;
else if(maxn[o]>kk&&kk>mxsec[o]){
sum[o]+=1ll*(kk-maxn[o])*mxcnt[o];
if(minn[o]==maxn[o]) minn[o]=kk;
else if(mnsec[o]==maxn[o]) mnsec[o]=kk;
maxn[o]=kk,mnbt[o]=true;
return;
}
}
pushdown(o,l,r);
if(mid>=ql) checkmin(lc,l,mid);
if(mid<qr) checkmin(rc,mid+1,r);
pushup(o);
}
LL querysum(int o,int l,int r){
if(ql<=l&&r<=qr) return sum[o];
pushdown(o,l,r);
LL ret=0;
if(mid>=ql) ret+=querysum(lc,l,mid);
if(mid<qr) ret+=querysum(rc,mid+1,r);
return ret;
}
int querymax(int o,int l,int r){
if(ql<=l&&r<=qr) return maxn[o];
pushdown(o,l,r);
int ret=-1e9;
if(mid>=ql) ret=std::max(ret,querymax(lc,l,mid));
if(mid<qr) ret=std::max(ret,querymax(rc,mid+1,r));
return ret;
}
int querymin(int o,int l,int r){
if(ql<=l&&r<=qr) return minn[o];
pushdown(o,l,r);
int ret=1e9;
if(mid>=ql) ret=std::min(ret,querymin(lc,l,mid));
if(mid<qr) ret=std::min(ret,querymin(rc,mid+1,r));
return ret;
}
#undef mid
#undef lc
#undef rc
int main(){
n=read();
rin(i,1,n) a[i]=read();
build(1,1,n);
m=read();
while(m--){
int opt=read(),l=read(),r=read();
if(opt==1){
int x=read();
ql=l,qr=r,kk=x;
add(1,1,n);
}
else if(opt==2){
int x=read();
ql=l,qr=r,kk=x;
checkmax(1,1,n);
}
else if(opt==3){
int x=read();
ql=l,qr=r,kk=x;
checkmin(1,1,n);
}
else if(opt==4){
ql=l,qr=r;
printf("%lld\n",querysum(1,1,n));
}
else if(opt==5){
ql=l,qr=r;
printf("%d\n",querymax(1,1,n));
}
else{
ql=l,qr=r;
printf("%d\n",querymin(1,1,n));
}
}
return 0;
}