点分树
点分树,又叫动态点分治。普通点分治,每次询问都要找子树的重心,而点分树则是将子树的重心提出来,建一颗新树。树高不超过 \(logn\)。
点分树和原树有如下关系:
1.点分树中若两个点的 lca 为 x,则原树中两点的路径也过 x。
2.点分树中若两个点都在 x 的同一个子树中,则在原树中把 x 删掉,两点仍在一个连通块内。
可以理解为点分树的作用就是存下点分治时的访问顺序,不用每次都点分治求重心,以便动态操作。
P6329 震波
题意描述
给定一颗 n 个结点的带点权的树。每次询问查询与 x 距离不超过 k 的点的点权和 或 将 x 的点权修改为 y。
Solution
基本思路是在每一个节点开两个数据结构 \(S_0\) 和 \(S_1\)。\(S_0\) 保存的是当前结点 x 子树(不包括 x)的贡献,\(S_1\) 保存的是当前结点 x 的父节点除去 x 子树(不包括 x) 的贡献。
单次询问和单次修改的时间复杂度都为 点分树树高 * 数据结构单次询问或修改时间复杂度。
对于本道题可以选用树状数组。但对每个结点开一个 \(O(n)\) 的树状数组空间显然不够。所以对每个结点都只开它子树大小的空间,复杂度为 \(O(logn)\)。
具体实现
树状数组
struct BIT {
vector<int> c;
void build(int len) {while(c.size() < len + 5) c.push_back(0);}
void add(int x, int val) {while(x < c.size()) c[x] += val, x += x & -x;}
int query(int x) {
x = min(x, (int)c.size() - 1);
//注意这里一定要取 min,因为查询的长度很可能大于了子树大小
int res = 0;
while(x) res += c[x], x -= x & -x;
return res;
}
}s[2][MAXN];
询问两点间的距离
void dfs1(int now, int fa) {
for(int i = hd[now]; i; i = nt[i]) {
if(to[i] == fa) continue;
dep[to[i]] = dep[now] + 1;
st[to[i]][0] = now;
for(int j = 1; (1 << j) < dep[to[i]]; j++)
st[to[i]][j] = st[st[to[i]][j - 1]][j - 1];
dfs1(to[i], now);
}
}
int getlca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--)
if(dep[st[x][i]] >= dep[y]) x = st[x][i];
if(x == y) return x;
for(int i = 20; i >= 0; i--)
if(st[x][i] != st[y][i]) x = st[x][i], y = st[y][i];
return st[x][0];
}
int getdis(int x, int y) {
int lca = getlca(x, y);
return dep[x] + dep[y] - 2 * dep[lca];
}
建立点分树
void dfs2(int now, int fa) {
dfn[++cnt] = now;
siz[now] = 1, son[now] = -1;
for(int i = hd[now]; i; i = nt[i]) {
if(to[i] == fa || vis[to[i]]) continue;
dfs2(to[i], now);
siz[now] += siz[to[i]];
son[now] = max(son[now], siz[to[i]]);
}
}
void dfs3(int now, int fa, int len) {
s[0][root].add(len, a[now]);
for(int i = hd[now]; i; i = nt[i])
if(to[i] != fa && !vis[to[i]]) dfs3(to[i], now, len + 1);
}
void divid(int now, int fa) {
cnt = 0; dfs2(now, 0);
MX = INF;
for(int i = 1; i <= cnt; i++)
son[dfn[i]] = max(son[dfn[i]], cnt - siz[dfn[i]]);
for(int i = 1; i <= cnt; i++)
if(MX > son[dfn[i]]) MX = son[dfn[i]], root = dfn[i];
vis[root] = true;
now = root; f[now] = fa;
s[0][now].build(cnt); s[1][now].build(cnt);
if(fa) {
for(int i = 1; i <= cnt; i++)
s[1][now].add(getdis(dfn[i], fa), a[dfn[i]]);
}
for(int i = hd[now]; i; i = nt[i])
if(!vis[to[i]]) dfs3(to[i], now, 1);
for(int i = hd[now]; i; i = nt[i])
if(!vis[to[i]]) divid(to[i], now);
}
void prepare() {
dep[1] = 1; dfs1(1, 0);
divid(1, 0);
}
查询和修改
void modify(int x, int val) {
int now = x;
while(f[now]) {
int d = getdis(x, f[now]);
s[1][now].add(d, -a[x]);
s[1][now].add(d, val);
now = f[now];
s[0][now].add(d, -a[x]);
s[0][now].add(d, val);
}
a[x] = val;
}
int query(int x, int val) {
int now = x, res = 0;
res += a[x] + s[0][x].query(val);
while(f[now]) {
int d = getdis(f[now], x);
if(d <= val) res += a[f[now]] + s[0][f[now]].query(val - d) - s[1][now].query(val - d);
now = f[now];
}
return res;
}
完整代码
#include<cstdio>
#include<vector>
using namespace std;
const int MAXN = 1e5 + 5, INF = 0x7ffffff;
int n, m, a[MAXN];
int tot, hd[MAXN], to[MAXN << 1], nt[MAXN << 1];
int dep[MAXN], f[MAXN], st[MAXN][21];
int cnt, dfn[MAXN];
int MX, root, siz[MAXN], son[MAXN];
bool vis[MAXN];
struct BIT {
vector<int> c;
void build(int len) {while(c.size() < len + 5) c.push_back(0);}
void add(int x, int val) {
while(x < c.size())
c[x] += val,
x += x & -x;
}
int query(int x) {
x = min(x, (int)c.size() - 1);
int res = 0;
while(x) res += c[x], x -= x & -x;
return res;
}
}s[2][MAXN];
void add(int x, int y) {
to[++tot] = y, nt[tot] = hd[x], hd[x] = tot;
to[++tot] = x, nt[tot] = hd[y], hd[y] = tot;
}
void dfs1(int now, int fa) {
for(int i = hd[now]; i; i = nt[i]) {
if(to[i] == fa) continue;
dep[to[i]] = dep[now] + 1;
st[to[i]][0] = now;
for(int j = 1; (1 << j) < dep[to[i]]; j++)
st[to[i]][j] = st[st[to[i]][j - 1]][j - 1];
dfs1(to[i], now);
}
}
int getlca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--)
if(dep[st[x][i]] >= dep[y]) x = st[x][i];
if(x == y) return x;
for(int i = 20; i >= 0; i--)
if(st[x][i] != st[y][i]) x = st[x][i], y = st[y][i];
return st[x][0];
}
int getdis(int x, int y) {
int lca = getlca(x, y);
return dep[x] + dep[y] - 2 * dep[lca];
}
void dfs2(int now, int fa) {
dfn[++cnt] = now;
siz[now] = 1, son[now] = -1;
for(int i = hd[now]; i; i = nt[i]) {
if(to[i] == fa || vis[to[i]]) continue;
dfs2(to[i], now);
siz[now] += siz[to[i]];
son[now] = max(son[now], siz[to[i]]);
}
}
void dfs3(int now, int fa, int len) {
s[0][root].add(len, a[now]);
for(int i = hd[now]; i; i = nt[i])
if(to[i] != fa && !vis[to[i]]) dfs3(to[i], now, len + 1);
}
void divid(int now, int fa) {
cnt = 0; dfs2(now, 0);
MX = INF;
for(int i = 1; i <= cnt; i++)
son[dfn[i]] = max(son[dfn[i]], cnt - siz[dfn[i]]);
for(int i = 1; i <= cnt; i++)
if(MX > son[dfn[i]]) MX = son[dfn[i]], root = dfn[i];
vis[root] = true;
now = root; f[now] = fa;
s[0][now].build(cnt); s[1][now].build(cnt);
if(fa) {
for(int i = 1; i <= cnt; i++)
s[1][now].add(getdis(dfn[i], fa), a[dfn[i]]);
}
for(int i = hd[now]; i; i = nt[i])
if(!vis[to[i]]) dfs3(to[i], now, 1);
for(int i = hd[now]; i; i = nt[i])
if(!vis[to[i]]) divid(to[i], now);
}
void prepare() {
dep[1] = 1; dfs1(1, 0);
divid(1, 0);
}
void modify(int x, int val) {
int now = x;
while(f[now]) {
int d = getdis(x, f[now]);
s[1][now].add(d, -a[x]);
s[1][now].add(d, val);
now = f[now];
s[0][now].add(d, -a[x]);
s[0][now].add(d, val);
}
a[x] = val;
}
int query(int x, int val) {
int now = x, res = 0;
res += a[x] + s[0][x].query(val);
while(f[now]) {
int d = getdis(f[now], x);
if(d <= val) res += a[f[now]] + s[0][f[now]].query(val - d) - s[1][now].query(val - d);
now = f[now];
}
return res;
}
int main() {
scanf("%d%d",&n, &m);
for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
for(int i = 1, x, y; i < n; i++) {
scanf("%d%d",&x, &y);
add(x, y);
}
prepare();
int las = 0, op, x, y;
for(int i = 1; i <= m; i++) {
scanf("%d%d%d",&op, &x, &y);
x ^= las, y ^= las;
if(op == 0) printf("%d\n",las = query(x, y));
else if(op == 1) modify(x, y);
}
return 0;
}