题目描述
输入
输出
样例输入
1 2
1000 1 0
1 1000 1
样例输出
0.001
题解
分数规划+树形背包dp
二分答案mid,题目便转化为求是否存在满足题目条件的集合V,使得$\frac{\sum\limits_{i\in V}p_i}{\sum\limits_{i\in V}s_i}\ge mid$,即$\sum\limits_{i\in V}(s_i-mid·p_i)\ge 0$。
这就转化为了一个树形dp问题。
令a[i]=s[i]-mid*p[i],表示i的性价比。设f[i][j]表示从子树i中选出j个且选i,可以获得的最大性价比之和,显然f[i][1]=a[i]。
那么对于每个i的子节点son,相当于有体积为1~si[son]共si[son]个物品放入背包内,每个物品可以放或不放。这相当于01背包问题。
但是这样dp的时间复杂度好像是$O(n^3)$的。
事实上,这里面的有效状态是很少的,如果只枚举有效状态,dp的时间复杂度将到达可以接受的$O(n^2)$。
具体粗略证明:
更新一棵子树的时间复杂度=更新该节点的子节点的时间复杂度+计算该节点的时间复杂度。
计算该节点的复杂度,如果采用最优策略,使用严格的有效区间范围来进行dp,时间复杂度应该为
$O(\sum\limits_{i=1}^m(1+\sum\limits_{j=1}^{i-1}si_j)·si_i)=O(\sum\limits_{i=1}^m\sum\limits_{j=1}^{i-1}si_j·si_i+\sum\limits_{i=1}^msi_i)=O((\sum\limits_{i=1}^msi_i)^2-\sum\limits_{i=1}^msi_i^2+\sum\limits_{i=1}^msi_i)=O(si_x^2-\sum\limits_{i=1}^msi_i^2+si_x)$,
其中$si_i(i\in[1,m])$表示x的第i个儿子节点的子树大小(总共有m个儿子节点),$si_x$表示x的子树大小。
而叶子节点的时间复杂度是$O(1)$的,进而我们可以使用累加法计算出总体dp的时间复杂度为$O(si_{root}^2+\sum\limits_{i=1}^nsi_i^2)=O(n^2)$。
因此总的时间复杂度是$O(n^2\log m)$。
为了避免精度误差带来的答案错误,建议固定二分c次,c值视情况而定,本题中取30可过。
另外由于数据太水了,所以$O(n^3\log m)$的做法也是可以通过本题的。(其实可以自己做一个链的数据卡掉它)
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 2510
using namespace std;
int head[N] , to[N] , next[N] , cnt , si[N] , w[N] , v[N] , n;
double a[N] , mid , f[N][N];
void add(int x , int y)
{
to[++cnt] = y , next[cnt] = head[x] , head[x] = cnt;
}
void init(int x)
{
int i;
si[x] = 1;
for(i = head[x] ; i ; i = next[i]) init(to[i]) , si[x] += si[to[i]];
}
void dfs(int x)
{
int i , j , k , tot = 0 , b = 0;
memset(f[x] , 0xc2 , sizeof(f[x]));
if(x) f[x][1] = a[x] , tot ++ , b ++ ;
else f[x][0] = 0;
for(i = head[x] ; i ; i = next[i])
{
dfs(to[i]);
for(j = tot ; j >= b ; j -- )
for(k = 1 ; k <= si[to[i]] ; k ++ )
f[x][j + k] = max(f[x][j + k] , f[x][j] + f[to[i]][k]);
tot += si[to[i]];
}
}
int main()
{
int n , k , i , x , c = 30;
double l = 0 , r = 0;
scanf("%d%d" , &k , &n);
for(i = 1 ; i <= n ; i ++ ) scanf("%d%d%d" , &w[i] , &v[i] , &x) , add(x , i) , r = max(r , (double)v[i]);
init(0);
while(c -- )
{
mid = (l + r) / 2;
for(i = 1 ; i <= n ; i ++ ) a[i] = v[i] - mid * w[i];
dfs(0);
if(f[0][k] >= 0) l = mid;
else r = mid;
}
printf("%.3lf\n" , (l + r) / 2);
return 0;
}