引入
序列上的很多问题可以用线段树等数据结构维护,树上的问题则可以转化为序列上的问题,然后用相应的数据结构维护。
树链剖分可以在\(O(logn)\)的复杂度内将树上的一条链转化为\(O(logn)\)个区间。
树链剖分
我们以P3384 【模板】轻重链剖分为例讲解
长什么样子
图中一棵树,每个点有点权,假如要查询10到5之间的路径上的点权和。我们将这段路径分为蓝色的1,2,3,4四段路径,分别求出这四个路径的路径和,再相加得到答案。
怎样求这一段路径的路径和
预处理的时候,将树转化为一个数组,给每个点一个下标seg[]
,该位置的元素即为该点的点权,用线段树维护这个数组。
通过一种巧妙的给seg
的方法,可以满足上述要求的这一段路径从上到下所有点依次在线段树中连续出现,于是就可以直接查询这个区间和,复杂度\(O(logn)\)
单个小路径查询复杂度\(O(logn)\),一条一般路径查询可以分为至多\(O(logn)\)个小路径,所以单次路径查询复杂度\(O(log^2n)\)
重要概念
轻重边、重儿子、重链:每个点的儿子中,子树内点数最多的儿子与它之间的边为重边,其余儿子与它之间的边为轻边,这个儿子为重儿子。只由重边构成的链叫重链。
(部分树)图中黄色的边为轻边,红色的边为重边,蓝色框起来的为重链。
可以发现,因为每个非叶结点有且仅有一个重儿子,所以重链一定是从上到下的,不会在lca处“拐弯”。
单个一个叶子结点也算一条重链,所以每个点一定在且仅在一个重链上。
一个点的轻儿子是一条重链的最顶端。
两次dfs——树剖核心操作
通过第一次dfs,我们可以求出每个点的重儿子,通过第二次dfs,我们可以给每个点赋seg
保证一条重链上的点编号连续。
第一次dfs:
处理出每个点的父亲fa[]
、深度dep[]
、子树大小siz[]
、重儿子son[]
。没什么难度。
参考代码:
void dfs1(int x, int F) { // F为当前点x的父节点
fa[x] = F, dep[x] = dep[F] + 1, siz[x] = 1;
for (int i = 0; i < G[x].size(); ++i) {
int y = G[x][i]; if (y == F) continue;
dfs1(y, x);
siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
第二次dfs:
处理出每个点所在重链顶端的点top[]
、每个点在线段树中的下标seg[]
、线段树中一个下标在树上对应的点rev[]
。
为什么要处理这些,后面再说。
怎样保证重链seg
连续?我们只需要保证从重链最顶端的点开始,往下每个重儿子的seg
为这个点seg
+1即可。
这个seg
实际上就是一个时间戳,翻译过来前面的要求就是保证dfs到每个点后,下一个dfs的点为它的重儿子。
时间戳可以新开一个变量表示,但我习惯用seg[0]
表示,因为这个地方没有实际意义(没有点0)。
参考代码:
void dfs2(int x, int F) { // F为当前点x所在重链顶端的点
top[x] = F;
seg[x] = ++ seg[0], rev[seg[x]] = x; // rev[]根据定义求即可
if (son[x]) dfs2(son[x], F);
for (int i = 0; i < G[x].size(); ++ i) {
int y = G[x][i];
if (y != fa[x] && y != son[x]) dfs2(y, y); // 每一个轻儿子都是一条重链的顶端
}
}
查询
如果要查询的两个点在同一条重链上(一点是另一点祖先),说明它们之间路径上的点为线段树中连续区间,直接查询即可。
如果不是在同一条重链上,我们让它们向上跳到同一条重链上。
每次让一个点跳到它所在重链的最顶端的点的父亲(完全跳过当前所在重链),同时将刚刚跳过的这个重链的答案在线段树中找到,加上。
每次跳之前判断当前两点是否在同一个重链内,如果不在,继续跳;如果跳了若干次之后在同一个重链内,终止跳的过程。
这时问题成为最开始的样子,直接加上这一部分重链答案,返回即可。
以上是大体思路,具体看代码。
参考代码:
inline int ask_path(int x, int y) { // x,y表示询问x到y路径上点权和(包括x和y)
int fx = top[x], fy = top[y], ans = 0; // 令fx,fy表示当前两个点x,y对应的重链顶端的点
while (fx != fy) { // 两个点在同一条重链当且仅当它们当前所在重链的顶部节点相同
// 如果两个点在同一条重链,停止跳
if (dep[fx] < dep[fy]) swap(fx, fy), swap(x, y);
// 每次让“深度”大的往上跳,总令x表示当前要向上跳的点,所以如果fx深度小于fy,交换即可
ans += query(1, 1, seg[0], seg[fx], seg[x]); // [seg[fx], seg[x]]即为当前所在重链在线段树上的下标区间
x = fa[fx], fx = top[x]; // x跳到上面重链底端(刚好跳过当前重链),fx跳到新重链的顶端
}
if (dep[x] < dep[y]) swap(x, y); // 为了之间的区间可以表示成[seg[y], seg[x]]的形式(y深度小)
ans += query(1, 1, seg[0], seg[y], seg[x]);
return ans;
}
为什么要每次让“深度”大的向上跳?(这里的深度其实是所在重链顶端节点的深度。)
如果不这样,图中情况令y先跳,则会跳到不在查询路径上的2点,答案错误。正确应该令x跳到1点,然后直接加上1、y之间的答案。
修改
学会查询,修改就是大同小异了,将query改成change即可,在每一段重链上区间修改。
线段树部分
与普通线段树不同的部分仅在于build(建树)时,当l==r,要将该点权值赋为a[rev[l]],而不是a[l],因为维护的序列下标为l处的点为rev[l],点权a[rev[l]]。
另外如果怕查询和修改时将一条重链上两端点代表区间写反,可以在线段树的查询和修改操作最开始加上if (x > y) swap(x, y);
子树的查询/修改
与路径查询不同,子树查询可以直接在线段树上查。
再次注意我们的seg
是一个dfs序,我们只不过交换了一个点先dfs进入哪些儿子,它还满足dfs序的特征:一个子树的所有节点的dfs序连续。dfs序中第一个点当然为子树根节点seg[x],子树内有siz[x]个点,所以它在线段树上的区间为[seg[x], seg[x] + siz[x] - 1]。
整理思路
讲了这么多,初学还不太直观,是不是云里雾里?
先看一遍完整代码吧(为了减少篇幅,已将快读快写省去):
#include <bits/stdc++.h>
using namespace std;
const int N = 100005;
int n, m, root, P, a[N]; // 输入
vector<int> G[N];
inline int add(int x, int y) {
return x + y >= P ? x + y - P : x + y;
}
int fa[N], dep[N], siz[N], son[N];
int top[N], seg[N], rev[N];
// SegmentTree
struct T {
int l, r, sum, tag;
} t[N << 2];
#define ls p << 1
#define rs p << 1 | 1
inline void pushup(int p) {
t[p].sum = add(t[ls].sum, t[rs].sum);
}
#define upd pushup
void build(int p, int l, int r) {
t[p].tag = 0, t[p].l = l, t[p].r = r;
if (l == r)
return t[p].sum = a[rev[l]], void();
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
upd(p);
}
inline void pushdown(int p) {
if (t[p].tag) {
t[ls].tag = add(t[ls].tag, t[p].tag);
t[rs].tag = add(t[rs].tag, t[p].tag);
t[ls].sum = add(t[ls].sum, (t[ls].r - t[ls].l + 1) * t[p].tag % P);
t[rs].sum = add(t[rs].sum, (t[rs].r - t[rs].l + 1) * t[p].tag % P);
t[p].tag = 0;
}
}
#define pd pushdown
void change(int p, int l, int r, int x, int y, int k) {
if (x > y) swap(x, y);
if (x <= l && r <= y) {
t[p].tag = add(t[p].tag, k);
t[p].sum = add(t[p].sum, (t[p].r - t[p].l + 1) * k % P);
return;
}
pd(p);
int mid = (l + r) >> 1;
if (x <= mid) change(ls, l, mid, x, y, k);
if (y > mid) change(rs, mid + 1, r, x, y, k);
upd(p);
}
int query(int p, int l, int r, int x, int y) {
if (x > y) swap(x, y);
if (x <= l && r <= y) return t[p].sum;
pd(p);
int mid = (l + r) >> 1, ans = 0;
if (x <= mid) ans = add(ans, query(ls, l, mid, x, y));
if (y > mid) ans = add(ans, query(rs, mid + 1, r, x, y));
return ans;
}
// 树剖
void dfs1(int x, int F) {
fa[x] = F, dep[x] = dep[F] + 1, siz[x] = 1;
for (int i = 0; i < G[x].size(); ++i) {
int y = G[x][i]; if (y == F) continue;
dfs1(y, x);
siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
void dfs2(int x, int F) {
top[x] = F;
seg[x] = ++ seg[0], rev[seg[x]] = x;
if (son[x]) dfs2(son[x], F);
for (int i = 0; i < G[x].size(); ++ i) {
int y = G[x][i];
if (y != fa[x] && y != son[x]) dfs2(y, y);
}
}
inline void add_path(int x, int y, int k) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(fx, fy), swap(x, y);
change(1, 1, seg[0], seg[fx], seg[x], k);
x = fa[fx], fx = top[x];
}
if (dep[x] < dep[y]) swap(x, y);
change(1, 1, seg[0], seg[y], seg[x], k);
}
inline int ask_path(int x, int y) {
int fx = top[x], fy = top[y], ans = 0;
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(fx, fy), swap(x, y);
ans = add(ans, query(1, 1, seg[0], seg[fx], seg[x]));
x = fa[fx], fx = top[x];
}
if (dep[x] < dep[y]) swap(x, y);
ans = add(ans, query(1, 1, seg[0], seg[y], seg[x]));
return ans;
}
signed main() {
n = read(), m = read(), root = read(), P = read();
for (int i = 1; i <= n; ++i) a[i] = read() % P;
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
G[u].push_back(v), G[v].push_back(u);
}
dfs1(root, root), dfs2(root, root), build(1, 1, seg[0]);
for (int i = 1; i <= m; ++i) {
int op = read();
if (op == 1) {
int x = read(), y = read(), z = read();
add_path(x, y, z);
}
else if (op == 2) {
int x = read(), y = read();
print(ask_path(x, y)), putchar('\n');
}
else if (op == 3) {
int x = read(), z = read();
change(1, 1, seg[0], seg[x], seg[x] + siz[x] - 1, z);
}
else {
int x = read();
print(query(1, 1, seg[0], seg[x], seg[x] + siz[x] - 1)), putchar('\n');
}
}
return 0;
}
-
(读入完树和点权后)先dfs1、dfs2两遍求出树剖需要信息,再建线段树。
-
有路径修改就
add_path
分为一些重链在线段树上修改; -
有路径查询就
ask_path
分为一些重链在线段树上查询,将答案累加起来; -
有子树修改就直接修改线段树上一段区间,有子树查询就在线段树中直接查区间。
-
线段树都为常规操作,树剖就是那四个函数的板子,默写上去。
能看懂吧,看不懂的,自己画画图吧。基础讲解就过了。(雾~
一些简单扩展
查询一条路径上的所有边的边权和
边权化为点权即可,一条边用它的深度大的那个端点表示,即一个点的点权为它到它父亲的边的边权。
转化完之后要注意一个小细节:一次路径查询/修改的时候,路径的lca的点权不能加上,(画图可以清楚看出),因为这个点权的含义为它到它父亲的边权,不在查询/修改的路径上。
要做到这一点也非常容易,看代码:
inline void add_path2(int x, int y, int k) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(fx, fy), swap(x, y);
change2(1, 1, seg[0], seg[fx], seg[x], k);
x = fa[fx], fx = top[x];
}
if (x == y) return;
if (dep[x] < dep[y]) swap(x, y);
change2(1, 1, seg[0], seg[y] + 1, seg[x], k); // 只在这里将深度小的y点去掉即可
}
如果是子树查询/修改呢?也很简单,把区间的左端点(子树根节点)去掉即可:
else if (op == 3) {
int x = read(), z = read();
change(1, 1, seg[0], seg[x] + 1, seg[x] + siz[x] - 1, z);
}
else {
int x = read();
print(query(1, 1, seg[0], seg[x] + 1, seg[x] + siz[x] - 1)), putchar('\n');
}
树链剖分求lca
如果你看懂了上面的基础树剖,很容易发现,其实add_path函数和ask_path函数跳到同一条重链后,那个深度小的点就是lca。
dfs1和dfs2和普通树剖一样,但不用维护seg[]和rev[](不涉及线段树)
参考代码:
inline int lca(int x, int y) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) swap(fx, fy), swap(x, y);
x = fa[fx], fx = top[x];
}
return dep[x] < dep[y] ? x : y;
}
更多扩展应用
我一开始说了,树剖是将链转化为线段树上一些区间,可不只是线段树1中涉及的操作那样简单,它可以是线段树2那样的多优先级懒标记,也可以是GSS1那样的最大子段和,也可以是各种线段树......
不过,像链上最大子段和,仔细想会发现,需要在线段树上维护的问题与两个要合并的区间哪个在左哪个在右有关,所以写ask_path时需要考虑的问题非常多,对码力要求非常高。
后记
这个题单是比较不错的,再难真没必要了(noip)。
树剖在树上是很常用的工具,要写得很熟很熟,学的前期10道题以内可能调试会很崩溃,练熟之后,基本不怎么需要调试。
树剖求lca确实在时间和空间上比倍增优秀很多。
空间复杂度:线性
时间复杂度(它轻重链的设计就是为了保证分成logn个区间,证明简单没用,我就不证明了):
预处理:dfs:线性, build:\(O(nlogn)\)
查询:树剖是\(O(logn)\)的,如果带线段树是\(O(log^2n)\)的,如果只有子树查询,(没必要写树剖)是\(O(logn)\)的
很多树上问题,不用手在草稿纸上画画图是不行的,没太懂可以手玩一会儿。练熟之后会觉得非常好理解,脑子里有图。