【HAOI2015】树上染色—树形dp

【HAOI2015】树上染色

【题目描述】
有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。
【输入格式】
第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。输入保证所有点之间是联通的。
【输出格式】
输出一个正整数,表示收益的最大值。
【输入样例1】
3 1
1 2 1
1 3 2
【输出样例1】
3
【输入样例2】
5 2
1 2 3
1 5 1
2 3 1
2 4 2
【输出样例2】
17
【样例解释】
在第二个样例中,将点1,2染黑就能获得最大收益。
【数据范围】
对于30%的数据,N<=20
对于50%的数据,N<=100
对于100%的数据,N<=2000,0<=K<=N。

题解:

这题一看就是dp,单个点无法产生贡献,只能是两个黑点或两个白点才能产生贡献

如果我们dp围绕点来进行,首先你无法定义一种状态,其次你不方便转移

所以我们考虑一条边能产生的贡献

首先一条边能产生的贡献为:

   Wi=dis×((这条边左边的黑点×这条边右边的黑点)+(这条边左边的白点×这条边右边的白点))

我们考虑一对点产生的贡献,这一条边L左侧的黑点要想连接L右侧的黑点,必定经过L,所以L会有贡献

有几对这样的黑点,边L就被经过几次,白点同理,所以得到了上面的方程

我们设f[i][j]表示以i为根节点的子树中染j个黑点的最大收益,

size[x]表示以x为根的子树大小,目标:f[1][k];

设当前搜索到以x为根的子树,设son为x的一个儿子,则:

  $f[x][j]=max(f[x][p]+f[son][j-p]+Wi),$

  $j<=min(size[x],k),p<=min(size[son],k)$

Wi表示x与son所连边产生的贡献,在上面已经解释过了

更新f的时候,如果正着推,就需要把f存到一个临时数组里;循环完后再赋值给f,如果倒着枚举直接更新就好了

当然你还可以减少枚举的数量,即j<=min(size[x],k),p<=min(size[son],k)。

下面给出两份代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define MAXN 2005
using namespace std;
ll n,k;
ll fr[MAXN<<],to[MAXN<<],nxt[MAXN<<],pre[MAXN],cnt=,dis[MAXN<<];
void add(ll u,ll v,ll w){
cnt++,fr[cnt]=u,to[cnt]=v,nxt[cnt]=pre[u],pre[u]=cnt,dis[cnt]=w;
}
ll f[MAXN][MAXN],temp[MAXN],size[MAXN];
void dfs(ll x,ll fa){
size[x]=;
for(ll i=pre[x];i;i=nxt[i]){
ll y=to[i];
if(y==fa) continue;
dfs(y,x);
ll num_b1=min(k,size[x]),num_b2=min(k,size[y]),l;
memset(temp,,sizeof(temp));
for(ll j=;j<=num_b1;j++){
for(ll p=;p<=num_b2&&p+j<=k;p++){
l=dis[i]*((k-p)*p+(size[y]-p)*(n-k-size[y]+p));
temp[j+p]=max(temp[j+p],f[x][j]+f[y][p]+l);
}
}
ll m=min(num_b1+num_b2,k);
for(ll j=;j<=m;j++)
f[x][j]=temp[j];
size[x]+=size[y];
}
}
int main(){
scanf("%lld%lld",&n,&k);
for(ll i=,u,v,d;i<n;i++){
scanf("%lld%lld%lld",&u,&v,&d);
add(u,v,d),add(v,u,d);
}
dfs(,);
printf("%lld\n",f[][k]);
return ;
}

正序

#include <iostream>
#include <cstdio>
#include <cstring>
#define ll long long
#define MAXN 2005
using namespace std;
ll n, k;
ll to[MAXN << ], nxt[MAXN << ], pre[MAXN], cnt = , dis[MAXN << ];
void add(ll u, ll v, ll w) { cnt++, to[cnt] = v, nxt[cnt] = pre[u], pre[u] = cnt, dis[cnt] = w; }
ll f[MAXN][MAXN], temp[MAXN][MAXN], size[MAXN];
void dfs(ll x, ll fa) {
size[x] = ;
for (ll i = pre[x]; i; i = nxt[i]) {
ll y = to[i];
if (y == fa)
continue;
dfs(y, x);
ll num_b1 = min(k, size[x]), l;
for (ll j = num_b1; j >= ; j--) {
ll num_b2 = min(k - j, size[y]);
for (ll p = num_b2; p >= ; p--) {
l = dis[i] * ((k - p) * p + (size[y] - p) * (n - k - size[y] + p));
f[x][j + p] = max(f[x][j + p], f[x][j] + f[y][p] + l);
}
}
size[x] += size[y];
}
}
int main() {
scanf("%lld%lld", &n, &k);
for (ll i = , u, v, d; i < n; i++) {
scanf("%lld%lld%lld", &u, &v, &d);
add(u, v, d), add(v, u, d);
}
dfs(, );
printf("%lld\n", f[][k]);
return ;
}

倒序,感谢loj码风优化

上一篇:大数据篇:Spark


下一篇:ArrayList其实就那么一回事儿之源码浅析