P2633 Count on a tree
Description
- 给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
-
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
- M行,表示每个询问的答案。
Sample Input
8 5 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8 2 5 1 0 5 2 10 5 3 11 5 4 110 8 2
Sample Output
2 8 9 105 7
Data Size
- N,M<=100000
题解:
- 主席树。
- 跑到树上的主席树,挺好玩的。
- 首先想想模板主席树是怎样实现的:每个节点根据前一个节点建立。然后利用前缀和思想拿第r棵树 - 第l-1棵树得到[l,r]区间的信息,操作在这棵新树上操作即可。
- 那么到了树上呢?
- 可以对于每一个节点,在它父亲基础上建树。这样每一个节点所保存的信息就是它自身到根节点这条链上的信息。
- 然后我们就可以解决某点到根路径上的第k大啦!
- 等等,不是要解决x-y路径上的第k大吗?
- s[u]+s[v]−s[lca(u,v)]−s[fa[lca(u,v)]]。这不就表示成了x-y路径的信息了嘛:D
#include <iostream>
#include <cstdio>
#include <algorithm>
#define N 200005
#define find(x) (lower_bound(b + 1, b + 1 + cnt, x) - b)
using namespace std;
struct T {int l, r, sum;} t[N << 5];
struct E {int next, to;} e[N * 2];
int n, m, num, cnt, dex, last;
int h[N], a[N], b[N], fat[N], dep[N];
int son[N], top[N], size[N], rt[N];
int read()
{
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
return x *= f;
}
void add(int u, int v)
{
e[++num].next = h[u];
e[num].to = v;
h[u] = num;
}
void dfs1(int x, int fath, int depth)
{
size[x] = 1, fat[x] = fath, dep[x] = depth;
int maxSon = 0;
for(int i = h[x]; i != 0; i = e[i].next)
if(e[i].to != fath)
{
dfs1(e[i].to, x, depth + 1);
size[x] += size[e[i].to];
if(size[e[i].to] > maxSon)
{
maxSon = size[e[i].to];
son[x] = e[i].to;
}
}
}
void dfs2(int x, int head)
{
top[x] = head;
if(!son[x]) return;
dfs2(son[x], head);
for(int i = h[x]; i != 0; i = e[i].next)
if(e[i].to != fat[x] && e[i].to != son[x])
dfs2(e[i].to, e[i].to);
}
int lca(int x, int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fat[top[x]];
}
if(dep[x] < dep[y]) return x;
return y;
}
int build(int l, int r)
{
int p = ++dex, mid = l + r >> 1;
if(l == r) return p;
t[p].l = build(l, mid), t[p].r = build(mid + 1, r);
return p;
}
int upd(int las, int l, int r, int val)
{
int p = ++dex, mid = l + r >> 1;
t[p].l = t[las].l, t[p].r = t[las].r, t[p].sum = t[las].sum + 1;
if(l == r) return p;
if(val <= mid) t[p].l = upd(t[las].l, l, mid, val);
else t[p].r = upd(t[las].r, mid + 1, r, val);
return p;
}
void dfs(int x)
{
rt[x] = upd(rt[fat[x]], 1, cnt, find(a[x]));
for(int i = h[x]; i != 0; i = e[i].next)
if(e[i].to != fat[x]) dfs(e[i].to);
}
int ask(int s1, int s2, int fa, int pa, int l, int r, int rank)
{
int size = t[t[s2].l].sum + t[t[s1].l].sum - t[t[fa].l].sum - t[t[pa].l].sum;
int mid = l + r >> 1;
if(l == r) return l;
if(rank <= size) return ask(t[s1].l, t[s2].l, t[fa].l, t[pa].l, l, mid, rank);
else return ask(t[s1].r, t[s2].r, t[fa].r, t[pa].r, mid + 1, r, rank - size);
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
a[i] = read(), b[++cnt] = a[i];
sort(b + 1, b + 1 + cnt);
cnt = unique(b + 1, b + 1 + cnt) - b - 1;
for(int i = 1; i < n; i++)
{
int u = read(), v = read();
add(u, v), add(v, u);
}
rt[0] = build(1, cnt);
dfs1(1, 0, 1), dfs2(1, 1), dfs(1);
for(int i = 1; i <= m; i++)
{
int u = read() ^ last, v = read(), rank = read(), head = lca(u, v);
int res = ask(rt[u], rt[v], rt[head], rt[fat[head]], 1, cnt, rank);
printf("%d\n", last = b[res]);
}
return 0;
}