Problem Description
In this problem we consider a rooted tree with N vertices. The vertices are numbered from 1 to N, and vertex 1 represents the root. There are integer weights on each vectice. Your task is to answer a list of queries, for each query, please tell us among all the vertices in the subtree rooted at vertice u, how many different kinds of weights appear exactly K times?
给出一根为\(1\)的树,每个点都有权值,每次询问以\(u\)为根的子树中正好出现\(k\)次的权值有几个
权值可以离散化,离散化之后范围在\([1,n]\),然后通过树上启发式合并可以得到所有点的答案,对于每个询问查询答案即可
具体方法就是维护每个权值出现的次数,和每个权值出现的次数的次数
//#pragma GCC optimize("O3")
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<bits/stdc++.h>
using namespace std;
function<void(void)> ____ = [](){ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);};
const int MAXN = 1e5+7;
int n,k,w[MAXN],ret[MAXN],sz[MAXN],son[MAXN],cnt[MAXN],app[MAXN];
vector<int> G[MAXN];
void dfs(int u, int par){
son[u] = 0; sz[u] = 1;
for(int v : G[u]){
if(v==par) continue;
dfs(v,u);
sz[u] += sz[v];
if(sz[son[u]]<sz[v]) son[u] = v;
}
}
void inc(int x){
cnt[x]++;
app[cnt[x]]++;
app[cnt[x]-1]--;
}
void dec(int x){
cnt[x]--;
app[cnt[x]]++;
app[cnt[x]+1]--;
}
void update(int u, int par, bool add){
if(add) inc(w[u]);
else dec(w[u]);
for(int v : G[u]) if(v!=par) update(v,u,add);
}
void search(int u, int par, bool clear){
for(int v : G[u]) if(v!=par and v!=son[u]) search(v,u,true);
if(son[u]) search(son[u],u,false);
for(int v : G[u]) if(v!=par and v!=son[u]) update(v,u,true);
inc(w[u]);
ret[u] = app[k];
if(clear) update(u,par,false);
}
void solve(int kase){
if(kase!=1) puts("");
scanf("%d %d",&n,&k);
for(int i = 1; i <= n; i++) G[i].clear();
vector<int> vec;
for(int i = 1; i <= n; i++){
scanf("%d",&w[i]);
vec.emplace_back(w[i]);
}
sort(vec.begin(),vec.end());
vec.erase(unique(vec.begin(),vec.end()),vec.end());
for(int i = 1; i <= n; i++) w[i] = lower_bound(vec.begin(),vec.end(),w[i]) - vec.begin() + 1;
for(int i = 1; i < n; i++){
int u, v; scanf("%d %d",&u,&v);
G[u].emplace_back(v); G[v].emplace_back(u);
}
dfs(1,0);
search(1,0,true);
printf("Case #%d:\n",kase);
int q; scanf("%d",&q);
while(q--){
int x; scanf("%d",&x);
printf("%d\n",ret[x]);
}
}
int main(){
int T; scanf("%d",&T);
for(int kase = 1; kase <= T; kase++) solve(kase);
return 0;
}