Description
给定含有 \(n\) 个点的森林,\(m\) 次询问,每次询问包含两个参数 \(v,p\),求有多少个节点与节点 \(v\) 含有相同的 \(p\) 级祖先。
\(1\le n\le 10^5\)
Solution
看到线段树合并的题解那么少,就来补一篇吧。
前置知识
- 线段树合并
解法
我们考虑将问题转换一下:首先对于每个询问,求出其 \(p\) 级祖先,记作点 \(u\),问题实际上转化为了点 \(u\) 的 \(p\) 级子孙数量。
然后考虑点 \(c\) 若能成为 \(u\) 的 \(p\) 级子孙,\(c\) 实际上要满足如下两个条件:
- 在 \(u\) 的子树内
- \(\operatorname{dep}_c=\operatorname{dep}_u+p\)
于是问题转化为 \(u\) 子树内有多少个点满足 \(\operatorname{dep}_c=\operatorname{dep}_u+p\) ,考虑使用线段树合并解决该问题。
对询问离线,挂在每个点上(这个过程可以用 vector
),这里指的是求出了 \(p\) 级祖先之后的询问。
使用动态开点权值线段树,设节点 \(x\) 的区间为 \([l,r]\),则节点 \(x\) 维护当前有多少个点满足 \(\operatorname{dep}\) 在 \([l,r]\) 之间。
进行一次 dfs,首先添加点 \(u\) 的信息,然后再将其子树的线段树与其进行合并,最后对挂在该点的询问进行查询即可。
时间复杂度:\(\mathcal O(n\log n)\)。
Code
constexpr int N = 1e5 + 5 ;
int n,tot,m;
int fa[N][20],dep[N],ans[N],head[N],to[N << 1],nxt[N << 1],srt[N],rt[N];
std::vector<std::pair<int,int> > q[N] ;
//first k,second id
inline void add(int u,int v) {
static int cnt = 0 ;
to[++cnt] = v,nxt[cnt] = head[u],head[u] = cnt ;
}
void dfs1(int x,int f) {
dep[x] = dep[f] + 1,fa[x][0] = f ;
for (int i = 1; i < 20; ++i)
fa[x][i] = fa[fa[x][i - 1]][i - 1] ;
for (int i = head[x]; i; i = nxt[i])
if (int v = to[i]; v != f)
dfs1(v,x);
}
inline int jump(int x,int k) {
for (int i = 19; ~i; --i)
if (k & (1 << i))
x = fa[x][i] ;
return x ;
}
namespace seg {
#define mid ((l + r) >> 1)
constexpr int MAX_NODE = N << 5 ;
int cnt = 0 ;
int ls[MAX_NODE],rs[MAX_NODE],sum[MAX_NODE] ;
void check(int x,int l,int r) {
if (!x)
return ;
//std::cout << "sum " << l << "~" << r << " = " << sum[x] << '\n' ;
check(ls[x],l,mid),check(rs[x],mid + 1,r) ;
}
inline void pushup(int x) {
sum[x] = sum[ls[x]] + sum[rs[x]] ;
}
inline void change(int &x,int l,int r,int pos) {
x = !x ? ++cnt : x ;
if (l == r)
return ++sum[x],void() ;
if (pos <= mid)
change(ls[x],l,mid,pos) ;
else
change(rs[x],mid + 1,r,pos) ;
pushup(x) ;
}
int merge(int x,int y,int l,int r) {
if (!x || !y)
return x | y ;
if (l == r)
return sum[x] += sum[y],x ;
return ls[x] = merge(ls[x],ls[y],l,mid),rs[x] = merge(rs[x],rs[y],mid + 1,r),pushup(x),x ;
}
int query(int x,int l,int r,int pos) {
if (l == r)
return sum[x] ;
if (pos <= mid)
return query(ls[x],l,mid,pos) ;
else
return query(rs[x],mid + 1,r,pos) ;
}
}
void dfs2(int x) {
seg::change(srt[x],1,n,dep[x]) ;
//std::cout << "x = " << x << '\n' ;
//seg::check(srt[x],1,n) ;
for (int i = head[x]; i; i = nxt[i])
if (int v = to[i]; v != fa[x][0])
dfs2(v),srt[x] = seg::merge(srt[x],srt[v],1,n);
for (auto it : q[x])
ans[it.second] = seg::query(srt[x],1,n,it.first) - 1;
}
int main(int argc, const char **argv) {
#ifdef helios
freopen("std.in", "r", stdin);
freopen("std.ans", "w", stdout);
#endif
std::ios::sync_with_stdio(false), std::cin.tie(nullptr), std::cout.tie(nullptr);
in(n) ;
for (int i = 1,x; i <= n; ++i) {
in(x) ;
if (x)
add(x,i),add(i,x) ;
else
rt[++tot] = i ;
}
for (int i = 1; i <= tot; ++i)
dep[rt[i]] = 1,dfs1(rt[i],0) ;
in(m) ;
for (int i = 1,x,k; i <= m; ++i) {
in(x,k) ;
int ac = jump(x,k) ;
if (ac == 0) {
ans[i] = 0 ;
continue ;
}
q[ac].push_back(std::make_pair(k + dep[ac],i)) ;
}
for (int i = 1; i <= tot; ++i)
dfs2(rt[i]) ;
for (int i = 1; i <= m; ++i)
out(' ',ans[i]) ;
seg::check(srt[1],1,n) ;
std::cerr << "This program costs " << clock() << " ms" ;
return 0 ;
}
对代码有疑问可以私信我。