点分树

点分树

点分树,又叫动态点分治。普通点分治,每次询问都要找子树的重心,而点分树则是将子树的重心提出来,建一颗新树。树高不超过 \(logn\)

点分树和原树有如下关系:
1.点分树中若两个点的 lca 为 x,则原树中两点的路径也过 x。
2.点分树中若两个点都在 x 的同一个子树中,则在原树中把 x 删掉,两点仍在一个连通块内。

可以理解为点分树的作用就是存下点分治时的访问顺序,不用每次都点分治求重心,以便动态操作。

P6329 震波

题意描述

给定一颗 n 个结点的带点权的树。每次询问查询与 x 距离不超过 k 的点的点权和 或 将 x 的点权修改为 y。

Solution

基本思路是在每一个节点开两个数据结构 \(S_0\)\(S_1\)\(S_0\) 保存的是当前结点 x 子树(不包括 x)的贡献,\(S_1\) 保存的是当前结点 x 的父节点除去 x 子树(不包括 x) 的贡献。

单次询问和单次修改的时间复杂度都为 点分树树高 * 数据结构单次询问或修改时间复杂度

对于本道题可以选用树状数组。但对每个结点开一个 \(O(n)\) 的树状数组空间显然不够。所以对每个结点都只开它子树大小的空间,复杂度为 \(O(logn)\)

具体实现

树状数组

struct BIT {
	vector<int> c;
	void build(int len) {while(c.size() < len + 5) c.push_back(0);}
	void add(int x, int val) {while(x < c.size()) c[x] += val, x += x & -x;}
	int query(int x) {
		x = min(x, (int)c.size() - 1);
                //注意这里一定要取 min,因为查询的长度很可能大于了子树大小
		int res = 0;
		while(x) res += c[x], x -= x & -x;
		return res;
	}
}s[2][MAXN];

询问两点间的距离

void dfs1(int now, int fa) {
	for(int i = hd[now]; i; i = nt[i]) {
		if(to[i] == fa) continue;
		dep[to[i]] = dep[now] + 1;
		st[to[i]][0] = now;
		for(int j = 1; (1 << j) < dep[to[i]]; j++)
		 st[to[i]][j] = st[st[to[i]][j - 1]][j - 1];
		dfs1(to[i], now);
	}
}

int getlca(int x, int y) {
	if(dep[x] < dep[y]) swap(x, y);
	for(int i = 20; i >= 0; i--)
	 if(dep[st[x][i]] >= dep[y]) x = st[x][i];
	if(x == y) return x;
	for(int i = 20; i >= 0; i--)
	 if(st[x][i] != st[y][i]) x = st[x][i], y = st[y][i];
	return st[x][0];
}

int getdis(int x, int y) {
	int lca = getlca(x, y);
	return dep[x] + dep[y] - 2 * dep[lca];
}

建立点分树

void dfs2(int now, int fa) {
	dfn[++cnt] = now;
	siz[now] = 1, son[now] = -1;
	for(int i = hd[now]; i; i = nt[i]) {
		if(to[i] == fa || vis[to[i]]) continue;
		dfs2(to[i], now);
		siz[now] += siz[to[i]];
		son[now] = max(son[now], siz[to[i]]);
	}
}

void dfs3(int now, int fa, int len) {
	s[0][root].add(len, a[now]);
	for(int i = hd[now]; i; i = nt[i])
	 if(to[i] != fa && !vis[to[i]]) dfs3(to[i], now, len + 1);
}

void divid(int now, int fa) {
	cnt = 0; dfs2(now, 0);
	MX = INF;
	for(int i = 1; i <= cnt; i++)
	 son[dfn[i]] = max(son[dfn[i]], cnt - siz[dfn[i]]);
	for(int i = 1; i <= cnt; i++)
	 if(MX > son[dfn[i]]) MX = son[dfn[i]], root = dfn[i];
	vis[root] = true;
	now = root; f[now] = fa;
	s[0][now].build(cnt); s[1][now].build(cnt);
	if(fa) {
		for(int i = 1; i <= cnt; i++)
		 s[1][now].add(getdis(dfn[i], fa), a[dfn[i]]);
	}
	for(int i = hd[now]; i; i = nt[i])
	 if(!vis[to[i]]) dfs3(to[i], now, 1);
	for(int i = hd[now]; i; i = nt[i])
	 if(!vis[to[i]]) divid(to[i], now);
}

void prepare() {
	dep[1] = 1; dfs1(1, 0);
	divid(1, 0);
}

查询和修改

void modify(int x, int val) {
	int now = x;
	while(f[now]) {
		int d = getdis(x, f[now]);
		s[1][now].add(d, -a[x]);
		s[1][now].add(d, val);
		now = f[now];
		s[0][now].add(d, -a[x]);
		s[0][now].add(d, val);
	}
	a[x] = val;
}

int query(int x, int val) {
	int now = x, res = 0;
	res += a[x] + s[0][x].query(val);
	while(f[now]) {
		int d = getdis(f[now], x);
		if(d <= val) res += a[f[now]] + s[0][f[now]].query(val - d) - s[1][now].query(val - d);
		now = f[now];
	}
	return res;
}

完整代码

#include<cstdio>
#include<vector>
using namespace std;

const int MAXN = 1e5 + 5, INF = 0x7ffffff;

int n, m, a[MAXN];
int tot, hd[MAXN], to[MAXN << 1], nt[MAXN << 1];
int dep[MAXN], f[MAXN], st[MAXN][21];
int cnt, dfn[MAXN];
int MX, root, siz[MAXN], son[MAXN];
bool vis[MAXN];

struct BIT {
	vector<int> c;
	void build(int len) {while(c.size() < len + 5) c.push_back(0);}
	void add(int x, int val) {
	while(x < c.size())
	 c[x] += val,
	 x += x & -x;
	}
	int query(int x) {
		x = min(x, (int)c.size() - 1);
		int res = 0;
		while(x) res += c[x], x -= x & -x;
		return res;
	}
}s[2][MAXN];

void add(int x, int y) {
	to[++tot] = y, nt[tot] = hd[x], hd[x] = tot;
	to[++tot] = x, nt[tot] = hd[y], hd[y] = tot;
}

void dfs1(int now, int fa) {
	for(int i = hd[now]; i; i = nt[i]) {
		if(to[i] == fa) continue;
		dep[to[i]] = dep[now] + 1;
		st[to[i]][0] = now;
		for(int j = 1; (1 << j) < dep[to[i]]; j++)
		 st[to[i]][j] = st[st[to[i]][j - 1]][j - 1];
		dfs1(to[i], now);
	}
}

int getlca(int x, int y) {
	if(dep[x] < dep[y]) swap(x, y);
	for(int i = 20; i >= 0; i--)
	 if(dep[st[x][i]] >= dep[y]) x = st[x][i];
	if(x == y) return x;
	for(int i = 20; i >= 0; i--)
	 if(st[x][i] != st[y][i]) x = st[x][i], y = st[y][i];
	return st[x][0];
}

int getdis(int x, int y) {
	int lca = getlca(x, y);
	return dep[x] + dep[y] - 2 * dep[lca];
}

void dfs2(int now, int fa) {
	dfn[++cnt] = now;
	siz[now] = 1, son[now] = -1;
	for(int i = hd[now]; i; i = nt[i]) {
		if(to[i] == fa || vis[to[i]]) continue;
		dfs2(to[i], now);
		siz[now] += siz[to[i]];
		son[now] = max(son[now], siz[to[i]]);
	}
}

void dfs3(int now, int fa, int len) {
	s[0][root].add(len, a[now]);
	for(int i = hd[now]; i; i = nt[i])
	 if(to[i] != fa && !vis[to[i]]) dfs3(to[i], now, len + 1);
}

void divid(int now, int fa) {
	cnt = 0; dfs2(now, 0);
	MX = INF;
	for(int i = 1; i <= cnt; i++)
	 son[dfn[i]] = max(son[dfn[i]], cnt - siz[dfn[i]]);
	for(int i = 1; i <= cnt; i++)
	 if(MX > son[dfn[i]]) MX = son[dfn[i]], root = dfn[i];
	vis[root] = true;
	now = root; f[now] = fa;
	s[0][now].build(cnt); s[1][now].build(cnt);
	if(fa) {
		for(int i = 1; i <= cnt; i++)
		 s[1][now].add(getdis(dfn[i], fa), a[dfn[i]]);
	}
	for(int i = hd[now]; i; i = nt[i])
	 if(!vis[to[i]]) dfs3(to[i], now, 1);
	for(int i = hd[now]; i; i = nt[i])
	 if(!vis[to[i]]) divid(to[i], now);
}

void prepare() {
	dep[1] = 1; dfs1(1, 0);
	divid(1, 0);
}

void modify(int x, int val) {
	int now = x;
	while(f[now]) {
		int d = getdis(x, f[now]);
		s[1][now].add(d, -a[x]);
		s[1][now].add(d, val);
		now = f[now];
		s[0][now].add(d, -a[x]);
		s[0][now].add(d, val);
	}
	a[x] = val;
}

int query(int x, int val) {
	int now = x, res = 0;
	res += a[x] + s[0][x].query(val);
	while(f[now]) {
		int d = getdis(f[now], x);
		if(d <= val) res += a[f[now]] + s[0][f[now]].query(val - d) - s[1][now].query(val - d);
		now = f[now];
	}
	return res;
}

int main() {
	scanf("%d%d",&n, &m);
	for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
	for(int i = 1, x, y; i < n; i++) {
		scanf("%d%d",&x, &y);
		add(x, y);
	}
	prepare();
	int las = 0, op, x, y;
	for(int i = 1; i <= m; i++) {
		scanf("%d%d%d",&op, &x, &y);
		x ^= las, y ^= las;
		if(op == 0) printf("%d\n",las = query(x, y));
		else if(op == 1) modify(x, y);
	}
	return 0;
}

点分树

上一篇:DateUtils 新增年月日时分秒


下一篇:QT 读取Xml文件