题目大意
给你一棵\(n\)个点的树,每个点有一种颜色;现在有\(m\)个询问,每次询问你\(x\)到\(y\)的路径上,若将\(a\)颜色视作\(b\)颜色,不同的颜色有几种。
\(n\leq 50000,m\leq 100000\)
分析
如果是把问题放到序列上:询问区间\([l,r]\)不同的颜色有几种。这个问题有两个已知的解法:
- 主席树(传送门)
- 莫队
看这题的数据范围显然是让你莫队了。(雾
树上莫队的第一步,是把树上问题转换为序列问题。我们求出原树的欧拉序,可以发现这个序列有这样的性质:
将一个点在欧拉序中首次出现和第二次出现的位置分别记作\(fir_u\)和\(las_u\),对于一条路径\((x,y)\)(假定\(fir_x<fir_y\))。
若\(lca(x,y)=x\),那么这条路径对应欧拉序中的区间\([fir_x,fir_y]\)。但是区间中出现两次的点要去掉,因为它们不属于这条路径。
若\(lca(x,y)\neq x\),那么这条路径对应欧拉序中的区间\([las_x,fir_y]\)。同样的要去掉出现两次的点,并且这个区间没有包括上\(lca\),要将\(lca\)再单独统计。
这样,树上问题就变成了序列问题。
为了不计算出现两次的点,我们开个标记数组,一个点每次出现,都把标记数组对应位置异或\(1\),那么一个点在标记数组中的值为\(1\)时才能被计算,当一个点对应的值变为\(0\)时又把它的贡献删去,这样问题便迎刃而解。再注意计算\(lca\)的答案即可。关于将\(a\)颜色视作\(b\)颜色的,只需判断区间中是否同时有\(a\)颜色和\(b\)颜色,有的话答案减\(1\),注意\(a=b\)要特判,不然要炸!
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200007;
int n, m, col[N], ans[N], ord[N];
int tot, dfn, st[N], to[N << 1], nx[N << 1], fir[N], las[N], anc[N][17], dep[N];
void add(int u, int v) { to[++tot] = v, nx[tot] = st[u], st[u] = tot; }
void dfs(int u)
{
fir[u] = ++dfn, ord[dfn] = u;
for (int i = st[u]; i; i = nx[i]) if (!fir[to[i]]) anc[to[i]][0] = u, dep[to[i]] = dep[u] + 1, dfs(to[i]);
las[u] = ++dfn, ord[dfn] = u;
}
int getlca(int u, int v)
{
if (dep[u] < dep[v]) swap(u, v);
for (int i = 16; i >= 0; i--) if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if (u == v) return u;
for (int i = 16; i >= 0; i--) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
int block, ret, be[N], tag[N], buc[N];
struct note { int l, r, id, a, b, lca; } q[N];
int cmp(note a, note b) { return be[a.l] == be[b.l] ? ((be[a.l] & 1) ? a.r < b.r : a.r > b.r) : a.l < b.l; }
void ins(int c, int v)
{
if (v == 1) { if (!buc[c]) ret++; buc[c]++; }
else { buc[c]--; if (!buc[c]) ret--; }
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &col[i]);
for (int i = 1, u, v; i <= n; i++)
{
scanf("%d%d", &u, &v);
if (u && v) add(u, v), add(v, u);
}
dep[1] = 1, dfs(1);
for (int j = 1; j <= 16; j++) for (int i = 1; i <= n; i++) anc[i][j] = anc[anc[i][j - 1]][j - 1];
block = sqrt(2 * n);
for (int i = 1; i <= 2 * n; i++) be[i] = i / block + 1;
for (int i = 1, x, y, a, b, lca; i <= m; i++)
{
scanf("%d%d%d%d", &x, &y, &a, &b);
if (fir[x] > fir[y]) swap(x, y);
lca = getlca(x, y);
if (lca == x) q[i] = (note){fir[x], fir[y], i, a, b, 0};
else q[i] = (note){las[x], fir[y], i, a, b, lca};
}
sort(q + 1, q + m + 1, cmp);
for (int i = 1, l = 1, r = 0; i <= m; i++)
{
while (l < q[i].l) tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]), ++l;
while (l > q[i].l) --l, tag[ord[l]] ^= 1, ins(col[ord[l]], tag[ord[l]]);
while (r < q[i].r) ++r, tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]);
while (r > q[i].r) tag[ord[r]] ^= 1, ins(col[ord[r]], tag[ord[r]]), --r;
if (q[i].lca) ins(col[q[i].lca], 1);
ans[q[i].id] = ret;
if (q[i].a != q[i].b && buc[q[i].a] && buc[q[i].b]) ans[q[i].id]--;
if (q[i].lca) ins(col[q[i].lca], 0);
}
for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
return 0;
}