题意
一个 \(n\) 个点的树,每个点上有 \(d_i\) 个物品,每件体积为 \(c_i\),价值是 \(w_i\)
在树上选择一些物品(每个点可以选多个),使得他们可以组成一个连通块,求能获得的最大价值。
题解
考虑直接树形dp。设 \(dp_{i,j}\) 为在 \(i\) 子树内使用 \(j\) 的体积,并且选了 \(i\) ,能够得到的最大价值。时间复杂度是 \(O(nm^2)\) 的。
瓶颈在于背包是 \(O(m^2)\) 的。
那么我们考虑如何在可接受的复杂度内计算 \(dp_{i,j}\)
\(dp_{i,j}\) 表示的是 \(i\) 子树内一个包括了 \(i\) 的连通块,其实就是要求选某个物品的话,必须至少选一个他父亲节点的物品
而这题是多重背包,二进制分组/单调队列一下,求解一次 \(dp_{i,j}\),复杂度会是 \(O(nm\log d)/O(nm)\)
考虑使用点分治,就会变成 \(O(nm\log n\log d)/O(nm\log n)\)
下面是二进制分组版本,因为不想写单调队列,也不太会写。。。
很容易写错,,而且跑得比下面那种慢多了。。。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mkp make_pair
#define pb push_back
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define ls(x) ((x) << 1)
#define rs(x) ((x) << 1 | 1)
#define fi first
#define se second
const int N = 510, M = 4010, inf = 0x3f3f3f3f;
int n, m, siz, rt, ans, c[N], d[N], w[N], id[N];
int e, to[N << 1], nxt[N << 1], hd[N];
bool vis[N]; int sz[N], mxsz[N], tim, dp[N][M], f[M];
void add(int u, int v) {
to[++e] = v; nxt[e] = hd[u]; hd[u] = e;
}
void findrt(int u, int fa) {
sz[u] = 1; mxsz[u] = 0;
for(int i = hd[u]; i; i = nxt[i]) {
int v = to[i]; if(v == fa || vis[v]) continue;
findrt(v, u); sz[u] += sz[v];
if(sz[v] > mxsz[u]) mxsz[u] = sz[v];
}
mxsz[u] = max(mxsz[u], siz - sz[u]);
if(mxsz[u] < mxsz[rt]) rt = u;
}
void init() {
ans = 0; mxsz[0] = inf;
for(int i = 1; i <= n; i++)
hd[i] = vis[i] = 0;
for(int i = 1; i <= e; i++)
to[i] = nxt[i] = 0;
e = 0;
}
void dfs(int u, int fa) {
sz[u] = 1; id[++tim] = u;
for(int i= hd[u]; i; i = nxt[i]) {
int v = to[i]; if(vis[v] || v == fa) continue;
dfs(v, u); sz[u] += sz[v];
}
return;
}
void calc(int u) {
tim = 0; dfs(u, 0);
dp[tim + 1][0] = 0; for(int i = 1; i <= m; i++) dp[tim + 1][i] = -inf;
for(int i = tim; i >= 1; i--) {
int p = id[i], t = d[p] - 1;
for(int k = m; k >= 0; k--)
f[k] = (k >= c[p]) ? (dp[i + 1][k - c[p]] + w[p]) : -inf;
for(int j = 1; j <= t; t -= j, j <<= 1) {
for(int k = m; k >= j * c[p]; k--)
f[k] = max(f[k], f[k - j * c[p]] + j * w[p]);
}
if(t) {
for(int k = m; k >= t * c[p]; k--)
f[k] = max(f[k], f[k - t * c[p]] + t * w[p]);
}
for(int k = 0; k <= m; k++)
dp[i][k] = max(dp[i + sz[p]][k], f[k]);
}
for(int i = 0; i <= m; i++)
ans = max(ans, dp[1][i]);
}
void solve(int u) {
calc(u);
vis[u] = 1;
for(int i = hd[u]; i; i = nxt[i]) {
int v = to[i]; if(vis[v]) continue;
siz = sz[v]; rt = 0; findrt(v, 0);
solve(rt);
}
return;
}
int main(){
int T; scanf("%d", &T);
while(T--) {
init();
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
for(int i = 1; i <= n; i++) scanf("%d", &d[i]);
for(int i = 1, x, y; i < n; i++)
scanf("%d%d", &x, &y), add(x, y), add(y, x);
rt = 0; siz = n; findrt(1, 0);
solve(rt);
printf("%d\n", ans);
}
return 0;
}
还有一种dp的方法,可以在dfs时就直接dp。
\(dp_{i,j}\) 改为表示dfs完 \(i\) 点及其子树,此时所访问过的所有点在用了 \(j\) 的体积的情况下的最大价值。
让 \(f_{j}\) 表示现在dfs到现在的点 \(i\),且选了 \(u\),所访问过的所有点的在用了 \(j\) 的体积的情况下的最大价值。计算后变成 和 \(dp_{i,j}\) 一样的定义。
那么这个 \(f_j\) 往下dfs下去的时候,一定会强制选当前点的父亲,否则就不可能出现当前点了,然后再多重背包一下就行了。
我说的巨大多不清楚,其实代码很清楚。。。但我觉得实际上实现起来复杂的一批(恼)
其实本质上和前一种方法是差不多的,状态都是当前dfs已经访问的所有点/dfs序中还没访问的那些点
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define mkp make_pair
#define pb push_back
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define ls(x) ((x) << 1)
#define rs(x) ((x) << 1 | 1)
#define fi first
#define se second
const int N = 510, M = 4010, inf = 0x3f3f3f3f;
int n, m, siz, rt, ans, c[N], d[N], w[N];
int e, to[N << 1], nxt[N << 1], hd[N];
bool vis[N]; int sz[N], mxsz[N], f[M], dp[N][M];
void add(int u, int v) {
to[++e] = v; nxt[e] = hd[u]; hd[u] = e;
}
void findrt(int u, int fa) {
sz[u] = 1; mxsz[u] = 0;
for(int i = hd[u]; i; i = nxt[i]) {
int v = to[i]; if(v == fa || vis[v]) continue;
findrt(v, u); sz[u] += sz[v];
if(sz[v] > mxsz[u]) mxsz[u] = sz[v];
}
mxsz[u] = max(mxsz[u], siz - sz[u]);
if(mxsz[u] < mxsz[rt]) rt = u;
}
void init() {
ans = 0; mxsz[0] = inf;
for(int i = 1; i <= n; i++)
hd[i] = vis[i] = 0;
for(int i = 1; i <= e; i++)
to[i] = nxt[i] = 0;
e = 0;
}
void dfs(int u, int fa, int cst) {
if(cst > m) return;
for(int i = 0; i <= m; i++)
dp[u][i] = f[i];
for(int i = m; i >= 0; i--) {
if(i >= cst + c[u])
f[i] = f[i - c[u]] + w[u];
else f[i] = -inf;
}
int t = d[u] - 1;
for(int i = 1; i <= t; t -= i, i <<= 1)
for(int j = m; j >= i * c[u]; j--)
f[j] = max(f[j], f[j - i * c[u]] + i * w[u]);
if(t) {
for(int j = m; j >= t * c[u]; j--)
f[j] = max(f[j], f[j - t * c[u]] + t * w[u]);
}
for(int i= hd[u]; i; i = nxt[i]) {
int v = to[i]; if(vis[v] || v == fa) continue;
dfs(v, u, cst + c[u]);
}
for(int i = 0; i <= m; i++)
dp[u][i] = f[i] = max(dp[u][i], f[i]);
return;
}
void calc(int u) {
f[0] = 0; for(int i = 1; i <= m; i++) f[i] = -inf;
dfs(u, 0, 0);
for(int i = 0; i <= m; i++)
ans = max(ans, dp[u][i]);
}
void solve(int u) {
calc(u);
vis[u] = 1;
for(int i = hd[u]; i; i = nxt[i]) {
int v = to[i]; if(vis[v]) continue;
siz = sz[v]; rt = 0; findrt(v, 0);
solve(rt);
}
return;
}
int main(){
int T; scanf("%d", &T);
while(T--) {
init();
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
for(int i = 1; i <= n; i++) scanf("%d", &c[i]);
for(int i = 1; i <= n; i++) scanf("%d", &d[i]);
for(int i = 1, x, y; i < n; i++)
scanf("%d%d", &x, &y), add(x, y), add(y, x);
rt = 0; siz = n; findrt(1, 0);
solve(rt);
printf("%d\n", ans);
}
return 0;
}