描述
给定一颗树(边权为1),选取一个节点子集,使得该集合中任意两个节点之间的距离都大于K。求这个集合节点最多是多少
输入
第一行是两个整数N,K
接下来是N-1行,每行2个整数x,y,表示x与y有一条边
输出
1个整数表示最多的节点数
样例输入
3 1
1 2
1 3
样例输出
2
提示
测试点 | N的上限 | K | 特征 |
---|---|---|---|
1 | 15 | 1 | |
2 | 1000 | 1 | 链 |
3 | 1000 | 1 | |
4 | 100000 | 1 | 链 |
5 | 100000 | 1 | |
6 | 15 | 2 | |
7 | 1000 | 2 | 链 |
8 | 1000 | 2 | |
9 | 100000 | 2 | 链 |
10 | 100000 | 2 |
树形dp入门题。
T=2的情况有点意思。
设当前访问第i个节点。
f[i][0]" role="presentation" style="position: relative;">f[i][0]f[i][0]:i不选但i父亲选。
f[i][1]" role="presentation" style="position: relative;">f[i][1]f[i][1]:不选且i父亲不选。
f[i][2]" role="presentation" style="position: relative;">f[i][2]f[i][2]:i选。
显然有:
f[i][2]=1+∑vf[v][0]" role="presentation" style="position: relative;">f[i][2]=1+∑vf[v][0]f[i][2]=1+∑vf[v][0]
以及:
f[i][0]=∑vf[v][1]" role="presentation" style="position: relative;">f[i][0]=∑vf[v][1]f[i][0]=∑vf[v][1]
关键是f[i][1]" role="presentation" style="position: relative;">f[i][1]f[i][1]
这个东西需要考虑儿子之间是否冲突,因此最优值的产生有两种可能:
1. 所有儿子都不选。
2. 某一个儿子选,其余不选。
因此有f[i][1]=(∑vf[v][1])+max(0,f[v][2]−f[v][1])" role="presentation" style="position: relative;">f[i][1]=(∑vf[v][1])+max(0,f[v][2]−f[v][1])f[i][1]=(∑vf[v][1])+max(0,f[v][2]−f[v][1])。
代码:
#include<bits/stdc++.h>
#define N 100005
using namespace std;
inline int read(){
int ans=0;
char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return ans;
}
int first[N],n,k,cnt=0,f[N][3];
struct edge{int v,next;}e[N<<1];
inline void add(int u,int v){e[++cnt].v=v,e[cnt].next=first[u],first[u]=cnt;}
inline int max(int a,int b){return a>b?a:b;}
inline int dfs1(int p,bool k,int fa){
if(f[p][k]!=-1)return f[p][k];
f[p][k]=k;
for(int i=first[p];i;i=e[i].next){
int v=e[i].v;
if(v==fa)continue;
if(!k)f[p][k]+=max(dfs1(v,0,p),dfs1(v,1,p));
else f[p][k]+=dfs1(v,0,p);
}
return f[p][k];
}
inline int dfs2(int p,int k,int fa){
if(f[p][k]!=-1)return f[p][k];
f[p][k]=(k==2);
if(!k){
for(int i=first[p];i;i=e[i].next){
int v=e[i].v;
if(v==fa)continue;
f[p][k]+=dfs2(v,1,p);
}
}
else if(k==1){
int max1=0;
for(int i=first[p];i;i=e[i].next){
int v=e[i].v;
if(v==fa)continue;
f[p][k]+=dfs2(v,1,p);
int tmp=dfs2(v,2,p)-dfs2(v,1,p);
if(max1<tmp)max1=tmp;
}
f[p][k]+=max1;
}
else for(int i=first[p];i;i=e[i].next){
int v=e[i].v;
if(v==fa)continue;
f[p][k]+=dfs2(v,0,p);
}
return f[p][k];
}
int main(){
n=read(),k=read();
for(int i=1;i<n;++i){
int u=read(),v=read();
add(u,v),add(v,u);
}
memset(f,-1,sizeof(f));
if(k==1)cout<<max(dfs1(1,1,1),dfs1(1,0,1));
else cout<<max(dfs2(1,1,1),dfs2(1,2,1));
return 0;
}