题目大意:
给你一颗树,每次求树上两点简单路径的交点个数
题目思路:
其实第一反应是lca,但是写了好多种情况并没有发现什么规律,
然后想用线段树维护个dfs序看序列里相同的数字的个数,但是dfs的顺序好像回影响答案,因为操作的是一个子树,
都到这里了,可以直接树剖,因为树剖时剖的时轻重链,可以利用重链链头和父亲节点关系“跳”从而实现对一个链上的的区间修改
然后每次询问第二条路径的上的1的个数就是答案
CODE:
int n, qq; vector<int> e[maxn]; int son[maxn], siz[maxn], dep[maxn], fa[maxn]; int tid2[maxn], top[maxn], tid[maxn], indexx; void dfs1(int u, int p) { siz[u] = 1, dep[u] = dep[p] + 1, fa[u] = p; for (int v : e[u]) { if (v == p) continue; dfs1(v, u); siz[u] += siz[v]; if (siz[son[u]] < siz[v]) son[u] = v; } } void dfs2(int u, int p) { top[u] = p, tid[u] = ++indexx, tid2[indexx] = u; if (son[u]) dfs2(son[u], p); for (int v : e[u]) if (v!=fa[u]&&v != son[u]) dfs2(v, v); } struct SegmentTree { int sum[maxn << 2], lazy[maxn << 2]; void pushDown(int i, int l, int r) { if(lazy[i]==0) return ; int mid = (l + r) >> 1; lazy[i << 1] += lazy[i], lazy[i << 1 | 1] += lazy[i]; sum[i << 1] += (mid - l + 1) * lazy[i], sum[i << 1 | 1] += (r - mid) * lazy[i]; lazy[i] = 0; } void pushUp(int i) { sum[i] = sum[i << 1] + sum[i << 1 | 1]; } void update(int i, int l, int r, int ql, int qr, int val) { if (ql <= l && qr >= r) { lazy[i] += val,sum[i] += (r - l + 1) * val; return; } pushDown(i, l, r); int mid = (l + r) >> 1; if (ql <= mid) update(i << 1, l, mid, ql, qr, val); if (qr > mid) update(i << 1 | 1, mid + 1, r, ql, qr, val); pushUp(i); } void update(int u, int v, int val) { while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) swap(u, v); int top1 = top[u]; update(1, 1, n, tid[top1], tid[u], val); u = fa[top1]; } if (dep[u] < dep[v])swap(u, v); update(1, 1, n, tid[v], tid[u], val); } int query(int i, int l, int r, int ql, int qr) { if (ql <= l && qr >= r) return sum[i]; int ans = 0; pushDown(i, l, r); int mid = (l + r) >> 1; if (ql <= mid) ans += query(i << 1, l, mid, ql, qr); if (qr > mid) ans += query(i << 1 | 1, mid + 1, r, ql, qr); return ans; } int query(int u, int v) { int ans = 0; while (top[u] != top[v]) { if (dep[top[u]] < dep[top[v]]) swap(u, v); int top1 = top[u]; ans += query(1, 1, n, tid[top1], tid[u]); u = fa[top1]; } if (dep[u] < dep[v]) swap(u, v); ans += query(1, 1, n, tid[v], tid[u]); return ans; } } st; int main() { n = read(), qq = read(); for (int i = 1; i <= n - 1; i++) { int u, v; u = read(), v = read(); e[u].push_back(v); e[v].push_back(u); } dfs1(1, 0), dfs2(1, 1); while (qq--) { int u1, v1, u2, v2; u1 = read(), v1 = read(), u2 = read(), v2 = read(); st.update(u1, v1, 1); out(st.query(u2, v2)); puts(""); st.update(u1, v1, -1); } return 0; }View Code