题目链接
You are given a tree consisting of
n
n
n vertices, and
m
m
m simple vertex paths. Your task is to find how many pairs of those paths intersect at exactly one vertex. More formally you have to find the number of pairs
(
i
,
j
)
(i, j)
(i,j)
(
1
≤
i
<
j
≤
m
)
(1 \leq i < j \leq m)
(1≤i<j≤m) such that
p
a
t
h
i
path_i
pathi and
p
a
t
h
j
path_j
pathj have exactly one vertex in common.
Input
First line contains a single integer n n n ( 1 ≤ n ≤ 3 ⋅ 1 0 5 ) (1 \leq n \leq 3 \cdot 10^5) (1≤n≤3⋅105).
Next n − 1 n - 1 n−1 lines describe the tree. Each line contains two integers u u u and v v v ( 1 ≤ u , v ≤ n ) (1 \leq u, v \leq n) (1≤u,v≤n) describing an edge between vertices u u u and v v v.
Next line contains a single integer m m m ( 1 ≤ m ≤ 3 ⋅ 1 0 5 ) (1 \leq m \leq 3 \cdot 10^5) (1≤m≤3⋅105).
Next m m m lines describe paths. Each line describes a path by it’s two endpoints u u u and v v v ( 1 ≤ u , v ≤ n ) (1 \leq u, v \leq n) (1≤u,v≤n). The given path is all the vertices on the shortest path from u u u to v v v (including u u u and v v v).
Output
Output a single integer — the number of pairs of paths that intersect at exactly one vertex.
Examples
input
5
1 2
1 3
1 4
3 5
4
2 3
2 4
3 4
3 5
output
2
input
1
3
1 1
1 1
1 1
output
3
input
5
1 2
1 3
1 4
3 5
6
2 3
2 4
3 4
3 5
1 1
1 2
output
7
Note
The tree in the first example and paths look like this. Pairs ( 1 , 4 ) (1,4) (1,4) and ( 3 , 4 ) (3,4) (3,4) intersect at one vertex.
In the second example all three paths contain the same single vertex, so all pairs ( 1 , 2 ) (1, 2) (1,2), ( 1 , 3 ) (1, 3) (1,3) and ( 2 , 3 ) (2, 3) (2,3) intersect at one vertex.
The third example is the same as the first example with two additional paths. Pairs ( 1 , 4 ) (1,4) (1,4), ( 1 , 5 ) (1,5) (1,5), ( 2 , 5 ) (2,5) (2,5), ( 3 , 4 ) (3,4) (3,4), ( 3 , 5 ) (3,5) (3,5), ( 3 , 6 ) (3,6) (3,6) and ( 5 , 6 ) (5,6) (5,6) intersect at one vertex.
口胡了一个
O
(
m
log
2
n
)
O(m\log^2 n)
O(mlog2n) 的暴力做法,当时没来得及写后来发现真的可行
如上图所示,有一个公共点的两条路径只可能是这
3
3
3 种情况。
观察发现这三种情况的公共点都是其中一条路径中两个端点的
lca
\text{lca}
lca 。因此可以维护节点值,按照路径
lca
\text{lca}
lca 的深度由大到小添加路径,在每次加一条路径时将这条路径的
lca
\text{lca}
lca 端点值加
1
1
1 ,将非
lca
\text{lca}
lca 端点方向的
lca
\text{lca}
lca 子节点减
1
1
1 。这样如果新加入路径的如果与原来的路径的
lca
\text{lca}
lca 有交点则对答案贡献加
1
1
1 ,如果与
lca
\text{lca}
lca 外的点有交点,则贡献减
1
1
1 ,将答案加上路径上的节点权值和即可。这样处理对于第三种情况会存在问题,如果一条路径与另一条路径中两个
lca
\text{lca}
lca 子节点相交会多减
1
1
1 次贡献,可以利用容斥原理,用
map
\text{map}
map 统计并加上这部分值即可。节点值的维护可以用树剖维护,
lca
\text{lca}
lca 子节点可以用树上倍增
O
(
log
n
)
O(\log n)
O(logn) 查寻。
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
int head[N], ver[N << 1], Next[N << 1], tot;
int n, m;
inline void add(int x, int y) {
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}
struct Seg_Tree {
struct T {
int l, r, ans;
} t[N << 2];
void build(int p, int l, int r) {
t[p] = {l, r, 0};
if (l == r)return;
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
}
void change(int p, int x, int v) {
if (t[p].l == t[p].r)return t[p].ans += v, void();
int mid = (t[p].l + t[p].r) >> 1;
change(x <= mid ? p << 1 : p << 1 | 1, x, v);
t[p].ans = t[p << 1].ans + t[p << 1 | 1].ans;
}
int ask(int p, int l, int r) {
if (l <= t[p].l && r >= t[p].r)return t[p].ans;
int mid = (t[p].l + t[p].r) >> 1;
int val = 0;
if (l <= mid)val += ask(p << 1, l, r);
if (r > mid)val += ask(p << 1 | 1, l, r);
return val;
}
} S;
struct Tree {
int d[N], fa[N], siz[N], son[N], dfn[N], top[N], num;
inline void init() {
d[1] = 1, dfs1(1, 0), dfs2(1, 1), S.build(1, 1, n);
}
void dfs1(int x, int f) {
siz[x] = 1, son[x] = 0;
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if (y == f)continue;
fa[y] = x, d[y] = d[x] + 1, dfs1(y, x), siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
void dfs2(int x, int p) {
top[x] = p, dfn[x] = ++num;
if (son[x]) dfs2(son[x], p);
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if (fa[x] == y || y == son[x])continue;
dfs2(y, y);
}
}
int ask(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (d[top[x]] < d[top[y]])swap(x, y);
ans += S.ask(1, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
if (d[x] > d[y])swap(x, y);
ans += S.ask(1, dfn[x], dfn[y]);
return ans;
}
void change(int x, int v) {
S.change(1, dfn[x], v);
}
} T;
struct LCA {
int t, f[N][20];
inline void init() {
t = (int) log2(n), dfs(1, 0);
}
void dfs(int x, int fa) {
for (int i = head[x]; i; i = Next[i]) {
int y = ver[i];
if (y == fa)continue;
f[y][0] = x;
for (int j = 1; j <= t; j++)
f[y][j] = f[f[y][j - 1]][j - 1];
dfs(y, x);
}
}
inline int lca(int x, int y) {
if (T.d[x] > T.d[y])swap(x, y);
for (int i = t; i >= 0; i--)
if (T.d[f[y][i]] >= T.d[x])
y = f[y][i];
if (x == y)
return x;
for (int i = t; i >= 0; i--)
if (f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
inline int son(int x, int y) {
if (T.d[x] < T.d[y])swap(x, y);
for (int i = t; i >= 0; i--)
if (T.d[f[x][i]] > T.d[y])x = f[x][i];
return x;
}
} L;
struct P {
int x, y, l, d;
inline bool operator<(P a) const {
return d > a.d;
}
} p[N];
map<pair<int, int>, int> cnt;
int main() {
scanf("%d", &n);
for (int i = 1, x, y; i <= n - 1; i++) {
scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
T.init(), L.init();
scanf("%d", &m);
for (int i = 1, x, y, l; i <= m; i++) {
scanf("%d%d", &x, &y);
l = L.lca(x, y);
p[i] = {x, y, l, T.d[l]};
}
sort(p + 1, p + 1 + m);
long long ans = 0;
for (int i = 1; i <= m; i++) {
if (p[i].x == p[i].y)
ans += T.ask(p[i].x, p[i].y), T.change(p[i].x, 1);
else if (p[i].x == p[i].l || p[i].y == p[i].l) {
ans += T.ask(p[i].x, p[i].y);
if (T.d[p[i].x] < T.d[p[i].y])swap(p[i].x, p[i].y);
T.change(p[i].y, 1), T.change(L.son(p[i].x, p[i].y), -1);
} else {
int sx = L.son(p[i].x, p[i].l), sy = L.son(p[i].y, p[i].l);
if (sx > sy)swap(sx, sy), swap(p[i].x, p[i].y);
ans += T.ask(p[i].x, sx) + T.ask(p[i].y, p[i].l) + cnt[{sx, sy}];
T.change(p[i].l, 1), T.change(sx, -1), T.change(sy, -1), cnt[{sx, sy}]++;
}
}
printf("%lld\n", ans);
return 0;
}