线段树

在线段树中,树中的结点和原数组中的区间本质上是同一个概念,因此,本文的用词将会灵活转换,以达到更好的阅读效果。

原文最后的 例题选讲 部分现已迁移到 这篇博客,读者可在掌握本文内容后自行前往阅读。

本文中的线段树模板均维护区间和。

基本概念

  1. 线段树:一种特殊的二叉搜索树。它用来维护一个数组中的区间和、区间最值等问题。线段树的每个结点表示原数组中的一段区间,层数低的结点表示大区间,层数高的结点表示小区间。能用线段树解决的问题必须满足传递性区间加法

  2. 区间加法:类似于区间\(dp\),大区间的信息可以从规模更小的区间推出。比如大区间之和可以两个规模更小的区间之和相加得出。经典的区间加法问题有区间和、区间最大值等问题,经典的非区间加法问题有区间最长下降子序列等问题。

  3. \(lazy\)标记:假设在维护线段树的时候,已经找到了被目标区间覆盖的结点,则我们可以只更新该结点的信息并打上\(lazy\)标记,不更新其子区间。待需要用到其子区间的时候,再下传\(lazy\)标记并更新信息,以节省代码效率和优化时间复杂度。

  4. 下传标记:用父结点的\(lazy\)标记更新当前结点的\(lazy\)标记,并相应地更新当前结点的信息,最后清空父结点\(lazy\)标记的流程。

存储方式

通常情况下,线段树会使用堆式存储法或者指针存储法来存储。

堆式存储法指用一个数组模拟一棵二叉树,假设当前结点为\(x\),则其左儿子编号为\(2x\),其右儿子编号为\(2x + 1\),其父结点编号为\(\lfloor \frac{x}{2} \rfloor\)。

指针存储法是指用一个结构体来表示树中一个结点,并额外加上指向左儿子和右儿子的指针。每次建立新结点的时候,使用new来新建一个结点,并更新其信息,以达到优化空间复杂度的效果。

堆式存储法的优势在于时间复杂度优秀,缺点在于浪费了很多没有用到的空间;指针存储法的优点在于空间复杂度优秀,缺点在于指针操作不易编写和结点操作常数太大。

结构及性质

  1. 线段树以一棵二叉树的形式维护;

  2. 线段树中的每一个结点都对应着原数组的一个区间;

  3. 结点的层次和对应区间的大小成反比;

  4. 线段树大约有\(O(logn)\)层,并且线段树是一棵较为平衡的二叉树;

  5. 存储线段树需要的空间大约是原数组长度的\(4\)到\(5\)倍。

建树

从根结点开始,依次更新它的左右子树信息,直到叶子结点为止。约定push_up函数为更新树状数组中结点信息的函数。

参考代码如下:

void build(int k, int l, int r) {
	tree[k].l = l;
	tree[k].r = r;
	if (l == r) {
		tree[k].sum = num[l];
		return;
	}
	int mid = (l + r) / 2;
	build(2 * k, l, mid);
	build(2 * k + 1, mid + 1, r);
	push_up(k);
}

单点修改

单点修改指更新原数组中的某个下标为\(x\)的元素,并相应地更新线段树的操作。

我们可以从根结点开始,查找该元素所在的区间:设当前区间的中间点为\(mid\),若\(x \leq mid\),说明\(x\)在当前结点的左孩子内;否则说明\(x\)在右孩子内。假如遇到了左右边界相等的区间,说明当前区间一定是待修改的元素。从当前元素开始回溯,将回溯路径上的结点信息全部更新即可。

参考代码如下:

void update(int k, int x, int y) {
	if (tree[k].l == tree[k].r) {
		tree[k].value = y;
		return;
	}
	int mid = (tree[k].l + tree[k].r) / 2;
	if (r <= mid) {
		update(2 * k, x, y);
	} else {
		update(2 * k + 1, x, y);
	}
	push_up(k);
}

区间修改

基本流程

区间修改指更新原数组中的某一段左右边界分别为\(l\)和\(r\)的区间,并相应地更新线段树的操作。

对于线段树中的某一个规模大于待修改区间的区间,其与待修改区间的关系一定在这三种关系之中:

  1. 待修改区间被完全包含在该区间的左孩子内;

  2. 待修改区间被完全包含在该区间的右孩子内;

  3. 待修改区间一部分属于该区间的左孩子,一部分属于该区间的右孩子。

首先下传当前结点的\(\textbf{lazy}\)标记,再分类讨论。对于第一种关系和第二种关系,分别在左孩子或右孩子内查询即可;对于第三种关系,我们可以把待修改区间以\(mid\)为中心切割成两个子区间,再分别在左孩子和右孩子内查询子区间,最后合并子区间以更新信息即可。

问题在于,修改具体应该如何实现?如果我们想建树一样,每一次都从根结点开始一层层更新信息,时间开销就会异常巨大。因此,我们需要一种新的概念以解决这种情况,也就是即将引入的\(\textbf{lazy}\)标记

\(lazy\)标记

\(lazy\)标记的含义是:对于当前已经被更新的结点,不更新其子区间的信息,并打上一个标记以表示该结点被更新,该结点的子结点未被更新。

因为我们在更新区间的时候,不一定要用到其子区间的值。因此,我们只需要更新规模大的区间,等到需要用到子区间的时候再更新子区间,达到提升更新效率的效果。这种操作相当于把更新的任务挪移一部分到查询中。又因为查询操作是从某一个结点开始合并其子结点,不用遍历整棵树,时间效率比纯更新高,所以我们可以在更新的时候只更新待更新的区间并打上\(lazy\)标记,下次用到子区间的时候再下传标记

参考代码如下:

void update(int k, int l, int r, long long x) {
	if (tree[k].l >= l && tree[k].r <= r) {
		tree[k].sum += (tree[k].r - tree[k].l + 1) * x;
		tree[k].lazy += x;
		return;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	if (l <= mid) {
		update(2 * k, l, r, x);
	}
	if (r > mid) {
		update(2 * k + 1, l, r, x);
	}
	push_up(k);
}

下传标记

如果我们在意某个区间的子区间,则我们需要下传标记,更新所有子区间的值。例如在查询的时候,假如我们遇到了打过\(lazy\)标记的结点,说明它的子结点没有被更新过。此时如果查询的区间有一部分在其子结点内,就会发生用旧信息回答查询的情况。

所以,在查询操作之前,我们需要将父结点的\(lazy\)标记继承到当前结点,并更新当前结点的信息,最后将父结点的\(lazy\)标记清空,表示其子结点已经被更新。这个流程被称为下传标记。

参考代码如下:

void push_down(int k) {
	if (tree[k].l == tree[k].r) {
		tree[k].lazy = 0;
		return;
	}
	tree[2 * k].sum += (tree[2 * k].r - tree[2 * k].l + 1) * tree[k].lazy;
	tree[2 * k + 1].sum += (tree[2 * k + 1].r - tree[2 * k + 1].l + 1) * tree[k].lazy;
	tree[2 * k].lazy += tree[k].lazy;
	tree[2 * k + 1].lazy += tree[k].lazy;
	tree[k].lazy = 0;
}

区间查询

如果我们要查询某个区间的信息,那么我们可以从根结点开始,找到所有查询区间的子区间,再将所有子区间的答案合并起来即可。

参考代码如下:

long long query(int k, int l, int r) {
	if (tree[k].l >= l && tree[k].r <= r) {
		return tree[k].sum;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	long long ret = 0;
	if (l <= mid) {
		ret += query(2 * k, l, r);
	}
	if (r > mid) {
		ret += query(2 * k + 1, l, r);
	}
	return ret;
}

模板

以带\(lazy\)标记的线段树为准。

单种操作

例题链接

#include <cstdio>
using namespace std;
#define maxn 100005

struct node {
	int l, r;
	long long sum, lazy;
}tree[5 * maxn];

int n, m;
long long num[maxn];

void push_up(int k) {
	tree[k].sum = (tree[2 * k].sum + tree[2 * k + 1].sum);
}

void push_down(int k) {
	if (tree[k].l == tree[k].r) {
		tree[k].lazy = 0;
		return;
	}
	tree[2 * k].sum += (tree[2 * k].r - tree[2 * k].l + 1) * tree[k].lazy;
	tree[2 * k + 1].sum += (tree[2 * k + 1].r - tree[2 * k + 1].l + 1) * tree[k].lazy;
	tree[2 * k].lazy += tree[k].lazy;
	tree[2 * k + 1].lazy += tree[k].lazy;
	tree[k].lazy = 0;
}

void build(int k, int l, int r) {
	tree[k].l = l;
	tree[k].r = r;
	if (l == r) {
		tree[k].sum = num[l];
		return;
	}
	int mid = (l + r) / 2;
	build(2 * k, l, mid);
	build(2 * k + 1, mid + 1, r);
	push_up(k);
}

void update(int k, int l, int r, long long x) {
	if (tree[k].l >= l && tree[k].r <= r) {
		tree[k].sum += (tree[k].r - tree[k].l + 1) * x;
		tree[k].lazy += x;
		return;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	if (l <= mid) {
		update(2 * k, l, r, x);
	}
	if (r > mid) {
		update(2 * k + 1, l, r, x);
	}
	push_up(k);
}

long long query(int k, int l, int r) {
	if (tree[k].l >= l && tree[k].r <= r) {
		return tree[k].sum;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	long long ret = 0;
	if (l <= mid) {
		ret += query(2 * k, l, r);
	}
	if (r > mid) {
		ret += query(2 * k + 1, l, r);
	}
	return ret;
}

int main() {
	int op, x, y;
	long long k;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%lld", &num[i]);
	}
	build(1, 1, n);
	for (int i = 1; i <= m; i++) {
		scanf("%d", &op);
		if (op == 1) {
			scanf("%d%d%lld", &x, &y, &k);
			update(1, x, y, k);
		} else if (op == 2) {
			scanf("%d%d", &x, &y);
			printf("%lld\n", query(1, x, y));
		}
	}
	return 0;
}

多种操作

例题链接

显然,对于有多种操作的线段树,\(lazy\)标记也要相应地增加。那么问题来了:我们应该怎样更新线段树的信息,如何下发\(lazy\)标记?

回顾乘法分配律:\((a + b)c = ac + bc\)。此处的\(a\)对应着线段树中的区间和,\(b\)则对应着加法标记,\(c\)对应着乘法标记。每次更新乘法标记的时候,加法标记也要相应地乘上乘法标记。这样分配律一下,区间和就是原本的区间和乘以乘法标记再加上加法标记。

根据上面的推论,设加法标记为\(add\),乘法标记为\(mul\),区间和为\(sum\)。\(add_{son} = add_{son} \times mul_{fa} + add_{fa}\),\(mul_{son} = mul_{son} \times mul_{fa}\),\(sum_{son} = sum_{son} \times mul_{fa} + add_{son} \times len_{son}\)。

参考代码如下:

#include <cstdio>
using namespace std;
#define maxn 100005

struct node {
	int l, r;
	long long sum, add, mul;
}tree[5 * maxn];

int n, m, p;
long long num[maxn];

void push_up(int k) {
	tree[k].sum = (tree[2 * k].sum + tree[2 * k + 1].sum) % p;
}

void push_down(int k) {
	tree[2 * k].sum = (tree[2 * k].sum * tree[k].mul + tree[k].add * (tree[2 * k].r - tree[2 * k].l + 1)) % p;
	tree[2 * k + 1].sum = (tree[2 * k + 1].sum * tree[k].mul + tree[k].add * (tree[2 * k + 1].r - tree[2 * k + 1].l + 1)) % p;
	tree[2 * k].mul = (tree[2 * k].mul * tree[k].mul) % p;
	tree[2 * k + 1].mul = (tree[2 * k + 1].mul * tree[k].mul) % p;
	tree[2 * k].add = (tree[2 * k].add * tree[k].mul + tree[k].add) % p;
	tree[2 * k + 1].add = (tree[2 * k + 1].add * tree[k].mul + tree[k].add) % p;
	tree[k].mul = 1;
	tree[k].add = 0;
}

void build(int k, int l, int r) {
	tree[k].l = l;
	tree[k].r = r;
	tree[k].mul = 1;
	if (l == r) {
		tree[k].sum = num[l] % p;
		return;
	}
	int mid = (l + r) / 2;
	build(2 * k, l, mid);
	build(2 * k + 1, mid + 1, r);
	push_up(k);
}

void update_mul(int k, int l, int r, long long x) {
	if (tree[k].l >= l && tree[k].r <= r) {
		tree[k].sum = (tree[k].sum * x) % p;
		tree[k].mul = (tree[k].mul * x) % p;
		tree[k].add = (tree[k].add * x) % p;
		return;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	if (l <= mid) {
		update_mul(2 * k, l, r, x);
	}
	if (r > mid) {
		update_mul(2 * k + 1, l, r, x);
	}
	push_up(k);
}

void update_add(int k, int l, int r, long long x) {
	if (tree[k].l >= l && tree[k].r <= r) {
		tree[k].add = (tree[k].add + x) % p;
		tree[k].sum = (tree[k].sum + x * (tree[k].r - tree[k].l + 1)) % p;
		return;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	if (l <= mid) {
		update_add(2 * k, l, r, x);
	}
	if (r > mid) {
		update_add(2 * k + 1, l, r, x);
	}
	push_up(k);
}

long long query(int k, int l, int r)
{
	if (tree[k].l >= l && tree[k].r <= r) {
		return tree[k].sum % p;
	}
	push_down(k);
	int mid = (tree[k].l + tree[k].r) / 2;
	long long ret = 0;
	if (l <= mid) {
		ret = (ret + query(2 * k, l, r)) % p;
	} 
	if (r > mid) {
		ret = (ret + query(2 * k + 1, l, r)) % p;
	}
	return ret;
}

int main() {
	int op, x, y;
	long long k;
	scanf("%d%d%d", &n, &m, &p);
	for (int i = 1; i <= n; i++) {
		scanf("%lld", &num[i]);
	}
	build(1, 1, n);
	for (int i = 1; i <= m; i++) {
		scanf("%d%d%d", &op, &x, &y);
		if (op == 1) {
			scanf("%lld", &k);
			update_mul(1, x, y, k);
		} else if (op == 2) {
			scanf("%lld", &k);
			update_add(1, x, y, k);
		} else {
			printf("%lld\n", query(1, x, y) % p);
		}
	}
	return 0;
}
上一篇:#离线,倒序,线段树#Comet OJ - Contest #15 E 栈的数据结构题


下一篇:ahk之路:利用ahk在window7下实现窗口置顶