题目链接
题意:
给定数列\(a_{1}、a_{2}、a_{3}...a_{n}\),两种操作:
- \(1\) \(l\) \(r\) \(v\),将区间\(\left[ 1,v\right]\)内的\(a_{i}\)增加v。
- \(2\) \(l\) \(r\),询问\(\displaystyle\sum_{i=l}^{r}\displaystyle\sum_{j=l+1}^{r} a_{i}a_{j}\)的值。
思路:
直接用线段树维护。
法一:\(\displaystyle\sum_{i=l}^{r}\displaystyle\sum_{j=i+1}^{r} a_{i}a_{j}=\dfrac{
(\displaystyle\sum_{i=l}^{r}a_{i})^2-\displaystyle\sum_{i=l}^{r}a_{i}^2}{2}\),再推出+d对区间内价值的影响。
法二:令\(V_{l,r}=\displaystyle\sum_{i=l}^{r}\displaystyle\sum_{j=i+1}^{r} a_{i}a_{j},
sum_{l,r}=\displaystyle\sum_{i=l}^{r}a_{i}\),\(mid = l+r>>1\),
则区间\(V_{l,r}=V_{l,mid}+V_{mid+1,r}+sum_{l,mid}*sum_{mid+1,r}\),线段段每个结点保存所覆盖区间的价值和元素和。
对于区间\(\left[l, r\right]\)上加d,推出对区间价值改变公式:\(V_{l,r}=V_{l,r}+(r-l)*sum_{l,r}+(r-l+1)*(r-l)*d^{2}/2\)。
code:
法一:
点击查看代码
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <deque>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
#include <unordered_map>
#define fi first
#define se second
#define pb push_back
// #define endl "\n"
#define debug(x) cout << #x << ":" << x << endl;
#define bug cout << "********" << endl;
#define all(x) x.begin(), x.end()
#define lowbit(x) x & -x
#define fin(x) freopen(x, "r", stdin)
#define fout(x) freopen(x, "w", stdout)
#define ull unsigned long long
#define ll long long
const double eps = 1e-15;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1.0);
const int mod = 1e9 + 7;
const int maxn = 3e5 + 10;
using namespace std;
#define lson (rt << 1)
#define rson (rt << 1 | 1)
ll inv(ll a){
ll b = mod - 2, ret = 1;
while(b){
if(b & 1)ret = ret * a % mod;
b >>= 1;
a = a * a % mod;
}
return ret;
}
struct node{
ll sum, sum2;
node(ll sum, ll sum2):sum(sum), sum2(sum2){}
node(){sum = sum2 = 0;}
}t[maxn << 2];
int s[maxn], n, m;
ll lazy[maxn << 2];
void pushup(int rt){
t[rt].sum = (t[lson].sum + t[rson].sum) % mod;
t[rt].sum2 = (t[lson].sum2 + t[rson].sum2) % mod;
}
void pushdown(int rt, int len){
if(lazy[rt]){
ll d = lazy[rt], len1 = (len + 1) >> 1, len2 = len >> 1;
lazy[lson] = (lazy[lson] + d) % mod;
lazy[rson] = (lazy[rson] + d) % mod;
t[lson].sum2 = (t[lson].sum2 + 2ll * t[lson].sum * d + len1 * d % mod * d) % mod;
t[lson].sum = (t[lson].sum + len1 * d) % mod;
t[rson].sum2 = (t[rson].sum2 + 2ll * t[rson].sum * d + len2 * d % mod * d) % mod;
t[rson].sum = (t[rson].sum + len2 * d) % mod;
lazy[rt] = 0;
}
}
void build(int rt, int l, int r){
t[rt].sum = t[rt].sum2 = lazy[rt] = 0;
if(l == r){
t[rt].sum = s[l];
t[rt].sum2 = 1ll * s[l] * s[l] % mod;
return ;
}
int mid = l + r >> 1;
build(lson, l, mid), build(rson, mid + 1, r);
pushup(rt);
}
void update(int rt, int l, int r, int L, int R, int d){
if(L <= l && r <= R){
int len = r - l + 1;
lazy[rt] = (lazy[rt] + d) % mod;
t[rt].sum2 = (t[rt].sum2 + 2ll * d * t[rt].sum + 1ll * len * d % mod * d) % mod;
t[rt].sum = (t[rt].sum + 1ll * len * d) % mod;
return ;
}
pushdown(rt, r - l + 1);
int mid = l + r >> 1;
if(L <= mid)update(lson, l, mid, L, R, d);
if(mid < R)update(rson, mid + 1, r, L, R, d);
pushup(rt);
}
node query(int rt, int l, int r, int L, int R){
if(L <= l && r <= R)return t[rt];
int mid = l + r >> 1;
pushdown(rt, r - l + 1);
node ret1, ret2;
if(L <= mid)ret1 = query(lson, l, mid, L, R);
if(mid < R)ret2 = query(rson, mid + 1, r, L, R);
return node((ret1.sum + ret2.sum) % mod, (ret1.sum2 + ret2.sum2) % mod);
}
int main(){
int t;
scanf("%d", &t);
while(t --){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++)scanf("%d", &s[i]);
build(1, 1, n);
int a, l, r, v;
while(m --){
scanf("%d%d%d", &a, &l, &r);
if(a == 1){
scanf("%d", &v);
update(1, 1, n, l, r, v);
}
else{
node p = query(1, 1, n, l, r);
ll ans = (p.sum * p.sum % mod - p.sum2 + mod) % mod * inv(2) % mod;
printf("%lld\n", ans);
}
}
}
return 0;
}
法二:
点击查看代码
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <deque>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
#include <unordered_map>
#define fi first
#define se second
#define pb push_back
// #define endl "\n"
#define debug(x) cout << #x << ":" << x << endl;
#define bug cout << "********" << endl;
#define all(x) x.begin(), x.end()
#define lowbit(x) x & -x
#define fin(x) freopen(x, "r", stdin)
#define fout(x) freopen(x, "w", stdout)
#define ull unsigned long long
#define ll long long
const double eps = 1e-15;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1.0);
const int mod = 1e9 + 7;
const int maxn = 3e5 + 10;
using namespace std;
struct node{
ll sum, v;
node(ll ret1, ll v1):sum(ret1), v(v1){}
node(){sum = v = 0;}
}t[maxn << 2];
ll lazy[maxn << 2];
int s[maxn], n, m;
void pushup(int rt){
t[rt].sum = (t[rt << 1].sum + t[rt << 1 | 1].sum) % mod;
t[rt].v = (t[rt << 1].v + t[rt << 1 | 1].v + t[rt << 1].sum * t[rt << 1 | 1].sum % mod) % mod;
}
void pushdown(int rt, int len){
if(lazy[rt]){
int len1 = (len + 1) >> 1, len2 = len >> 1;
ll d = lazy[rt];
lazy[rt << 1] = (lazy[rt << 1] + d) % mod;
lazy[rt << 1 | 1] = (lazy[rt << 1 | 1] + d) % mod;
t[rt << 1].v = (t[rt << 1].v + t[rt << 1].sum * d % mod * (len1 - 1) + (len1 - 1) * len1 /2 % mod * (d * d % mod)) % mod;
t[rt << 1 | 1].v = (t[rt << 1 | 1].v + t[rt << 1 | 1].sum * d % mod * (len2 - 1) + (len2 - 1) * len2 /2 % mod * (d * d % mod)) % mod;
t[rt << 1].sum = (t[rt << 1].sum + len1 * d) % mod;
t[rt << 1 | 1].sum = (t[rt << 1 | 1].sum + len2 * d) % mod;
lazy[rt] = 0;
}
}
void build(int rt, int l, int r){
t[rt].v = lazy[rt] = 0;
if(l == r)return void(t[rt].sum = s[l]);
int mid = l + r >> 1;
build(rt << 1, l, mid), build(rt << 1 | 1, mid + 1, r);
pushup(rt);
}
void update(int rt, int l, int r, int L, int R, int d){
if(L <= l && r <= R){
int len = r - l + 1;
lazy[rt] = (lazy[rt] + d) % mod;
t[rt].v = (t[rt].v + t[rt].sum * d % mod * (len - 1) + 1ll * (len - 1) * len / 2 % mod * d % mod * d % mod) % mod;
t[rt].sum = (t[rt].sum + 1ll * len * d )% mod;
return ;
}
pushdown(rt, r - l + 1);
int mid = l + r >> 1;
if(L <= mid)update(rt << 1, l, mid, L, R, d);
if(mid < R)update(rt << 1 | 1, mid + 1, r, L, R, d);
pushup(rt);
}
node query(int rt, int l, int r, int L, int R){
if(L <= l && r <= R)return t[rt];
pushdown(rt, r - l + 1);
int mid = l + r >> 1;
node ret1, ret2;
if(L <= mid)ret1 = query(rt << 1, l, mid, L, R);
if(mid < R)ret2 = query(rt << 1 | 1, mid + 1, r, L, R);
return node((ret1.sum + ret2.sum) % mod, (ret1.v + ret2.v + ret1.sum * ret2.sum) % mod);
}
int main(){
int t;
scanf("%d", &t);
while(t --){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++)scanf("%d", &s[i]);
build(1, 1, n);
int a, l, r, v;
while(m --){
scanf("%d%d%d", &a, &l, &r);
if(a == 1){
scanf("%d", &v);
update(1, 1, n, l, r, v);
}
else{
node p = query(1, 1, n, l, r);
printf("%lld\n", p.v);
}
}
}
return 0;
}