树形dp套树形数据结构
对每个节点 \(i\) ,分两步进行:
1.令 \(f_i\) 为Mitya在节点 \(i\) 停止游戏最多可以吃到多少块饼干
我们可以进行一次dfs,用一个树形数据结构(树状数组、线段树、平衡树)来维护从树根节点 \(1\) 到节点 \(i\) 经过的路径上所有节点的“属性二元组”—— \((t_i,x_i)\) ,为了吃到更多块饼干,应该尽量吃 \(t_i\) 小的饼干。
以树状数组为例。由于 \(1≤t_i≤10^6\) ,直接在值域上建立两个树状数组(当然可以先用 \(O(n\ log\ n)\) 的时间离散化,不过此题不必要),一个记录时间 \(ct\) ,一个记录个数 \(cx\) 。dfs到节点 \(i\) 时,首先执行add操作,在 \(ct\) 的位置 \(t_i\) 上加 \(t_ix_i\) ,同时在 \(cx\) 的位置 \(t_i\) 上加 \(x_i\) 。在 \(ct\) 上二分查找 \(ask_{ct}(k)≤T-2S\) 的最大值,其中 \(S\) 为从树根节点 \(1\) 到节点 \(i\) 所需的时间,一来一回所以需要两倍的 \(S\) 。那么 \(f_i\) 的值为在 \(cx\) 上 \(ask_{cx}(k)\) 与剩下时间 \(T-2S-ask_{ct}(k)\) 除以后一个吃单块饼干的时间 \(k+1\) 向下取整后的值之和(当然若 \(k=m\) 则不需要加上后面的部分)。dfs回溯的时候进行类似操作即可。
2.考虑Vasya的干扰,由于Vasya可以删除节点 \(i\) 与 \(i\) 的子节点之间的边,为了删除的有意义,就必须要删掉节点 \(i\) 的所有子节点中结果最大的节点。
因此,令 \(m1_i,m2_i\) 分别为在以节点 \(i\) 为根的子树中除节点 \(i\) 自己外结果的最大值和非严格次大值,特别的,若不存在次大值, \(m2_i=0\)
令 \(dp_i\) 为按照游戏规则在以节点 \(i\) 为根的子树中,Mitya在某一个节点停止游戏可以吃到的饼干数的最大值
显然,若节点 \(i\) 为根节点 \(1\) ,由于Mitya先进行操作,因此 \(dp_i=max(f_i,m1_i)\)
否则,\(dp_i=max(f_i,m2_i)\)
\(dp_1\) 即为所求
时间复杂度 \(O(n\ log^2\ maxT)\)
代码:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 100006, M = 1000006;
ll n, m, T, x[N], t[N], f[N], ct[M], cx[M], m1[N], m2[N], dp[N];
vector<pair<int, ll> > e[N];
void addct(int a, ll k) {
while (a <= m) {
ct[a] += k;
a += a & -a;
}
}
ll askct(int a) {
ll ans = 0;
while (a) {
ans += ct[a];
a -= a & -a;
}
return ans;
}
void addcx(int a, ll k) {
while (a <= m) {
cx[a] += k;
a += a & -a;
}
}
ll askcx(int a) {
ll ans = 0;
while (a) {
ans += cx[a];
a -= a & -a;
}
return ans;
}
void dfs(int a, ll S) {
addct(t[a], t[a] * x[a]);
addcx(t[a], x[a]);
int l = 1, r = m + 1;
while (l < r) {
int mid = (l + r) >> 1;
if (askct(mid) <= T - (S << 1)) l = mid + 1;
else r = mid;
}
int k = l - 1;
f[a] = askcx(k);
if (k != m) f[a] += (T - (S << 1) - askct(k)) / (k + 1);
for (unsigned int i = 0; i < e[a].size(); i++)
dfs(e[a][i].first, S + e[a][i].second);
addct(t[a], -t[a] * x[a]);
addcx(t[a], -x[a]);
}
void dfs(int a) {
for (unsigned int i = 0; i < e[a].size(); i++) {
int y = e[a][i].first;
dfs(y);
if (dp[y] >= m1[a]) {
m2[a] = m1[a];
m1[a] = dp[y];
} else if (dp[y] >= m2[a]) m2[a] = dp[y];
}
if (a == 1) dp[a] = max(f[a], m1[a]);
else dp[a] = max(f[a], m2[a]);
}
int main() {
cin >> n >> T;
for (int i = 1; i <= n; i++) scanf("%lld", &x[i]);
for (int i = 1; i <= n; i++) {
scanf("%lld", &t[i]);
m = max(m, t[i]);
}
for (int i = 2; i <= n; i++) {
int p;
ll l;
scanf("%d %lld", &p, &l);
e[p].push_back(make_pair(i, l));
}
dfs(1, 0);
dfs(1);
cout << dp[1] << endl;
return 0;
}