自闭了几天后的我终于开始做题了。。然后调了3h一道点分治板子题,调了一天一道IOI。。。
最后还是自己手造数据debug出来的。。。
这题一看:树上路径问题,已知路径长度求balabala,显然是点分治(其实只要有一点点对点分治思想及应用的理解就能知道)。照普通点分治的做法,找重心分治,每次统计子树所有点到根的距离,然后开个桶判断一下是否出现即可(本题还要存一下边数)。然后我们要做的只有暴力统计取min了,时间复杂度\(O(n\log n)\)。
于是本蒟蒻自信满满地打下了以下代码:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
#define maxn 202000
#define maxk 10001000
struct edge {
int w, to, next;
} e[maxn << 1];
int head[maxn], dis[maxn], save[maxn], save2[maxn], cnt[maxn], vis[maxn], size[maxn], maxp[maxn], ecnt, sum, depth[maxn];
int get[maxk], get2[maxk];
int n, m, k, ans = 0x3f3f3f3f;
#define isdigit(x) ((x) >= '0' && (x) <= '9')
inline int read() {
int res = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) res = (res << 1) + (res << 3) + (c ^ 48), c = getchar();
return res;
}
void adde(int u, int v, int w) {
e[++ecnt] = (edge) {w, v, head[u]};
head[u] = ecnt;
}
int getrt(int x, int fa) {
int rt = 0;
maxp[x] = 0;
size[x] = 1;
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to] || to == fa) continue;
rt = getrt(to, x);
size[x] += size[to];
maxp[x] = max(maxp[x], size[to]);
}
maxp[x] = max(maxp[x], sum - size[x]);
if (maxp[rt] > maxp[x]) rt = x;
return rt;
}
void getdis(int x, int fa) {
save[++save[0]] = dis[x];
cnt[++cnt[0]] = depth[x];
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to] || to == fa) continue;
dis[to] = dis[x] + e[i].w;
depth[to] = depth[x] + 1;
getdis(to, x);
}
}
void work(int x) {
save2[0] = 0;
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to]) continue;
save[0] = cnt[0] = 0;
dis[to] = e[i].w;
depth[to] = 1;
getdis(to, x);
for (int j = 1; j <= save[0]; ++j) {
// get2[save[j]] = min(get2[save[j]], cnt[j]);
if (k >= save[j] && get[k - save[j]])
ans = min(ans, get2[k - save[j]] + cnt[j]);
}
for (int j = 1; j <= save[0]; ++j) {
get[save[j]] = 1, save2[++save2[0]] = save[j];
get2[save[j]] = min(get2[save[j]], cnt[j]);
}
}
for (int i = 1; i <= save2[0]; ++i) get[save2[i]] = 0, get2[save2[i]] = 0x3f3f3f3f;
}
void dfs(int x) {
vis[x] = get[0] = 1;
get2[0] = 0;
work(x);
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to]) continue;
maxp[0] = n;
sum = size[to];
dfs(getrt(to, 0));
}
}
int main() {
memset(get2, 0x3f, sizeof(get2));
n = read(), k = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read(), w = read();
adde(u + 1, v + 1, w);
adde(v + 1, u + 1, w);
}
maxp[0] = sum = n;
dfs(getrt(1, 0));
if (ans == 0x3f3f3f3f) puts("-1");
else printf("%d\n", ans);
return 0;
}
RE85。爆栈?算了就算爆栈现在除了重写一遍也没有别的解决办法了。。
尝试把两个桶开大2倍,变成90了。
再开大2倍,变成MLE了。看来是数据太大,桶开不下。于是写了个unordered_map
,常数极大,TLE到只有20分。
回过头来,发现save
里存的路径长度只有\(≤k\)时才对答案有贡献,不符合条件的都可以不存进桶里,桶可以只开到1e6多一点(保证k不越界就行)。加一句特判即可AC。
code:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
#define maxn 200007
#define maxk 1000007
struct edge {
int w, to, next;
} e[maxn << 1];
int head[maxn], dis[maxn], save[maxn], save2[maxn], cnt[maxn], vis[maxn], size[maxn], maxp[maxn], ecnt, sum, depth[maxn];
int get[maxk], get2[maxk];
int n, m, k, ans = 0x3f3f3f3f;
#define isdigit(x) ((x) >= '0' && (x) <= '9')
inline int read() {
int res = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) res = (res << 1) + (res << 3) + (c ^ 48), c = getchar();
return res;
}
void adde(int u, int v, int w) {
e[++ecnt] = (edge) {w, v, head[u]};
head[u] = ecnt;
}
int getrt(int x, int fa) {
int rt = 0;
maxp[x] = 0;
size[x] = 1;
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to] || to == fa) continue;
rt = getrt(to, x);
size[x] += size[to];
maxp[x] = max(maxp[x], size[to]);
}
maxp[x] = max(maxp[x], sum - size[x]);
if (maxp[rt] > maxp[x]) rt = x;
return rt;
}
void getdis(int x, int fa) {
save[++save[0]] = dis[x];
cnt[++cnt[0]] = depth[x];
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to] || to == fa) continue;
dis[to] = dis[x] + e[i].w;
depth[to] = depth[x] + 1;
getdis(to, x);
}
}
void work(int x) {
save2[0] = 0;
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to]) continue;
save[0] = cnt[0] = 0;
dis[to] = e[i].w;
depth[to] = 1;
getdis(to, x);
for (int j = 1; j <= save[0]; ++j) {
// get2[save[j]] = min(get2[save[j]], cnt[j]);
if (k >= save[j] && get[k - save[j]])
ans = min(ans, get2[k - save[j]] + cnt[j]);
}
for (int j = 1; j <= save[0]; ++j)
if (save[j] <= k) {
get[save[j]] = 1, save2[++save2[0]] = save[j];
get2[save[j]] = min(get2[save[j]], cnt[j]);
}
}
for (int i = 1; i <= save2[0]; ++i) get[save2[i]] = 0, get2[save2[i]] = 0x3f3f3f3f;
}
void dfs(int x) {
vis[x] = get[0] = 1;
get2[0] = 0;
work(x);
for (int i = head[x]; i; i = e[i].next) {
int to = e[i].to;
if (vis[to]) continue;
maxp[0] = n;
sum = size[to];
dfs(getrt(to, 0));
}
}
int main() {
memset(get2, 0x3f, sizeof(get2));
n = read(), k = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read(), w = read();
adde(u + 1, v + 1, w);
adde(v + 1, u + 1, w);
}
maxp[0] = sum = n;
dfs(getrt(1, 0));
if (ans == 0x3f3f3f3f) puts("-1");
else printf("%d\n", ans);
return 0;
}