读题找到无解的情况:大头吃掉的\(K\)个果子后,剩下的\(N - K\)个果子不够其他的\(M-1\)个头分,即\(N-K<M-1\)
题目中只限制除了大头外的头要吃至少一个果子,也就是说一段树枝相连的两个果子若被同一个头吃掉,则只会被大头吃掉,否则不是最优(\(M=2\)的情况除外)。
证明(口胡):
如果出现树枝相连的两个果子被同一个非大头吃掉,则可以把其中一个换成另一个非大头,修改后没有改变大头吃的果子数,但使难受值变小了,所以“树枝相连的两个果子被同一个非大头吃掉”一定不是最优解。
考虑DP:
设\(\large dp_{i, j, k}\)表示在以\(i\)为根节点的子树中有\(j\)个果子归大头吃,\(k=0\)表示点\(i\)不归大头吃,\(k=1\)表示点\(i\)归大头吃。
然后就产生了一个类似于树上背包的DP, 注意特殊处理\(M = 2\)的情况。
转移时只需注意分类讨论:
两个都不归大头吃,只有m=2时算入难受;
只有子节点归大头吃,不算入难受值;
只有根节点归大头吃, 不算入难受值;
都归大头吃, 算入难受值。
转移方程:
\[dp[x][j + l][0] = \min_{j\le tot, l\le size_y}(dp[y][l][1] + dp[x][j][0], dp[y][l][0] + dp[x][j][0] + [m=2] \times e[i].dat ); \] \[dp[x][j + l + 1][1] = \min_{j\le tot, l\le size_y}(dp[y][l][1] + dp[x][j + 1][1] + e[i].dat, dp[y][l][0] + dp[x][j + 1][1]); \]转移会产生覆盖,需要另开一个数组来记录答案。
代码:
#include<bits/stdc++.h>
using namespace std;
int n, m, k;
struct node{
int to, dat, nxt;
}e[610];
int h[1010], cnt;
void add(int x, int y, int z){
e[++cnt].to = y;
e[cnt].dat = z;
e[cnt].nxt = h[x];
h[x] = cnt;
}
int dp[310][310][2];
int f[310], size[310];
void dfs(int x, int fa){
f[x] = fa;
size[x] = 1;
for(int i = h[x]; i; i = e[i].nxt){
int y = e[i].to;
if(y != fa)
dfs(y, x), size[x] += size[y];
}
}
int pp[310][2];
void dfss(int x, int fa){
int tot = 0;//注意tot不包括x本身
dp[x][0][0] = 0;//初始化
dp[x][1][1] = 0;
for(int i = h[x]; i; i = e[i].nxt){
int y = e[i].to;
if(y == fa)
continue;
dfss(y, x);
memset(pp, 0x3f, sizeof(pp));//用于转移记录答案, 防止覆盖。
for(int j = min(tot, k); j >= 0; j--){
for(int l = min(size[y], k - j); l >= 0; l--){
if(j + l > k)
continue;pp[j + l + 1][1] = min(pp[j + l + 1][1], dp[y][l][0] + dp[x][j + 1][1]);
if(l >= 1)
pp[j + l + 1][1] = min(pp[j + l + 1][1], dp[y][l][1] + dp[x][j + 1][1] + e[i].dat);
//以下为核心转移---------------------------
pp[j + l][0] = min(pp[j + l][0], dp[y][l][0] + dp[x][j][0] + (m == 2 ? e[i].dat : 0));//两个都不归大头吃,只有m=2时算入难受值
pp[j + l][0] = min(pp[j + l][0], dp[y][l][1] + dp[x][j][0]);//只有子节点归大头吃,不算入难受值。
if(j + l <= min(tot + size[y], k)){
pp[j + l + 1][1] = min(pp[j + l + 1][1], dp[y][l][0] + dp[x][j + 1][1]);//只有根节点归大头吃, 不算入难受值
if(l >= 1)
pp[j + l + 1][1] = min(pp[j + l + 1][1], dp[y][l][1] + dp[x][j + 1][1] + e[i].dat);//都归大头吃, 算入难受值
//以上为核心转移---------------------------
}
}
}
tot += size[y];//扩大子树大小
for(int qsqsq = 0; qsqsq <= tot + 1; qsqsq++){//重新赋值
dp[x][qsqsq][0] = pp[qsqsq][0];
dp[x][qsqsq][1] = pp[qsqsq][1];
}
}
}
signed main(){
memset(dp, 0x3f, sizeof(dp));
scanf("%d%d%d", &n, &m, &k);
if(n - k < m - 1){
puts("-1");
return 0;
}
for(int i = 1; i < n; i++){
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
add(x, y, z);
add(y, x, z);
}
dfs(1, 0);
dfss(1, 0);
else
cout <<dp[1][k][1] <<"\n";
}
能跑到32ms,但为什么我的代码这么长啊。