看到这题的第一个想法就是:
树链剖分+线段树套平衡树(set)
对于每个线段树节点维护一个set,记录该节点代表的区间有哪几种奶牛。
效率大概是$O(Nlog^2N)$(也可能是$log^3$?,我太蒻了不会证),吸个氧就过了。
代码:
1 #include<cstdio> 2 #include<algorithm> 3 #include<set> 4 #define N 100005 5 6 inline void rd(int &x){ 7 x=0;char c=getchar(); 8 while(c<'0'||c>'9')c=getchar(); 9 while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar(); 10 } 11 12 int n,m,a[N]; 13 14 int hd[N],_hd; 15 struct edge{ 16 int v,nxt; 17 }e[N<<1]; 18 inline void addedge(int u,int v){ 19 e[++_hd]=(edge){v,hd[u]}; 20 hd[u]=_hd; 21 } 22 23 int fa[N],dep[N],sz[N],son[N]; 24 inline void dfs1(int u,int Fa){ 25 fa[u]=Fa; 26 dep[u]=dep[Fa]+1; 27 sz[u]=1; 28 for(int i=hd[u];i;i=e[i].nxt){ 29 int v=e[i].v; 30 if(v==Fa) 31 continue; 32 dfs1(v,u); 33 sz[u]+=sz[v]; 34 if(sz[v]>sz[son[u]]) 35 son[u]=v; 36 } 37 } 38 int id[N],_id,pos[N],top[N]; 39 inline void dfs2(int u){ 40 id[u]=++_id; 41 pos[_id]=u; 42 top[u]=u==son[fa[u]]?top[fa[u]]:u; 43 if(son[u]) 44 dfs2(son[u]); 45 for(int i=hd[u];i;i=e[i].nxt){ 46 int v=e[i].v; 47 if(v==fa[u]||v==son[u]) 48 continue; 49 dfs2(v); 50 } 51 } 52 53 std::set<int> s[N<<2]; 54 inline void build(int p,int L,int R){ 55 for(int i=L;i<=R;i++) 56 s[p].insert(a[pos[i]]); 57 if(L==R) 58 return; 59 int mid=(L+R)>>1; 60 build(p<<1,L,mid); 61 build(p<<1|1,mid+1,R); 62 } 63 inline bool sch(int p,int L,int R,int l,int r,int x){ 64 if(L>r||R<l) 65 return 0; 66 if(l<=L&&R<=r) 67 return s[p].count(x); 68 int mid=(L+R)>>1; 69 return sch(p<<1,L,mid,l,r,x)||sch(p<<1|1,mid+1,R,l,r,x); 70 } 71 72 inline bool lca_sch(int u,int v,int x){ 73 while(top[u]!=top[v]){ 74 if(dep[top[u]]<dep[top[v]]) 75 std::swap(u,v); 76 if(sch(1,1,n,id[top[u]],id[u],x)) 77 return 1; 78 u=fa[top[u]]; 79 } 80 if(dep[u]<dep[v]) 81 std::swap(u,v); 82 return sch(1,1,n,id[v],id[u],x); 83 } 84 85 int main(){ 86 // freopen("milkvisits.in","r",stdin); 87 // freopen("milkvisits.out","w",stdout); 88 rd(n),rd(m); 89 for(int i=1;i<=n;i++) 90 scanf("%d",&a[i]); 91 for(int i=1;i<n;i++){ 92 int u,v; 93 rd(u),rd(v); 94 addedge(u,v); 95 addedge(v,u); 96 } 97 dfs1(1,0); 98 dfs2(1); 99 build(1,1,n); 100 while(m--){ 101 int u,v,x; 102 rd(u),rd(v),rd(x); 103 printf("%d",lca_sch(u,v,x)); 104 } 105 106 #define w 0 107 return ~~('0')?(0^w^0):(0*w*0); 108 }