[学习笔记]树形动态规划

关于树形DP

树形动态规划,顾名思义,就是在树的数据结构上做动态规划。由于树天生就是一种递归的数据结构,因此树形 DP 的实现方式通常都是用记忆化搜索。

因为转移有 push 和 pull 型两种,自然树形 DP 的转移也有两种顺序:

  1. 叶 → \rightarrow → 根,算出一个节点的子节点信息后得到该节点的信息,这一种转移顺序比较常用;
  2. 根 → \rightarrow → 叶,这种转移顺序极为少见,但是也不是没有。(路人:你这不是废话

树形 DP 的时间复杂度初学者可能会分析错,大多数普通的题目都是 O ( n ) O(n) O(n) 的。我们来分析一下:
在大多数情况下,状态的个数即为节点的个数 n n n,而每个状态转移的复杂度是 O ( son i ) O(\text{son}_i) O(soni​)(这里 son i \text{son}_i soni​ 表示节点 i i i 的子节点个数),所以总的复杂度就是 O ( ∑ i = 1 n son i ) O(\sum_{i=1}^n \text{son}_i) O(∑i=1n​soni​),即 O ( n ) O(n) O(n)。

树形 DP 的使用情境十分好判断,就是在树上做一遍 DP。对于一些问题,我们可以建立起模型,用树的数据结构求解。

树形 DP 其实也带有一点套路题的色彩,很多情况下我们的状态都是“设 d p i dp_i dpi​ 为以 i i i 为根节点的子树中……”,为了设计状态转移方程,我们考虑当前这个节点的状态和它的父节点、它的子节点间有什么关系。

经典问题 & 例题

树形 DP 的经典问题有:树的重心问题,树的最长路径/最长链/最远点对/直径问题,树的中心问题,树的点覆盖问题,树的独立集问题等。这些问题比较基础,就不展开讲了,许多教材上都有。


例题1

一本通 5.2 练习 2 旅游规划

首先明确一下我们的思路:我们首先要求出最长路径的长度,然后判断每个点是否在最长路径上,即经过这个点的最长路径能否达到全局最长路径的长度。

Step1: 求出最长路径的长度。
这是个经典问题。我们以任意一点为根,把无根树变成有根树,那么树中的一个最长链必然在以某个节点为根的子树中且经过这个节点。所以以一个点为根的子树中的经过这个节点的最长链长度,一定等于从这个点向下出发(即不越出这个子树)的最长链和次长链之和。
分别用 d 1 i d1_i d1i​ 和 d 2 i d2_i d2i​ 表示从节点 i i i 向下出发的最长链和次长链长度,那么枚举 i i i 的子节点 v v v,有

  • 如果 d 1 v + 1 > d 1 i d1_v+1>d1_i d1v​+1>d1i​,那么 d 2 i ← d 1 i d2_i \leftarrow d1_i d2i​←d1i​ 且 d 1 i ← d 1 v + 1 d1_i \leftarrow d1_v + 1 d1i​←d1v​+1;
  • 如果不满足 d 1 v + 1 > d 1 i d1_v+1>d1_i d1v​+1>d1i​ 但 d 1 v + 1 > d 2 i d1_v+1>d2_i d1v​+1>d2i​,那么 d 2 i ← d 1 v + 1 d2_i \leftarrow d1_v+1 d2i​←d1v​+1。

最长路径的长度即为 max ⁡ { d 1 i + d 2 i } \max \{d1_i+d2_i\} max{d1i​+d2i​}

Step2:求出经过一个点的最长路径。
从一个点出发的最长路径只有2种形态:在子树中的最长路径和不在子树中的最长路径。
子树中的最长路径:刚刚在 Step1 中求过了,就是 d 1 i d1_i d1i​;
不在子树中的最长路径:设 u p i up_i upi​ 表示从点 i i i 向上出发的最长路径,那么它必然经过点 i i i 的父节点 f a fa fa。接下来又分两种情况:

  1. 转到了 f a fa fa 的子树中。这时我们要记录 d 1 f a d1_{fa} d1fa​ 是从哪里转移过来的。设 c i c_i ci​ 表示 d 1 i d1_i d1i​ 是从 c i c_i ci​ 转移过来的,如果 c f a = i c_{fa} = i cfa​=i,那么整条路径的长度就是 d 1 i + d 2 f a + 1 d1_i+d2_{fa}+1 d1i​+d2fa​+1,否则会产生重复的计算;反之,则整条路径的长度就是 d 1 i + d 1 f a + 1 d1_i+d1_{fa}+1 d1i​+d1fa​+1。
  2. 不进入 f a fa fa 的子树,那么路径长度就是 d 1 i + u p f a + 1 d1_i+up_{fa}+1 d1i​+upfa​+1。

在以上三种可能的路径长度中取最大值即可。

思路呼之欲出了:两遍 DFS(或者说两遍 DP),第一遍求出 d 1 i , d 2 i d1_i,d2_i d1i​,d2i​ 和 c i c_i ci​,第二遍求出 u p i up_i upi​。

加这道题的目的一是为了回顾树中最长路径的求法;二是展示一下无返回值类型(void)的记忆化搜索;三是告诉大家,我们经常会在树形 DP 中用到两遍 DFS。

剩下的足够用代码解释了:

#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn = 2e5 + 10;
int n, h[maxn], en;
struct Edge
{
	int u;
	int v;
	int next;
} e[maxn << 1];
int d1[maxn], d2[maxn], c[maxn], up[maxn];

void addedge(int u, int v)
{
	en++;
	e[en].u = u;
	e[en].v = v;
	e[en].next = h[u];
	h[u] = en;
	return;
}

void dfs1(int x, int fa)
{
	for(int i = h[x]; i != 0; i = e[i].next)
	{
		int v = e[i].v;
		if(v == fa)
			continue;
		dfs1(v, x);
		if(d1[v] + 1 > d1[x])
		{
			d2[x] = d1[x];
			d1[x] = d1[v] + 1;
			c[x] = v;
		}
		else
			d2[x] = max(d2[x], d1[v] + 1);
	}
	return;
}

void dfs2(int x, int fa)
{
	up[x] = max(up[fa], c[fa] == x ? d2[fa] : d1[fa]) + 1;
	for(int i = h[x]; i != 0; i = e[i].next)
		if(e[i].v != fa)
			dfs2(e[i].v, x);
	return;
}

int main()
{
	scanf("%d", &n);
	for(int i = 1; i < n; i++)
	{
		int u, v;
		scanf("%d%d", &u, &v);
		u++;
		v++;
        //题目中的点编号是0,1...n-1,为方便操作同时避免意外的错误
        //将u和v都加1
		addedge(u, v);
		addedge(v, u);
	}
	dfs1(1, 0);
	dfs2(1, 0);
	int len = 0;
	for(int i = 1; i <= n; i++)
		len = max(len, d1[i] + d2[i]);  //求出最长路径的长度
	for(int i = 1; i <= n; i++)  //判断有哪些点在最长路径上
		if(d1[i] + d2[i] == len || d1[i] + up[i] == len)
			printf("%d\n", i - 1);  //注意一开始编号加了1
	return 0;	
}

例题2

ZJOI2008 骑士

Observation:如果一个骑士憎恨另一个骑士,我们在这两个骑士间建立一个无向边,那么整张图是一个基环树森林。
Proof:整张图是由若干连通块构成的。假设其中一个连通块点的个数为 m m m,因为每个骑士都有且仅有一个最厌恶的骑士且不是自己,这个连通块中就会有 m m m 条边。一个 m m m 个节点的图,有 m m m 条边且连通,自然是基环树。基环树就是一个树上加了一条边形成了一个环。

那么我们只要让每个基环树的战斗力尽可能大,基环树森林的总战斗力就会最大。

把基环树看成环上的每个节点都作为根节点长出了一棵有根树,每棵树中我们做一遍“没有上司的舞会”模型(题目在这里),然后再在环上做一遍序列 DP 就没问题了。但是有一个问题:区间 DP 我们可以破环为链,这里我们没办法破环为链,做个特殊处理就好啦。

代码:(码量稍微有点大)

//C++11
#include<cstdio>
#include<cctype>
#include<vector>
#include<stack>
using namespace std;
const int maxn = 1000010;
int n;
long long a[maxn];
int h[maxn], en;
struct Edge
{
	int u;
	int v;
	int next;
};
Edge e[maxn << 1];

stack<int> stk;
vector<int> cycle;
long long f[maxn][2], g[maxn][2][2], ans;
bool circ[maxn], vis[maxn], found[maxn], ok;

inline long long read()
{
	long long x = 0;
	bool flag = true;
	char ch = getchar();
	while(!isdigit(ch))
	{
		if(ch == '-')
			flag = false;
		ch = getchar();
	}
	while(isdigit(ch))
	{
		x = (x << 1) + (x << 3) + (ch ^ 48);
		ch = getchar();
	}
	return flag ? x : -x;
}

inline void addedge(int u, int v)
{
	en++;
	e[en].u = u;
	e[en].v = v;
	e[en].next = h[u];
	h[u] = en;
	return;
}

inline int get(int edge)  //找反向边编号
{
	return edge & 1 ? edge + 1 : edge - 1;
}

void find_circle(int x, int pre)
{
	if(ok) return;  //已经找到环了
	stk.push(x);  //把节点放进栈里
	if(found[x])  //发现形成了一个环
	{
		while(!stk.empty() && !circ[stk.top()])  //把栈里的节点倒出来
		{
			circ[stk.top()] = true;
			cycle.push_back(stk.top());
			stk.pop();
		}
		ok = true;
		return;
	}
	found[x] = true;
	for(int i = h[x]; i != 0; i = e[i].next)
	{
		if(get(i) == pre)
			continue;
        //注意这里之所以用反向边判断而不是用走回原来节点判断
        //是因为两个骑士可能互相憎恨
		int v = e[i].v;
		find_circle(v, i);
	}
	if(!stk.empty())
		stk.pop();   //注意细节
	return;
} 

void dp(int x, int fa)
{
	vis[x] = true;
	f[x][1] = a[x];
	for(int i = h[x]; i != 0; i = e[i].next)
	{
		int v = e[i].v;
		if(v == fa || circ[v])
			continue;
		dp(v, x);
		f[x][0] += max(f[v][0], f[v][1]);
		f[x][1] += f[v][0]; 
	}
	return;
}

inline void dp2()
{
    //g[x][b1][b2]中,b1表示x选不选,b2表示1选不选
	int len = cycle.size(), head = cycle[0], sec = cycle[1], last = cycle[len - 1];
	g[sec][0][0] = f[head][0] + f[sec][0];
	g[sec][1][0] = f[head][0] + f[sec][1];
	for(int i = 2; i < len; ++i)
	{
		int u = cycle[i - 1], v = cycle[i];
		g[v][0][0] = max(g[u][0][0], g[u][1][0]) + f[v][0];
		g[v][1][0] = g[u][0][0] + f[v][1];
	}
	g[sec][0][1] = f[head][1] + f[sec][0];
	if(len > 2)
	{
		int third = cycle[2];
		g[third][0][1] = g[sec][0][1] + f[third][0];
		g[third][1][1] = g[sec][0][1] + f[third][1];
	}
	for(int i = 3; i < len; ++i)
	{
		int u = cycle[i - 1], v = cycle[i];
		g[v][0][1] = max(g[u][0][1], g[u][1][1]) + f[v][0];
		g[v][1][1] = g[u][0][1] + f[v][1];
	}
	ans += max(g[last][0][0], max(g[last][0][1], g[last][1][0]));
	return;
}

int main()
{
	n = read();
	for(int i = 1; i <= n; ++i)
	{
		int target;
		a[i] = read();
		target = read();
		addedge(i, target);
		addedge(target, i);
	}
	for(int i = 1; i <= n; ++i)
	{
		if(vis[i])
			continue;  //注意整张图是个森林
		ok = false;
		while(!stk.empty())
			stk.pop();
		cycle.clear();  //细节,清空栈和动态数组
		find_circle(i, -1);  //找环
		for(int x : cycle)  //仅C++11及以上标准支持
			dp(x, 0);  //树形DP
		dp2();  //序列DP
	}
	printf("%lld\n", ans);
	return 0;
}

换根DP

推荐阅读:[学习笔记]换根DP

笔者没有学这个……所以没有写……

上一篇:关于sv中竞争冒险的理解


下一篇:VHDL中的delta cycle