LOJ #2533. 「CTSC2018」暴力写挂(边分治合并)

题意

给你两个有 \(n\) 个点的树 \(T, T'\) ,求一对点对 \((x, y)\) 使得

\[depth(x) + depth(y) - (depth(LCA(x , y)) + depth′ (LCA′ (x, y)))
\]

最大。

数据范围

对于所有数据, \(n \le 366666 , |v| \le 2017011328\) 。

题解

第一次写边分治(原来碰到过都弃疗啦) 。

我们看这个式子不太舒服,化简一下:

\[\frac 1 2 (dist(x, y) - depth(x) - depth(y) + 2depth′ (LCA′ (x, y)) )
\]

这样的话,我们可以考虑枚举第二棵树的 \(LCA'\) ,那么我们意味着需要在这棵树的两个子树内找到点对 \((x, y)\) 使得它们的 \(dist(x, y) - depth(x) - depth(y)\) 最大。

怎么做呢,我们可以考虑边分治。边分治有什么好处呢?每次分治的话只有两边。

但是为了让复杂度正确,我们考虑使用二叉化,把度数降下来。不然可能在菊花处退化到 \(\mathcal O(n^2)\) 。

具体来说可以依次考虑每个儿子,然后每次拆掉原来的边,搞个新点和新边保证相对关系不变就行了,这样的点数好像是浪费最少的,时间和空间都比较优秀。

假设我们当前分治的两边是 \(U, V\) ,分治重心边的节点是 \(u, v\) 那么我们只需要找到一对点 \(p_1 \in U, p_2 \in V\) 使得 \(dist(p_1, u) + dist(p_2, v) + depth(p_1) + depth(p_2)\) 最大。

我们可以先每次分治预处理出每个分治块内的点到分治中心边的距离,那么我们对于每个点有个 \(dist(x, p) + dep_x\) 的信息。

那么对于每个点 \(x\) 会有 \(\log\) 个信息,我们记下每次相对于分治边的方向,那么可以唯一确定这个信息所在的分治块的位置!

这样有什么好处呢?我们在第二棵树上直接对于每个点的 \(\log\) 信息合并,那么这次合并的节点刚好对应上第一棵树上同一个分治块,这个类似于线段树合并的操作,用左边和右边的信息合并就行了。

这样复杂度是 \(\mathcal O(n \log n)\) ,注意空间也是 \(\mathcal O(n \log n)\) 的,要卡下空间。

代码

参考了 zhoushuyu 神仙的代码 ,写的好精简啊QAQ

#include <bits/stdc++.h>

#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
#define Travel(i, u, G) for (int i = G.Head[u], v = G.to[i], w = G.val[i]; i; i = G.Next[i], v = G.to[i], w = G.val[i])
#define fir first
#define sec second using namespace std; typedef long long ll;
typedef pair<int, int> PII; template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; } inline int read() {
int x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
} void File() {
#ifdef zjp_shadow
freopen ("2553.in", "r", stdin);
freopen ("2553.out", "w", stdout);
#endif
} const int N = 366666 + 233, NN = N * 3;
const ll inf = 0x3f3f3f3f3f3f3f3f; template<int Maxn, int Maxm>
struct Graph { int Head[Maxn], Next[Maxm], to[Maxm], val[Maxm], e; Graph() { e = 1; } inline void add_edge(int u, int v, int w) {
to[++ e] = v; Next[e] = Head[u]; Head[u] = e; val[e] = w;
} inline void Add(int u, int v, int w) {
add_edge(u, v, w); add_edge(v, u, w);
} }; Graph<N, N << 1> T1, T2;
Graph<NN, N << 3> DT; int node, n; ll dep[N]; void Build(int u, int fa = 0) {
int lst = u;
Travel(i, u, T1) if (v != fa) {
dep[v] = dep[u] + w, Build(v, u);
DT.Add(lst, ++ node, 0), DT.Add(node, v, w), lst = node;
}
} int minsz, id, sz[NN]; bool vis[N << 3]; void Get_Edge(int u, int fa, int tot) {
sz[u] = 1;
Travel(i, u, DT) if (v != fa && !vis[i]) {
Get_Edge(v, u, tot); sz[u] += sz[v];
if (chkmin(minsz, max(sz[v], tot - sz[v]))) id = i;
}
} const int Node = (NN) * 20; int ch[Node][2], stot; ll val[Node][2]; PII info[NN]; int rt[NN]; void Get_Info(int u, int fa, ll dis, int dir) {
if (u <= n) {
++ stot;
if (!info[u].fir) rt[u] = stot;
else ch[info[u].fir][info[u].sec] = stot;
info[u] = make_pair(stot, dir);
val[stot][dir] = dis + dep[u];
val[stot][dir ^ 1] = - inf;
}
sz[u] = 1;
Travel(i, u, DT) if (v != fa && !vis[i])
Get_Info(v, u, dis + w, dir), sz[u] += sz[v];
} int maxdep;
void Solve(int u, int tot) {
if (tot == 1) return;
minsz = node + 1; Get_Edge(u, 0, tot);
int x = DT.to[id], y = DT.to[id ^ 1];
vis[id] = vis[id ^ 1] = true;
Get_Info(x, 0, 0, 0);
Get_Info(y, 0, DT.val[id], 1);
Solve(x, sz[x]); Solve(y, sz[y]);
} ll ans = - inf; int Merge(int x, int y, ll dec) {
if (!x || !y) return x | y;
chkmax(ans, max(val[x][0] + val[y][1], val[x][1] + val[y][0]) - dec);
Rep (i, 2) chkmax(val[x][i], val[y][i]), ch[x][i] = Merge(ch[x][i], ch[y][i], dec);
return x;
} void Dfs(int u, int fa = 0, ll dis = 0) {
chkmax(ans, (dep[u] - dis) * 2);
Travel(i, u, T2) if (v != fa) {
Dfs(v, u, dis + w);
rt[u] = Merge(rt[u], rt[v], 2 * dis);
}
} int main () { File(); node = n = read();
For (i, 1, n - 1) {
int u = read(), v = read(), w = read(); T1.Add(u, v, w);
}
For (i, 1, n - 1) {
int u = read(), v = read(), w = read(); T2.Add(u, v, w);
} Build(1); Solve(1, node); Dfs(1); printf ("%lld\n", ans >> 1); return 0; }
上一篇:Luogu五月月赛


下一篇:CF552E 字符串 表达式求值