题目大意
小 S
在一棵被称为 OJT
的树上刷题。
这棵树上,有 \(n\) 个节点,每个节点上都有一道题目,每个节点上的题目难度可能会不同。
小 S
的能力有限,仅为k个单位能力,在第 \(i\) 个节点上的题目,难度为 \(D_i\)。
由于题目过毒,每做一道题都会杀死小S的脑细胞,使小 S
的能力值下降,做第 \(i\) 个节点上的题目会使他的能力值下降 \(c_i\) 个单位。
对于一道题,小 S
能拿到小 S
目前的能力值 \(\div D_i * 100\)(向下取整)分(若小 S
目前的能力值 \(\ge D_i\),小 S
就能拿到 \(100\)),因信号问题,实际的 \(D_i\) 为 \(\sum_{a=1}^{d[i]}\sum_{b=1}^{d[i]}{(d[i] \bmod a)(d[i] \bmod b)} \bmod 100\)。
小 S
希望可以拿到尽量高的总分数,希望你帮他找到他最多可以获得的总分数。
PS:小 S
总是从根节点 \(1\) 出发,每次向所在节点的其中一个子节点走,小 S
可以选择不做当前节点上的题目。
解题思路
树上背包好题,不过赛场上没想到,可惜。
确定思路后,优化思路。
考虑化简 \(D_i\),
根据模的意义,简化式子,设 \(D_i=a\),得 \(a=\sum_{i=1}^{a}\sum_{j=1}^{a}{(a - \left\lfloor \frac{a}{i}\right\rfloor i)(a - \left\lfloor \frac{a}{j}\right\rfloor j)} \bmod 100 \\ =\sum_{i=1}^{a}\sum_{j=1}^{a} a^2-a \left\lfloor \frac{a}{i} \right\rfloor i - a \left\lfloor \frac{a}{j} \right\rfloor j + \left\lfloor \frac{a}{i} \right\rfloor\left\lfloor \frac{a}{j} \right\rfloor ij \\ =a^4+2a^2 \sum\left \lfloor \frac{a}{i}\right \rfloor i+\sum\sum (\left\lfloor \frac{a}{i} \right\rfloor \left\lfloor \frac{a}{j} \right \rfloor ij)\)。
AC CODE
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int _ = 2007;
int n, k;
int c[_], d[_];
int tot, head[_], to[_ << 1], nxt[_ << 1];
int dp[_][_];
int ans;
int read()
{
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x;
}
void add(int u, int v)
{
to[++tot] = v;
nxt[tot] = head[u];
head[u] = tot;
}
int f(int x)
{
int ans = 0;
for(int i = 1; i <= x; ++i)
ans = ans + x % i;
return ans % 100;
}
void dfs(int u, int fa)
{
for(int i = 0; i <= k; ++i)
{
dp[u][i] = dp[fa][i];
if(i + c[u] <= k)
{
if(i + c[u] >= d[u])
dp[u][i] = max(dp[u][i], dp[fa][i + c[u]] + 100);
else
dp[u][i] = max(dp[u][i], dp[fa][i + c[u]] + (i + c[u]) * 100 / d[u]);
}
ans = max(ans, dp[u][i]);
}
for(int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if(v == fa) continue;
dfs(v, u);
}
}
signed main()
{
n = read();
k = read();
for(register int i = 1; i < n; ++i)
{
int u, v;
u = read();
v = read();
add(u, v);
add(v, u);
}
for(register int i = 1; i <= n; ++i)
{
int u;
u = read();
u = f(u);
d[i] = u * u % 100;
}
for(register int i = 1; i <= n; ++i)
c[i] = read();
dfs(1, 0);
printf("%lld\n", ans);
return 0;
}