[C++]树链剖分
预备知识
- 树的基础知识
- 关于这个本文有介绍
- 邻接表存图
- 线段树基础
- 最近公共祖先LCA
- 虽然用不到这个思想 但是有类似的
- 有助于快速理解代码
- 建议阅读这篇Blog
算法思想
树链剖分
顾名思义 就是把树形结构改良成链状结构
这样可以通过线段树方便的维护
为了更好的讲解
这里先列举出几个概念:
- 重儿子 是指当前节点的所有儿子中子树最大的儿子
- 重链 全部由重儿子组成的链
代码讲解
Code
#include<bits/stdc++.h>
#define maxn 200007
#define mid ((l+r)>>1)
#define li i<<1
#define ri 1+(i<<1)
using namespace std;
int n,m,root,mod;
int deep[maxn],father[maxn],son[maxn],sub[maxn];
int head[maxn],cnt,value[maxn];
int top[maxn],id[maxn],value_sort[maxn];
struct Edge{
int u,v;
Edge(int a = 0,int b = 0){
u = head[a];
v = b;
}
}e[maxn << 1];
struct Tree{
int l,r,sum;
int lazy;
}t[maxn << 1];
void Read(){
int a,b;
cin >> n >> m >> root >> mod;
for(int i = 1;i <= n;i++) cin >> value[i];
for(int i = 1;i < n;i++){
cin >> a >> b;
e[++cnt] = Edge(a,b);
head[a] = cnt;
e[++cnt] = Edge(b,a);
head[b] = cnt;
}
}
int dfs1(int u,int fa){
deep[u] = deep[fa] + 1;
father[u] = fa;
sub[u] = 1;
int maxson = -1;
for(int i = head[u];i;i = e[i].u){
int ev = e[i].v;
if(ev == fa) continue;
sub[u] += dfs1(ev,u);
if(sub[ev] > maxson){
maxson = sub[ev];
son[u] = ev;
}
}
return sub[u];
}
void dfs2(int u,int topf){
id[u] = ++cnt;
value_sort[cnt] = value[u];
top[u] = topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i = head[u];i;i = e[i].u){
int ev = e[i].v;
if(!id[ev])
dfs2(ev,ev);
}
}
void Build(int i,int l,int r){
t[i].l = l;
t[i].r = r;
if(l == r){
t[i].sum = value_sort[l];
return ;
}
Build(li,l,mid);
Build(ri,mid+1,r);
t[i].sum = t[li].sum + t[ri].sum;
}
void push(int i){
t[li].lazy = (t[li].lazy + t[i].lazy) % mod;
t[ri].lazy = (t[ri].lazy + t[i].lazy) % mod;
int mid_ = (t[i].l + t[i].r) >> 1;
t[li].sum = (t[li].sum + t[i].lazy * (mid_-t[i].l+1)) % mod;
t[ri].sum = (t[ri].sum + t[i].lazy * (t[i].r - mid_)) % mod;
t[i].lazy = 0;
}
void add(int i,int l,int r,int k){
if(l <= t[i].l && t[i].r <= r){
t[i].sum += k * (t[i].r - t[i].l + 1);
t[i].lazy += k;
return ;
}
if(t[i].lazy != 0) push(i);
if(t[li].r >= l)
add(li,l,r,k);
if(t[ri].l <= r)
add(ri,l,r,k);
t[i].sum = (t[li].sum + t[ri].sum) % mod;
}
int search(int i,int l,int r){
if(l <= t[i].l && t[i].r <= r)
return t[i].sum;
push(i);
int ans = 0;
if(t[li].r >= l) ans = (ans + search(li,l,r)) % mod;
if(t[ri].l <= r) ans = (ans + search(ri,l,r)) % mod;
return ans;
}
int search_tree(int x,int y){
int ans = 0;
while(top[x] != top[y]){
if(deep[top[x]] < deep[top[y]]) swap(x,y);
ans = (ans + search(1,id[top[x]],id[x])) % mod;
x = father[top[x]];
}
if(deep[x] > deep[y]) swap(x,y);
ans = (ans + search(1,id[x],id[y])) % mod;
return ans;
}
void add_tree(int x,int y,int k){
while(top[x] != top[y]){
if(deep[top[x]] < deep[top[y]]) swap(x,y);
add(1,id[top[x]],id[x],k);
x = father[top[x]];
}
if(deep[x] > deep[y]) swap(x,y);
add(1,id[x],id[y],k);
}
void interaction(){
int tot;
int x,y,z;
for(int i = 1;i <= m;i++){
cin >> tot;
if(tot == 1){
cin >> x >> y >> z;
add_tree(x,y,z%mod);
}
if(tot == 2){
cin >> x >> y;
cout << search_tree(x,y)%mod << endl;
}
if(tot == 3){
cin >> x >> z;
add(1,id[x],id[x]+sub[x]-1,z%mod);
}
if(tot == 4){
cin >> x;
cout << search(1,id[x],id[x]+sub[x]-1)%mod << endl;
}
}
}
int main(){
Read();
dfs1(root,0);
cnt = 0;
dfs2(root,root);
Build(1,1,n);
interaction();
return 0;
}