题目链接
题意:
给定一棵\(N\)个节点的树,节点编号从\(1\)到\(N\),每个节点都有一个整数权值。
现在,我们要进行\(M\)次询问,格式为\(u\) \(v\),对于每个询问你需要回答从\(u\)到\(v\)的路径上(包括两端点)共有多少种不同的点权值。
思路:
树上莫队,预处理得到树的欧拉序列。
\(dfs\)序列:是指将一棵树被\(dfs\)遍历时所经过的节点顺序,回溯时不再记录。
欧拉序列:\(dfs\)遍历,第一次遇到该点时记录一次时间戳\(first\),回溯时再记录一次时间戳\(last\),由此得到一序列为欧拉序列,其中\(first\)和\(last\)分别记录该点在序列中的第一次遍历和第二次遍历的时间戳。
令\(first[u]\)<\(first[v]\),即先遍历\(u\),对于\(u\)到\(v\)的路径分两种情况:
1.当\(u\)与\(v\)的最近公共祖先为\(u\)时,则从\(u\)到\(v\)的路径上的点即为区间\([first[u], first[v]]\)上的点,且区间内每个点只出现一次。
2.当\(u\)与\(v\)不在一颗子树上时,即它们的\(LCA\)为另一点时,则从\(u\)到\(v\)的路径上的点即为区间\([last[u], first[v]]\)上的只出现过一次的点再加上点\(LCA(u,v)\),(如果\(dfs\)从\(u\)到\(v\)时经过其他点,则在区间内必刚好出现两次,对此可以设置\(used\)数组,不断令\(used\)^\(1\)判断其是否出现两次)。特别注意:当处理区间\([l,r]\)内的点时,需映射到欧拉序列中,得到该点序号再处理,而对于\(LCA(u,v)\),无需映射,因为求得的\(LCA\)本身就是点的序号。
对于每次询问,得到区间左右端点,然后普通莫队离线处理即可。
code:
点击查看代码
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <deque>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
// #include <unordered_map>
#define fi first
#define se second
#define pb push_back
// #define endl "\n"
#define debug(x) cout << #x << ":" << x << endl;
#define bug cout << "********" << endl;
#define all(x) x.begin(), x.end()
#define lowbit(x) x & -x
#define fin(x) freopen(x, "r", stdin)
#define fout(x) freopen(x, "w", stdout)
#define ull unsigned long long
#define ll long long
const double eps = 1e-15;
const int inf = 0x3f3f3f3f;
// const ll INF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1.0);
const int mod = 998244353;
const int maxn = 1e6 + 10;
using namespace std;
int our[maxn], n, m, s[maxn], p[maxn], ans[maxn];
int dp[maxn][32], dep[maxn], first[maxn], last[maxn];
int vis[maxn], tot, block, ret, used[maxn];
vector<int> v[maxn];
struct node{
int l, r, lca, id;
bool operator<(const node &a)const{
return (l/block == a.l/block) ? ((l/block) & 1 ? r < a.r : r > a.r) : l < a.l;
}
}cnt[maxn];
void dfs(int u, int fa){
dep[u] = dep[fa] + 1, dp[u][0] = fa;
our[++ tot] = u;
first[u] = tot;
for(int i = 1; (1 << i) <= dep[fa]; i ++)dp[u][i] = dp[dp[u][i - 1]][i - 1];
for(auto i : v[u]){
if(i == fa)continue;
dfs(i, u);
}
our[++ tot] = u;
last[u] = tot;
}
int lca(int a, int b){
if(dep[a] < dep[b])swap(a, b);
int h = dep[a] - dep[b];
for(int i = 25; i >= 0; i --){
if((h >> i) & 1)a = dp[a][i];
}
if(a == b)return a;
for(int i = 25; i >= 0; i --){
if(dp[a][i] != dp[b][i])a = dp[a][i], b = dp[b][i];
}
return dp[a][0];
}
void work(int x){
x = our[x];
if(used[x])vis[s[x]] --, ret -= !vis[s[x]];
else ret += !vis[s[x]], vis[s[x]] ++;
used[x] ^= 1;
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i ++)scanf("%d", &s[i]), p[i] = s[i];
for(int i = 1, a, b; i < n; i ++){
scanf("%d%d", &a, &b);
v[a].pb(b), v[b].pb(a);
}
block = sqrt(tot);
sort(p + 1, p + n + 1);
int q = unique(p + 1, p + n + 1) - p - 1;
for(int i = 1; i <= n; i ++)s[i] = lower_bound(p + 1, p + q + 1, s[i]) - p;
dfs(1, 0);
for(int i = 1; i <= m; i ++){
int u, v;
scanf("%d%d", &u, &v);
int LCA = lca(u, v);
if(first[u] > first[v])swap(u, v);
if(LCA != u)cnt[i] = {last[u], first[v], LCA, i};
else cnt[i] = {first[u], first[v], 0, i};
}
sort(cnt + 1, cnt + m + 1);
int l = 1, r = 0;
for(int i = 1; i <= m; i ++){
while(r < cnt[i].r)work(++ r);
while(r > cnt[i].r)work(r --);
while(l < cnt[i].l)work(l ++);
while(l > cnt[i].l)work(-- l);
if(cnt[i].lca)work(cnt[i].lca);
ans[cnt[i].id] = ret;
if(cnt[i].lca)work(cnt[i].lca);
}
for(int i = 1; i <= m; i ++)printf("%d\n", ans[i]);
return 0;
}