考虑将宝石的种类变成每种宝石出现的下标,这样问题变成树上路径找一段正整数前缀使得这个前缀是这个路径的子序列。
先跑出一遍 dfs 序,然后进行倍增,记 \(nxt_{i,j}\) 为从点 \(i\) 向下走,再装 \(2^j\) 个宝石走到的节点,\(pre_{i,j}\) 表示向上走的同样情况。然后重链剖分,对于每个询问按照 \(lca\) 拆成两部分,上行的部分在重链上尽可能多的跳 \(pre\) 数组,下行继续跳 \(nxt\) 数组即可。
比较难写。
#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
struct edge
{
int nxt,to;
}e[200001<<1];
int n,m,c,q,tot,h[200001],p[200001],w[200001],dep[200001],fa[200001],s[200001],son[200001],id[200001],top[200001],cnt,pre[200001][16],nxt[200001][16],ans;
pair<int,int> node[200001];
vector<pair<int,int> > tmp;
vector<int> v[200001];
inline int read()
{
int x=0;
char c=getchar();
while(c<‘0‘||c>‘9‘)
c=getchar();
while(c>=‘0‘&&c<=‘9‘)
{
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x;
}
void print(int x)
{
if(x>=10)
print(x/10);
putchar(x%10+‘0‘);
}
inline void add(int x,int y)
{
e[++tot].nxt=h[x];
h[x]=tot;
e[tot].to=y;
}
void dfs1(int k,int f,int deep)
{
dep[k]=deep;
fa[k]=f;
s[k]=1;
int maxson=-1;
for(register int i=h[k];i;i=e[i].nxt)
{
if(e[i].to==f)
continue;
dfs1(e[i].to,k,deep+1);
s[k]+=s[e[i].to];
if(s[e[i].to]>maxson)
{
maxson=s[e[i].to];
son[k]=e[i].to;
}
}
}
void dfs2(int k,int t)
{
id[k]=++cnt;
top[k]=t;
if(!son[k])
return;
dfs2(son[k],t);
for(register int i=h[k];i;i=e[i].nxt)
{
if(e[i].to==fa[k]||e[i].to==son[k])
continue;
dfs2(e[i].to,e[i].to);
}
}
inline int LCA(int x,int y)
{
while(top[x]^top[y])
{
if(dep[top[x]]<dep[top[y]])
x^=y^=x^=y;
x=fa[top[x]];
}
if(dep[x]>dep[y])
x^=y^=x^=y;
return x;
}
inline int find1(int x,int minn)
{
int res=0;
for(register int i=15;~i;--i)
if(pre[x][i]&&pre[x][i]>=minn)
{
x=pre[x][i];
res|=1<<i;
}
return res;
}
inline int find2(int x,int maxn)
{
int res=0;
for(register int i=15;~i;--i)
if(nxt[x][i]&&nxt[x][i]<=maxn)
{
x=nxt[x][i];
res|=1<<i;
}
return res;
}
inline void q1(int x,int y)
{
while(top[x]^top[y])
{
int l=id[top[x]],r=id[x];
int pos=upper_bound(v[ans].begin(),v[ans].end(),r)-v[ans].begin()-1;
if(pos==-1||v[ans][pos]<l)
{
x=fa[top[x]];
continue;
}
ans+=find1(v[ans][pos],l)+1;
x=fa[top[x]];
}
if(x^y)
{
int l=id[y],r=id[x];
int pos=upper_bound(v[ans].begin(),v[ans].end(),r)-v[ans].begin()-1;
if(pos==-1||v[ans][pos]<l)
return;
ans+=find1(v[ans][pos],l)+1;
}
}
inline void q2(int x,int y)
{
tmp.clear();
while(top[x]^top[y])
{
tmp.push_back(make_pair(id[top[x]],id[x]));
x=fa[top[x]];
}
tmp.push_back(make_pair(id[y],id[x]));
reverse(tmp.begin(),tmp.end());
for(register int i=0;i<(int)tmp.size();++i)
{
int l=tmp[i].first,r=tmp[i].second;
int pos=lower_bound(v[ans].begin(),v[ans].end(),l)-v[ans].begin();
if(pos==(int)v[ans].size()||v[ans][pos]>r)
continue;
ans+=find2(v[ans][pos],r)+1;
}
}
int main()
{
n=read(),m=read(),c=read();
for(register int i=1;i<=c;++i)
p[read()]=i;
for(register int i=1;i<=n;++i)
w[i]=p[read()];
for(register int i=1;i<n;++i)
{
int x=read(),y=read();
add(x,y);
add(y,x);
}
dfs1(1,0,1);
dfs2(1,1);
for(register int i=1;i<=n;++i)
node[i]=make_pair(id[i],w[i]);
sort(node+1,node+n+1);
for(register int i=1;i<=n;++i)
v[node[i].second].push_back(node[i].first);
for(register int i=n;i;--i)
{
int pos=upper_bound(v[node[i].second+1].begin(),v[node[i].second+1].end(),node[i].first)-v[node[i].second+1].begin();
if(pos!=(int)v[node[i].second+1].size())
{
nxt[i][0]=v[node[i].second+1][pos];
for(register int j=1;j<=15;++j)
nxt[i][j]=nxt[nxt[i][j-1]][j-1];
}
}
for(register int i=1;i<=n;++i)
{
int pos=upper_bound(v[node[i].second+1].begin(),v[node[i].second+1].end(),node[i].first)-v[node[i].second+1].begin()-1;
if(pos!=-1)
{
pre[i][0]=v[node[i].second+1][pos];
for(register int j=1;j<=15;++j)
pre[i][j]=pre[pre[i][j-1]][j-1];
}
}
//for(register int i=1;i<=n;++i)
//printf("%d %d\n",pre[i][0],nxt[i][0]);
q=read();
while(q--)
{
int x=read(),y=read(),lca=LCA(x,y);
ans=1;
q1(x,lca);
q2(y,lca);
print(ans-1);
putchar(‘\n‘);
}
return 0;
}