http://acm.hdu.edu.cn/showproblem.php?pid=4616
要记录各种状态的段 a[2][4]
a[0][j]表示以trap为起点一共有j个trap的最优值
a[1][j]表示不以trap为起点一共有j个trap的最优值
dp[x][i][j] 表示以x为根节点的子树从各个叶子到x节点的各状态最优值
每到一个节点 要枚举经过此节点的所有符合要求的段中最优的(需要合并段)
代码:
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<cmath>
#include<set>
#include<map>
#include<stack>
#include<vector>
#include<algorithm>
#include<queue>
#include<bitset>
#include<deque>
#include<numeric> #pragma comment(linker, "/STACK:1024000000,1024000000") using namespace std; typedef long long ll;
typedef unsigned int uint;
typedef pair<int,int> pp;
const double eps=1e-9;
const int INF=0x3f3f3f3f;
const ll MOD=1000000007;
const int N=100005;
int head[N],I;
struct node
{
int j,next;
}edge[N*2];
int value[N],trap[N];
int dp[N][2][4];
int ans,C;
void add(int i,int j)
{
edge[I].j=j;
edge[I].next=head[i];
head[i]=I++;
}
void init(int n)
{
for(int i=0;i<n;++i)
scanf("%d %d",&value[i],&trap[i]);
memset(head,-1,sizeof(head));I=0;
for(int i=1;i<n;++i)
{
int l,r;
scanf("%d %d",&l,&r);
add(l,r);
add(r,l);
}
}
void copyArr(int (*b)[4],int (*a)[4])
{
for(int i=0;i<2;++i)
for(int j=0;j<4;++j)
b[i][j]=a[i][j];
}
void clArr(int (*b)[4])
{
for(int i=0;i<2;++i)
for(int j=0;j<4;++j)
b[i][j]=-1;
b[0][0]=b[1][0]=0;
}
void update(int (*b)[4],int x)
{
if(trap[x]==0)
{
for(int i=0;i<2;++i)
for(int j=0;j<4;++j)
if(b[i][j]!=-1)
b[i][j]+=value[x];
b[0][0]=0;
}else
{
for(int i=0;i<2;++i)
for(int j=3;j>0;--j)
{
if(b[i][j-1]!=-1)
b[i][j]=b[i][j-1]+value[x];
}
b[0][0]=0;
b[1][0]=0;
}
}
void print(int (*b)[4])
{
for(int i=0;i<2;++i)
{
for(int j=0;j<4;++j)
printf("%4d ",b[i][j]);printf("\n");
}printf("\n");
}
void findAns(int (*b)[4],int (*v1)[4],int (*v2)[4],int x)
{
int c=C-trap[x];
int tmp=0;
for(int i=0;i<2;++i)
for(int j=0;j<4;++j)
{
for(int l=0;l<2;++l)
for(int r=0;r<4;++r)
{
if(j+r>c) continue;
if(j+r==c)
{
if(i+l==2) continue;
if(i!=l)
{
if(i==0&&j==0) continue;
if(l==0&&r==0) continue;
}
}
if(v1[l][r]!=b[l][r])
tmp=max(tmp,max(0,v1[l][r])+max(0,b[i][j]));
else
tmp=max(tmp,max(0,v2[l][r])+max(0,b[i][j]));
}
}
ans=max(ans,tmp+value[x]);
}
void dfs(int pre,int x,int (*a)[4])
{
int b[2][4];
copyArr(b,a);
update(b,x);
int v1[2][4],v2[2][4];
clArr(v1);clArr(v2);
for(int t=head[x];t!=-1;t=edge[t].next)
{
int l=edge[t].j;
if(l==pre) continue;
dfs(x,l,b);
for(int i=0;i<2;++i)
for(int j=0;j<4;++j)
{
v2[i][j]=max(v2[i][j],dp[l][i][j]);
if(v1[i][j]<v2[i][j])
swap(v1[i][j],v2[i][j]);
}
}
copyArr(dp[x],v1);
update(dp[x],x);
for(int i=0;i<2;++i)
for(int j=0;j<4;++j)
{
v2[i][j]=max(v2[i][j],a[i][j]);
if(v1[i][j]<v2[i][j])
swap(v1[i][j],v2[i][j]);
}
findAns(a,v1,v2,x);
for(int t=head[x];t!=-1;t=edge[t].next)
{
int l=edge[t].j;
if(l==pre) continue;
findAns(dp[l],v1,v2,x);
}
}
int main()
{
//freopen("data.in","r",stdin);
int T;
scanf("%d",&T);
while(T--)
{
int n;
scanf("%d %d",&n,&C);
init(n);
int a[2][4];
clArr(a);
memset(dp,-1,sizeof(dp));
ans=0;
dfs(-1,0,a);
printf("%d\n",ans);
}
return 0;
}