给定一组数据(n个数据),进行m次操作,想要求某一段区间和,或者区间上同时加上或减去一个数。对于这种问题,采用最朴素的算法思想,求区间和的时间复杂度为O(mn),删改操作为O(mn^2),若使用前缀和预处理,可以将求区间和的复杂度降低至O(m),用差分预处理,也可以将删改的复杂度降至O(m)。但是如果m次操作中既有求区间和删改,那么时间复杂度将高得无法接受,于是我们可以使用线段树这种神奇的算法,两种操作的时间复杂度都可以降低至O(mlogn)。
线段树,顾名思义,每一个树的节点存储的信息为一个区间,对于本题,其对应的便是一个区间所有元素之和。树的根节点,存储的是1到n所有数据之和,其左节点存储1到n/2所有数据和,右节点存储n/2+1到n所有数据之和,之后以此类推。想要得到某一段数据和,只要从根节点开始搜索,只要O(logn)即可找到,删改操作也是同理。
线段树算法分为三步。
一.建树
void build(ll s,ll t,ll p){ if(s==t){ d[p]=a[s]; return; } ll m=s+((t-s)>>1); build(s,m,2*p); build(m+1,t,2*p+1); d[p]=d[2*p]+d[2*p+1]; }
二.更新操作(懒惰标记)
void add(ll l,ll r,ll c,ll s,ll t,ll p){ if(l<=s&&r>=t){ d[p]+=(t-s+1)*c; lazy[p]+=c; return; } ll m=s+((t-s)>>1); if(lazy[p]){ d[p*2]+=lazy[p]*(m-s+1); d[p*2+1]+=lazy[p]*(t-m); lazy[p*2]+=lazy[p]; lazy[p*2+1]+=lazy[p]; lazy[p]=0; } if(l<=m)add(l,r,c,s,m,p*2); if(r>m)add(l,r,c,m+1,t,p*2+1); d[p]=d[p*2]+d[p*2+1]; }
三.求区间和
ll getsum(ll l,ll r,ll s,ll t,ll p){ if(l<=s&&t<=r)return d[p]; ll m=s+((t-s)>>1); if(lazy[p]){ d[p*2]+=lazy[p]*(m-s+1); d[p*2+1]+=lazy[p]*(t-m); lazy[p*2]+=lazy[p]; lazy[p*2+1]+=lazy[p]; lazy[p]=0; } ll sum=0; if(l<=m)sum+=getsum(l,r,s,m,p*2); if(r>m)sum+=getsum(l,r,m+1,t,p*2+1); return sum; }
完整代码
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=300005; ll a[maxn]; ll d[maxn]; void build(ll s,ll t,ll p){ if(s==t){ d[p]=a[s]; return; } ll m=s+((t-s)>>1); build(s,m,2*p); build(m+1,t,2*p+1); d[p]=d[2*p]+d[2*p+1]; } ll lazy[maxn]; void add(ll l,ll r,ll c,ll s,ll t,ll p){ if(l<=s&&r>=t){ d[p]+=(t-s+1)*c; lazy[p]+=c; return; } ll m=s+((t-s)>>1); if(lazy[p]){ d[p*2]+=lazy[p]*(m-s+1); d[p*2+1]+=lazy[p]*(t-m); lazy[p*2]+=lazy[p]; lazy[p*2+1]+=lazy[p]; lazy[p]=0; } if(l<=m)add(l,r,c,s,m,p*2); if(r>m)add(l,r,c,m+1,t,p*2+1); d[p]=d[p*2]+d[p*2+1]; } ll getsum(ll l,ll r,ll s,ll t,ll p){ if(l<=s&&t<=r)return d[p]; ll m=s+((t-s)>>1); if(lazy[p]){ d[p*2]+=lazy[p]*(m-s+1); d[p*2+1]+=lazy[p]*(t-m); lazy[p*2]+=lazy[p]; lazy[p*2+1]+=lazy[p]; lazy[p]=0; } ll sum=0; if(l<=m)sum+=getsum(l,r,s,m,p*2); if(r>m)sum+=getsum(l,r,m+1,t,p*2+1); return sum; } int main(){ ll n,m; scanf("%lld %lld",&n,&m); for(ll i=1;i<=n;i++)scanf("%lld",&a[i]); build(1,n,1); while(m--){ int opt; scanf("%d",&opt); ll x,y; if(opt==1){ ll k; scanf("%lld %lld %lld",&x,&y,&k); add(x,y,k,1,n,1); } else{ scanf("%lld %lld",&x,&y); printf("%lld\n",getsum(x,y,1,n,1)); } } return 0; }