Count on a tree II (树上莫队)

题目链接
题意:
给定一棵\(N\)个节点的树,节点编号从\(1\)到\(N\),每个节点都有一个整数权值。
现在,我们要进行\(M\)次询问,格式为\(u\) \(v\),对于每个询问你需要回答从\(u\)到\(v\)的路径上(包括两端点)共有多少种不同的点权值。

思路:
树上莫队,预处理得到树的欧拉序列。
\(dfs\)序列:是指将一棵树被\(dfs\)遍历时所经过的节点顺序,回溯时不再记录。
欧拉序列:\(dfs\)遍历,第一次遇到该点时记录一次时间戳\(first\),回溯时再记录一次时间戳\(last\),由此得到一序列为欧拉序列,其中\(first\)和\(last\)分别记录该点在序列中的第一次遍历和第二次遍历的时间戳。

令\(first[u]\)<\(first[v]\),即先遍历\(u\),对于\(u\)到\(v\)的路径分两种情况:
1.当\(u\)与\(v\)的最近公共祖先为\(u\)时,则从\(u\)到\(v\)的路径上的点即为区间\([first[u], first[v]]\)上的点,且区间内每个点只出现一次。
2.当\(u\)与\(v\)不在一颗子树上时,即它们的\(LCA\)为另一点时,则从\(u\)到\(v\)的路径上的点即为区间\([last[u], first[v]]\)上的只出现过一次的点再加上点\(LCA(u,v)\),(如果\(dfs\)从\(u\)到\(v\)时经过其他点,则在区间内必刚好出现两次,对此可以设置\(used\)数组,不断令\(used\)^\(1\)判断其是否出现两次)。特别注意:当处理区间\([l,r]\)内的点时,需映射到欧拉序列中,得到该点序号再处理,而对于\(LCA(u,v)\),无需映射,因为求得的\(LCA\)本身就是点的序号。
对于每次询问,得到区间左右端点,然后普通莫队离线处理即可。

code:

点击查看代码
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <deque>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
// #include <unordered_map>

#define fi first
#define se second
#define pb push_back
// #define endl "\n"
#define debug(x) cout << #x << ":" << x << endl;
#define bug cout << "********" << endl;
#define all(x) x.begin(), x.end()
#define lowbit(x) x & -x
#define fin(x) freopen(x, "r", stdin)
#define fout(x) freopen(x, "w", stdout)
#define ull unsigned long long
#define ll long long 

const double eps = 1e-15;
const int inf = 0x3f3f3f3f;
// const ll INF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1.0);
const int mod =  998244353;
const int maxn = 1e6 + 10;

using namespace std;

int our[maxn], n, m, s[maxn], p[maxn], ans[maxn];
int dp[maxn][32], dep[maxn], first[maxn], last[maxn];
int vis[maxn], tot, block, ret, used[maxn];
vector<int> v[maxn];
struct node{
    int l, r, lca, id;
    bool operator<(const node &a)const{
        return (l/block == a.l/block) ? ((l/block) & 1 ? r < a.r : r > a.r) : l < a.l; 
    }
}cnt[maxn];

void dfs(int u, int fa){
    dep[u] = dep[fa] + 1, dp[u][0] = fa;
    our[++ tot] = u;
    first[u] = tot;
    for(int i = 1; (1 << i) <= dep[fa]; i ++)dp[u][i] = dp[dp[u][i - 1]][i - 1];
    for(auto i : v[u]){
        if(i == fa)continue;
        dfs(i, u);
    }
    our[++ tot] = u;
    last[u] = tot;
}

int lca(int a, int b){
    if(dep[a] < dep[b])swap(a, b);
    int h = dep[a] - dep[b];
    for(int i = 25; i >= 0; i --){
        if((h >> i) & 1)a = dp[a][i];
    }
    if(a == b)return a;
    for(int i = 25; i >= 0; i --){
        if(dp[a][i] != dp[b][i])a = dp[a][i], b = dp[b][i];
    }
    return dp[a][0];
}

void work(int x){
    x = our[x];
    if(used[x])vis[s[x]] --, ret -= !vis[s[x]];
    else ret += !vis[s[x]], vis[s[x]] ++;
    used[x] ^= 1;
}

int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i ++)scanf("%d", &s[i]), p[i] = s[i];
    for(int i = 1, a, b; i < n; i ++){
        scanf("%d%d", &a, &b);
        v[a].pb(b), v[b].pb(a);
    }
    block = sqrt(tot);
    sort(p + 1, p + n + 1);
    int q = unique(p + 1, p + n + 1) - p - 1;
    for(int i = 1; i <= n; i ++)s[i] = lower_bound(p + 1, p + q + 1, s[i]) - p;
    dfs(1, 0);
    for(int i = 1; i <= m; i ++){
        int u, v;
        scanf("%d%d", &u, &v);
        int LCA = lca(u, v);
        if(first[u] > first[v])swap(u, v);
        if(LCA != u)cnt[i] = {last[u], first[v], LCA, i};
        else cnt[i] = {first[u], first[v], 0, i};
    }
    sort(cnt + 1, cnt + m + 1);
    int l = 1, r = 0;
    for(int i = 1; i <= m; i ++){
        while(r < cnt[i].r)work(++ r);
        while(r > cnt[i].r)work(r --);
        while(l < cnt[i].l)work(l ++);
        while(l > cnt[i].l)work(-- l);
        if(cnt[i].lca)work(cnt[i].lca);
        ans[cnt[i].id] = ret;
        if(cnt[i].lca)work(cnt[i].lca);
    }
    for(int i = 1; i <= m; i ++)printf("%d\n", ans[i]);
    return 0;
}
上一篇:[做题记录-数据结构] Luogu5210 [ZJOI2017]线段树


下一篇:【初赛解析】2021CSP-S初赛解析(不完全)