P2633 Count on a tree 题解

Descirption

Luogu传送门

Solution

主席树 + LCA + 树上差分

看到 查询第 \(k\) 小的点权,自然想到主席树。

那么这道题就是在一棵树上维护一个主席树。

考虑一个数列上的主席树是如何建的,转换到一棵树上应该不难吧(

再来看两个点间的区间第 \(k\) 小如何找。

经典思想:树上差分

都来做这道题了不会真的有人还不会树上差分吧,不会吧不会吧。

一个点出现的次数即为:

\[cnt_x + cnt_y - cnt_{lca(x, y)} - cnt_{fa[(lca(x, y))]} \]

我们就在主席树上查一下这个即可。

我用的树链剖分找的 \(lca\)。

Code

个人觉得代码还是很优美的。

#include <bits/stdc++.h>
#define ls(x) t[x].l
#define rs(x) t[x].r

using namespace std;

namespace IO{
    inline int read(){
        int x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }

    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;

const int N = 1e5 + 10;
int n, m, tot;
int a[N], b[N];

namespace Segment_Tree{
    struct Tree{
        int rt, sum, l, r;
    }t[N << 5];
    int cnt;

    inline int build(int l, int r){
        int rt = ++cnt;
        if(l == r) return rt;
        int mid = (l + r) >> 1;
        ls(rt) = build(l, mid);
        rs(rt) = build(mid + 1, r);
        return rt;
    }

    inline int update(int pre, int l, int r, int x){
        int rt = ++cnt;
        ls(rt) = ls(pre), rs(rt) = rs(pre), t[rt].sum = t[pre].sum + 1;
        if(l == r) return rt;
        int mid = (l + r) >> 1;
        if(x <= mid) ls(rt) = update(ls(pre), l, mid, x);
        else rs(rt) = update(rs(pre), mid + 1, r, x);
        return rt;
    }

    inline int query(int u, int v, int p, int fap, int l, int r, int k){
        if(l == r) return l;
        int res = t[ls(u)].sum + t[ls(v)].sum - t[ls(p)].sum - t[ls(fap)].sum;
        int mid = (l + r) >> 1;
        if(res >= k) return query(ls(u), ls(v), ls(p), ls(fap), l, mid, k);
        else return query(rs(u), rs(v), rs(p), rs(fap), mid + 1, r, k - res);
    }
}
using namespace Segment_Tree;

namespace Tree_Chain_cut{
    struct node{
        int v, nxt;
    }edge[N << 1];
    int head[N], ecnt;

    inline void add(int x, int y){
        edge[++ecnt] = (node){y, head[x]};
        head[x] = ecnt;
    }

    int dep[N], siz[N], son[N], fa[N];

    inline void dfs1(int x, int p){
        t[x].rt = update(t[p].rt, 1, tot, a[x]);
        dep[x] = dep[p] + 1, fa[x] = p, siz[x] = 1;
        for(int i = head[x]; i; i = edge[i].nxt){
            int y = edge[i].v;
            if(y == p) continue;
            dfs1(y, x);
            siz[x] += siz[y];
            if(siz[y] > siz[son[x]]) son[x] = y;
        }
    }

    int top[N];

    inline void dfs2(int x, int topfa){
        top[x] = topfa;
        if(!son[x]) return;
        dfs2(son[x], topfa);
        for(int i = head[x]; i; i = edge[i].nxt){
            int y = edge[i].v;
            if(y == son[x] || y == fa[x]) continue;
            dfs2(y, y);
        }
    }

    inline int lca(int x, int y){
        while(top[x] != top[y]){
            if(dep[top[x]] < dep[top[y]]) swap(x, y);
            x = fa[top[x]];
        }
        return dep[x] < dep[y] ? x : y;
    }
}
using namespace Tree_Chain_cut;

inline void prework(){
    sort(b + 1, b + 1 + n);
    tot = unique(b + 1, b + 1 + n) - b - 1;
    for(int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + tot, a[i]) - b;
    for(int i = 1; i < n; ++i){
        int u = read(), v = read();
        add(u, v), add(v, u);
    }
    t[0].rt = build(1, tot);
    dfs1(1, 0), dfs2(1, 1);
}

int main(){
    n = read(), m = read();
    for(int i = 1; i <= n; ++i) a[i] = read(), b[i] = a[i];
    prework();
    int lst = 0;
    while(m--){
        int u = read() ^ lst, v = read(), k = read();
        int p = lca(u, v);
        write(lst = b[query(t[u].rt, t[v].rt, t[p].rt, t[fa[p]].rt, 1, tot, k)]), puts("");
    }
    return 0;
}

\[\_EOF\_ \]

上一篇:SpringMVC - 数据怎么显示到前端 Model, ModelMap, ModelAndView


下一篇:简单的springmvc搭建