题目大意
\(n\)个点的树, 树上每一个点有一个宝石\(w_i\), 给出一个固定的数字不重复的序列\(p_i\)和一些询问\(u_i, v_i\), 对于每一个询问求出\(u_i\)到\(v_i\)的路径上, 按所给顺序\(p\)收集宝石的最多个数.
\(n,q <= 10^5\)
解题思路
考场上是一点没想到这个.做题还是少了.
我们考虑把颜色重编号, 使得颜色是在\(p\)中的顺序. 设这个新颜色为\(c_i\).
那么问题转化为在询问上找最长的连续正整数子序列.
我们把询问拆成两端 一段是\(u\)到\(lca\), 一段是\(lca\)到\(v\) (注意在代码中 要小心\(lca\)被算两次的细节, 下面的做法是把\(lca\)放到了第一部分处理)
对于前一部分, 我们尽量找出长的一段来, 设一个倍增数组\(f_{u,i}\)表示在\(u\)到根的路径上颜色为\(c_u+2^i\)的最深的节点所在的位置. 可以通过桶来辅助处理.
到了一个点 就往上跳到一个颜色最大的深度大于等于\(lca\)的节点 这能得到前半部分的答案.
然而后一部分不太好求. 某些神仙于是想到二分.
该预处理的差不多, 只是变成了\(c_u-2^i\)的颜色.
我们二分到一个答案, 往上跳对应的步数, 如果深度大于\(lca\)那么合法.
这么一讲貌似不难嘛 但是难想 也不好写.
时间复杂度\(O(n\log^2n)\)
#include <cstdio>
#include <vector>
#define L 17
#define M 50010
#define N 200010
#define fo(i, a, b) for(int i = (a); i <= (b); ++i)
#define fd(i, a, b) for(int i = (a); i >= (b); --i)
using namespace std;
inline int read()
{
int x = 0; char ch = getchar();
while(ch < '0' || ch > '9') ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
return x;
}
int n, m, c, q, p[N], lc[N], dep[N], ord[N], ans[N], w[N], lastw[M], f[N][L + 5], g[N][L + 5];
int last[N], pre[N << 1], to[N << 1];
vector<int> s[N], t[N];
inline void add(int u, int v){static int tot = 0; to[++tot] = v, pre[tot] = last[u], last[u] = tot;}
void dfs1(int u)
{
dep[u] = dep[f[u][0]] + 1;
fo(i, 1, L) f[u][i] = f[f[u][i - 1]][i - 1];
for(int i = last[u]; i; i = pre[i])
{
int v = to[i];
if(v == f[u][0]) return ;
f[v][0] = u; dfs1(v);
}
}
inline int lca(int x, int y)
{
if(dep[x] < dep[y]) return lca(y, x);
fd(i, L, 0)
if(dep[f[x][i]] >= dep[y])
x = f[x][i];
if(x == y) return x;
fd(i, L, 0) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
void dfs2(int u)
{
int tmp = lastw[w[u]];
lastw[w[u]] = dep[u];
g[dep[u]][0] = lastw[w[u] + 1];
fo(i, 1, L) g[dep[u]][i] = g[g[dep[u]][i - 1]][i - 1];
for(int i = 0, siz = s[u].size(); i < siz; ++i)
{
int qry = s[u][i], cur = lastw[1];
if(cur < dep[lc[qry]]) continue ;
++ans[qry];
fd(j, L, 0)
if(g[cur][j] >= dep[lc[qry]])
{
ans[qry] += (1 << j);
cur = g[cur][j];
}
}
for(int i = last[u]; i; i = pre[i])
{
int v = to[i];
if(v != f[u][0]) dfs2(v);
}
lastw[w[u]] = tmp;
}
void dfs3(int u)
{
int tmp = lastw[w[u]];
lastw[w[u]] = dep[u];
g[dep[u]][0] = lastw[w[u] - 1];
fo(i, 1, L) g[dep[u]][i] = g[g[dep[u]][i - 1]][i - 1];
for(int i = 0, siz = t[u].size(); i < siz; ++i)
{
int qry = t[u][i];
int l = ans[qry] + 1, r = c, res = ans[qry];
while(l <= r)
{
int mid = (l + r) >> 1;
int cur = lastw[mid], cnt = mid - ans[qry] - 1;
fd(j, L, 0) ((1 << j) & cnt) && (cur = g[cur][j]);
cur > dep[lc[qry]] ? (l = (res = mid) + 1) : (r = mid - 1);
}
ans[qry] = res;
}
for(int i = last[u]; i; i = pre[i])
{
int v = to[i];
if(v != f[u][0]) dfs3(v);
}
lastw[w[u]] = tmp;
}
int main()
{
freopen("gem.in", "r", stdin);
freopen("gem.out", "w", stdout);
int u, v;
n = read(), m = read(), c = read();
fo(i, 1, c) p[i] = read(), ord[p[i]] = i;
fo(i, 1, n) w[i] = ord[read()];
fo(i, 2, n) u = read(), v = read(), add(u, v), add(v, u);
dfs1(1);
q = read();
fo(i, 1, q) u = read(), v = read(), lc[i] = lca(u, v), s[u].push_back(i), t[v].push_back(i);
dfs2(1);
dfs3(1);
fo(i, 1, q) printf("%d\n", ans[i]);
return 0;
}