在线段树中,树中的结点和原数组中的区间本质上是同一个概念,因此,本文的用词将会灵活转换,以达到更好的阅读效果。
原文最后的 例题选讲 部分现已迁移到 这篇博客,读者可在掌握本文内容后自行前往阅读。
本文中的线段树模板均维护区间和。
基本概念
-
线段树:一种特殊的二叉搜索树。它用来维护一个数组中的区间和、区间最值等问题。线段树的每个结点表示原数组中的一段区间,层数低的结点表示大区间,层数高的结点表示小区间。能用线段树解决的问题必须满足传递性和区间加法。
-
区间加法:类似于区间\(dp\),大区间的信息可以从规模更小的区间推出。比如大区间之和可以两个规模更小的区间之和相加得出。经典的区间加法问题有区间和、区间最大值等问题,经典的非区间加法问题有区间最长下降子序列等问题。
-
\(lazy\)标记:假设在维护线段树的时候,已经找到了被目标区间覆盖的结点,则我们可以只更新该结点的信息并打上\(lazy\)标记,不更新其子区间。待需要用到其子区间的时候,再下传\(lazy\)标记并更新信息,以节省代码效率和优化时间复杂度。
-
下传标记:用父结点的\(lazy\)标记更新当前结点的\(lazy\)标记,并相应地更新当前结点的信息,最后清空父结点\(lazy\)标记的流程。
存储方式
通常情况下,线段树会使用堆式存储法或者指针存储法来存储。
堆式存储法指用一个数组模拟一棵二叉树,假设当前结点为\(x\),则其左儿子编号为\(2x\),其右儿子编号为\(2x + 1\),其父结点编号为\(\lfloor \frac{x}{2} \rfloor\)。
指针存储法是指用一个结构体来表示树中一个结点,并额外加上指向左儿子和右儿子的指针。每次建立新结点的时候,使用new
来新建一个结点,并更新其信息,以达到优化空间复杂度的效果。
堆式存储法的优势在于时间复杂度优秀,缺点在于浪费了很多没有用到的空间;指针存储法的优点在于空间复杂度优秀,缺点在于指针操作不易编写和结点操作常数太大。
结构及性质
-
线段树以一棵二叉树的形式维护;
-
线段树中的每一个结点都对应着原数组的一个区间;
-
结点的层次和对应区间的大小成反比;
-
线段树大约有\(O(logn)\)层,并且线段树是一棵较为平衡的二叉树;
-
存储线段树需要的空间大约是原数组长度的\(4\)到\(5\)倍。
建树
从根结点开始,依次更新它的左右子树信息,直到叶子结点为止。约定push_up
函数为更新树状数组中结点信息的函数。
参考代码如下:
void build(int k, int l, int r) {
tree[k].l = l;
tree[k].r = r;
if (l == r) {
tree[k].sum = num[l];
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
push_up(k);
}
单点修改
单点修改指更新原数组中的某个下标为\(x\)的元素,并相应地更新线段树的操作。
我们可以从根结点开始,查找该元素所在的区间:设当前区间的中间点为\(mid\),若\(x \leq mid\),说明\(x\)在当前结点的左孩子内;否则说明\(x\)在右孩子内。假如遇到了左右边界相等的区间,说明当前区间一定是待修改的元素。从当前元素开始回溯,将回溯路径上的结点信息全部更新即可。
参考代码如下:
void update(int k, int x, int y) {
if (tree[k].l == tree[k].r) {
tree[k].value = y;
return;
}
int mid = (tree[k].l + tree[k].r) / 2;
if (r <= mid) {
update(2 * k, x, y);
} else {
update(2 * k + 1, x, y);
}
push_up(k);
}
区间修改
基本流程
区间修改指更新原数组中的某一段左右边界分别为\(l\)和\(r\)的区间,并相应地更新线段树的操作。
对于线段树中的某一个规模大于待修改区间的区间,其与待修改区间的关系一定在这三种关系之中:
-
待修改区间被完全包含在该区间的左孩子内;
-
待修改区间被完全包含在该区间的右孩子内;
-
待修改区间一部分属于该区间的左孩子,一部分属于该区间的右孩子。
首先下传当前结点的\(\textbf{lazy}\)标记,再分类讨论。对于第一种关系和第二种关系,分别在左孩子或右孩子内查询即可;对于第三种关系,我们可以把待修改区间以\(mid\)为中心切割成两个子区间,再分别在左孩子和右孩子内查询子区间,最后合并子区间以更新信息即可。
问题在于,修改具体应该如何实现?如果我们想建树一样,每一次都从根结点开始一层层更新信息,时间开销就会异常巨大。因此,我们需要一种新的概念以解决这种情况,也就是即将引入的\(\textbf{lazy}\)标记。
\(lazy\)标记
\(lazy\)标记的含义是:对于当前已经被更新的结点,不更新其子区间的信息,并打上一个标记以表示该结点被更新,该结点的子结点未被更新。
因为我们在更新区间的时候,不一定要用到其子区间的值。因此,我们只需要更新规模大的区间,等到需要用到子区间的时候再更新子区间,达到提升更新效率的效果。这种操作相当于把更新的任务挪移一部分到查询中。又因为查询操作是从某一个结点开始合并其子结点,不用遍历整棵树,时间效率比纯更新高,所以我们可以在更新的时候只更新待更新的区间并打上\(lazy\)标记,下次用到子区间的时候再下传标记。
参考代码如下:
void update(int k, int l, int r, long long x) {
if (tree[k].l >= l && tree[k].r <= r) {
tree[k].sum += (tree[k].r - tree[k].l + 1) * x;
tree[k].lazy += x;
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid) {
update(2 * k, l, r, x);
}
if (r > mid) {
update(2 * k + 1, l, r, x);
}
push_up(k);
}
下传标记
如果我们在意某个区间的子区间,则我们需要下传标记,更新所有子区间的值。例如在查询的时候,假如我们遇到了打过\(lazy\)标记的结点,说明它的子结点没有被更新过。此时如果查询的区间有一部分在其子结点内,就会发生用旧信息回答查询的情况。
所以,在查询操作之前,我们需要将父结点的\(lazy\)标记继承到当前结点,并更新当前结点的信息,最后将父结点的\(lazy\)标记清空,表示其子结点已经被更新。这个流程被称为下传标记。
参考代码如下:
void push_down(int k) {
if (tree[k].l == tree[k].r) {
tree[k].lazy = 0;
return;
}
tree[2 * k].sum += (tree[2 * k].r - tree[2 * k].l + 1) * tree[k].lazy;
tree[2 * k + 1].sum += (tree[2 * k + 1].r - tree[2 * k + 1].l + 1) * tree[k].lazy;
tree[2 * k].lazy += tree[k].lazy;
tree[2 * k + 1].lazy += tree[k].lazy;
tree[k].lazy = 0;
}
区间查询
如果我们要查询某个区间的信息,那么我们可以从根结点开始,找到所有查询区间的子区间,再将所有子区间的答案合并起来即可。
参考代码如下:
long long query(int k, int l, int r) {
if (tree[k].l >= l && tree[k].r <= r) {
return tree[k].sum;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
long long ret = 0;
if (l <= mid) {
ret += query(2 * k, l, r);
}
if (r > mid) {
ret += query(2 * k + 1, l, r);
}
return ret;
}
模板
以带\(lazy\)标记的线段树为准。
单种操作
#include <cstdio>
using namespace std;
#define maxn 100005
struct node {
int l, r;
long long sum, lazy;
}tree[5 * maxn];
int n, m;
long long num[maxn];
void push_up(int k) {
tree[k].sum = (tree[2 * k].sum + tree[2 * k + 1].sum);
}
void push_down(int k) {
if (tree[k].l == tree[k].r) {
tree[k].lazy = 0;
return;
}
tree[2 * k].sum += (tree[2 * k].r - tree[2 * k].l + 1) * tree[k].lazy;
tree[2 * k + 1].sum += (tree[2 * k + 1].r - tree[2 * k + 1].l + 1) * tree[k].lazy;
tree[2 * k].lazy += tree[k].lazy;
tree[2 * k + 1].lazy += tree[k].lazy;
tree[k].lazy = 0;
}
void build(int k, int l, int r) {
tree[k].l = l;
tree[k].r = r;
if (l == r) {
tree[k].sum = num[l];
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
push_up(k);
}
void update(int k, int l, int r, long long x) {
if (tree[k].l >= l && tree[k].r <= r) {
tree[k].sum += (tree[k].r - tree[k].l + 1) * x;
tree[k].lazy += x;
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid) {
update(2 * k, l, r, x);
}
if (r > mid) {
update(2 * k + 1, l, r, x);
}
push_up(k);
}
long long query(int k, int l, int r) {
if (tree[k].l >= l && tree[k].r <= r) {
return tree[k].sum;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
long long ret = 0;
if (l <= mid) {
ret += query(2 * k, l, r);
}
if (r > mid) {
ret += query(2 * k + 1, l, r);
}
return ret;
}
int main() {
int op, x, y;
long long k;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%lld", &num[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d", &op);
if (op == 1) {
scanf("%d%d%lld", &x, &y, &k);
update(1, x, y, k);
} else if (op == 2) {
scanf("%d%d", &x, &y);
printf("%lld\n", query(1, x, y));
}
}
return 0;
}
多种操作
显然,对于有多种操作的线段树,\(lazy\)标记也要相应地增加。那么问题来了:我们应该怎样更新线段树的信息,如何下发\(lazy\)标记?
回顾乘法分配律:\((a + b)c = ac + bc\)。此处的\(a\)对应着线段树中的区间和,\(b\)则对应着加法标记,\(c\)对应着乘法标记。每次更新乘法标记的时候,加法标记也要相应地乘上乘法标记。这样分配律一下,区间和就是原本的区间和乘以乘法标记再加上加法标记。
根据上面的推论,设加法标记为\(add\),乘法标记为\(mul\),区间和为\(sum\)。\(add_{son} = add_{son} \times mul_{fa} + add_{fa}\),\(mul_{son} = mul_{son} \times mul_{fa}\),\(sum_{son} = sum_{son} \times mul_{fa} + add_{son} \times len_{son}\)。
参考代码如下:
#include <cstdio>
using namespace std;
#define maxn 100005
struct node {
int l, r;
long long sum, add, mul;
}tree[5 * maxn];
int n, m, p;
long long num[maxn];
void push_up(int k) {
tree[k].sum = (tree[2 * k].sum + tree[2 * k + 1].sum) % p;
}
void push_down(int k) {
tree[2 * k].sum = (tree[2 * k].sum * tree[k].mul + tree[k].add * (tree[2 * k].r - tree[2 * k].l + 1)) % p;
tree[2 * k + 1].sum = (tree[2 * k + 1].sum * tree[k].mul + tree[k].add * (tree[2 * k + 1].r - tree[2 * k + 1].l + 1)) % p;
tree[2 * k].mul = (tree[2 * k].mul * tree[k].mul) % p;
tree[2 * k + 1].mul = (tree[2 * k + 1].mul * tree[k].mul) % p;
tree[2 * k].add = (tree[2 * k].add * tree[k].mul + tree[k].add) % p;
tree[2 * k + 1].add = (tree[2 * k + 1].add * tree[k].mul + tree[k].add) % p;
tree[k].mul = 1;
tree[k].add = 0;
}
void build(int k, int l, int r) {
tree[k].l = l;
tree[k].r = r;
tree[k].mul = 1;
if (l == r) {
tree[k].sum = num[l] % p;
return;
}
int mid = (l + r) / 2;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
push_up(k);
}
void update_mul(int k, int l, int r, long long x) {
if (tree[k].l >= l && tree[k].r <= r) {
tree[k].sum = (tree[k].sum * x) % p;
tree[k].mul = (tree[k].mul * x) % p;
tree[k].add = (tree[k].add * x) % p;
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid) {
update_mul(2 * k, l, r, x);
}
if (r > mid) {
update_mul(2 * k + 1, l, r, x);
}
push_up(k);
}
void update_add(int k, int l, int r, long long x) {
if (tree[k].l >= l && tree[k].r <= r) {
tree[k].add = (tree[k].add + x) % p;
tree[k].sum = (tree[k].sum + x * (tree[k].r - tree[k].l + 1)) % p;
return;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
if (l <= mid) {
update_add(2 * k, l, r, x);
}
if (r > mid) {
update_add(2 * k + 1, l, r, x);
}
push_up(k);
}
long long query(int k, int l, int r)
{
if (tree[k].l >= l && tree[k].r <= r) {
return tree[k].sum % p;
}
push_down(k);
int mid = (tree[k].l + tree[k].r) / 2;
long long ret = 0;
if (l <= mid) {
ret = (ret + query(2 * k, l, r)) % p;
}
if (r > mid) {
ret = (ret + query(2 * k + 1, l, r)) % p;
}
return ret;
}
int main() {
int op, x, y;
long long k;
scanf("%d%d%d", &n, &m, &p);
for (int i = 1; i <= n; i++) {
scanf("%lld", &num[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d%d%d", &op, &x, &y);
if (op == 1) {
scanf("%lld", &k);
update_mul(1, x, y, k);
} else if (op == 2) {
scanf("%lld", &k);
update_add(1, x, y, k);
} else {
printf("%lld\n", query(1, x, y) % p);
}
}
return 0;
}