正题
题目链接:https://www.luogu.com.cn/problem/P5311
题目大意
给出\(n\)个点的一棵树,每个节点有一个颜色,\(m\)次询问提出区间\([l,r]\)的点构成的生成子图中\(x\)所在连通块的颜色数。
\(1\leq n,m,a_i\leq 10^5\)
解题思路
用点分树解决本题是很妙的想法。/bx
考虑点分树如何解决,对于一个询问\(l,r,x\),如果\(x\)在点分树上的一个祖先\(y\)满足\(x\)到\(y\)的路径都是\([l,r]\)的点,那么此时在\(y\)的点分子树上的所有节点\(z\)都满足如果\(y\)到\(z\)的路径都是\([l,r]\)的节点,那么\(x\)到\(z\)的也是。具体原因很好理解,因为两条路径重复的那一段路已经满足条件了。
那么我们此时就可以将一个条件拆分成两个条件挂在\(y\)上了,具体地我们对于每个询问找到\(x\)点分树上深度最小的祖先\(y\)满足\(x\)到\(y\)的路径上都是\([l,r]\)的节点,然后把这个询问挂在这个点上。
然后我们暴力枚举所有点然后处理它的点分树子树,假设现在枚举到\(x\)点,我们把它的所有儿子按照\(x\)到它们的路径上的最小值从大到小询问,然后挂在\(x\)点上的询问按照\(l\)从大到小排序。
然后暴力遍历,记录每个颜色最小的\(mx\)表示\(y\)到这个颜色的点需要经过的路径最大值,然后把权值丢到树状数组上查询即可。
时间复杂度:\(O(n\log^2 n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define lowbit(x) (x&-x)
using namespace std;
const int N=1e5+10;
struct edge{
int to,next;
}a[N<<1];
struct node{
int x,fa,l,r;
};
int n,m,tot,num,root,ls[N],c[N];
int siz[N],f[N],t[N],lat[N],ans[N];
vector<node> anc[N],son[N],q[N];
bool v[N];
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
void Groot(int x,int fa){
siz[x]=1;f[x]=0;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa||v[y])continue;
Groot(y,x);siz[x]+=siz[y];
f[x]=max(f[x],siz[y]);
}
f[x]=max(f[x],num-siz[x]);
if(f[x]<f[root])root=x;
return;
}
void dfs(int x,int fa,int fr,int mi=1e9,int mx=-1e9){
mi=min(x,mi);mx=max(x,mx);
anc[x].push_back((node){x,fr,mi,mx});
son[fr].push_back((node){x,fr,mi,mx});
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa||v[y])continue;
dfs(y,x,fr,mi,mx);
}
return;
}
void Build(int x){
int S=num;v[x]=1;dfs(x,0,x);
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(v[y])continue;
num=(siz[y]>siz[x])?(S-siz[x]):siz[y];
root=0;Groot(y,x);Build(root);
}
return;
}
bool cmp(node x,node y)
{return x.l>y.l;}
void Change(int x,int val){
while(x<=n){
t[x]+=val;
x+=lowbit(x);
}
return;
}
int Ask(int x){
int ans=0;
while(x){
ans+=t[x];
x-=lowbit(x);
}
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&c[i]);
for(int i=1;i<n;i++){
int x,y;
x=i,y=i+1;
scanf("%d%d",&x,&y);
addl(x,y);addl(y,x);
}
f[0]=n+1;num=n;Groot(1,1);
Build(root);
for(int t=1;t<=m;t++){
int l,r,x;
scanf("%d%d%d",&l,&r,&x);
for(int i=0;i<anc[x].size();i++)
if(anc[x][i].l>=l&&anc[x][i].r<=r)
{q[anc[x][i].fa].push_back((node){x,t,l,r});break;}
}
for(int i=0;i<N;i++)lat[i]=n+1;
for(int p=1;p<=n;p++){
sort(q[p].begin(),q[p].end(),cmp);
sort(son[p].begin(),son[p].end(),cmp);
int z=0;
for(int i=0;i<son[p].size();i++){
node x=son[p][i];
while(z<q[p].size()&&q[p][z].l>x.l)
{node y=q[p][z];ans[y.fa]=Ask(y.r);z++;}
Change(lat[c[x.x]],-1);
lat[c[x.x]]=min(lat[c[x.x]],x.r);
Change(lat[c[x.x]],1);
}
while(z<q[p].size())
{node y=q[p][z];ans[y.fa]=Ask(y.r);z++;}
for(int i=0;i<son[p].size();i++){
node x=son[p][i];
Change(lat[c[x.x]],-1);
lat[c[x.x]]=n+1;
}
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}