观察数据范围是跟k有关的,因此我们考虑建立虚树,对于维护三个值
总和就是常规的按每条路左右两边点数算贡献,注意是特殊点的数量
之后我们维护mi[i],mx[i]表示对于当前点,子树中离他最近的特殊点在哪
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll,ll> pll; const int N=2e6+10; const int M=2e6+10; const int inf=0x3f3f3f3f; const ll mod=998244353; struct node{ int x; ll w; }; vector<node> g[N],g1[N]; int dfn[N],times,depth[N]; int n,f[N][25],cost[N][25]; int st[N]; int cnt[N]; ll sz[N]; int q[N],k; ll dp[N][2],sum,mi,mx; ll ans1[N],ans2[N]; void dfs(int u,int fa,ll w){ dfn[u]=++times; depth[u]=depth[fa]+1; f[u][0]=fa,cost[u][0]=w; int i; for(i=1;i<=21;i++){ f[u][i]=f[f[u][i-1]][i-1]; cost[u][i]=cost[f[u][i-1]][i-1]+cost[u][i-1]; } for(i=0;i<g1[u].size();i++){ int v=g1[u][i].x; if(v==fa) continue; dfs(v,u,g1[u][i].w); } } bool cmp(int a,int b){ return dfn[a]<dfn[b]; } int lca(int a,int b){ if(depth[a]<depth[b]) swap(a,b); int i; for(i=21;i>=0;i--){ if(depth[f[a][i]]>=depth[b]){ a=f[a][i]; } } if(a==b) return a; for(i=21;i>=0;i--){ if(f[a][i]!=f[b][i]){ a=f[a][i]; b=f[b][i]; } } return f[a][0]; } ll query(int a,int b){ ll ans=0; if(depth[a]>depth[b]) swap(a,b); for(int i=21;i>=0;i--){ if(depth[f[b][i]]>=depth[a]){ ans+=cost[b][i]; b=f[b][i]; } } return ans; } void add(int a,int b){ ll num=query(a,b); g[a].push_back({b,num}); } void get(int u,int fa){ int i; if(st[u]){ sz[u]=1; ans1[u]=ans2[u]=0; } else{ ans1[u]=inf,ans2[u]=-inf,sz[u]=0; } for(i=0;i<g[u].size();i++){ int v=g[u][i].x; if(v==fa) continue; get(v,u); sum+=sz[v]*(k-sz[v])*g[u][i].w; sz[u]+=sz[v]; mi=min(mi,ans1[u]+g[u][i].w+ans1[v]); mx=max(mx,ans2[u]+g[u][i].w+ans2[v]); ans1[u]=min(ans1[u],g[u][i].w+ans1[v]); ans2[u]=max(ans2[u],g[u][i].w+ans2[v]); st[v]=0; } g[u].clear(); } int main(){ ios::sync_with_stdio(false); cin>>n; int i; for(i=1;i<n;i++){ int a,b; cin>>a>>b; g1[a].push_back({b,1}); g1[b].push_back({a,1}); } dfs(1,0,0); int m; cin>>m; while(m--){ cin>>k; for(i=1;i<=k;i++){ int x; cin>>x; st[x]=1; cnt[i]=x; } sort(cnt+1,cnt+1+k,cmp); q[1]=1; int tt=1; for(i=1;i<=k;i++){ if(cnt[i]==1) continue; int p=lca(cnt[i],q[tt]); if(p!=q[tt]){ while(dfn[p]<dfn[q[tt-1]]){ add(q[tt-1],q[tt]); tt--; } if(dfn[p]!=dfn[q[tt-1]]){ add(p,q[tt]); q[tt]=p; } else{ add(p,q[tt]); tt--; } } q[++tt]=cnt[i]; } for(i=1;i<tt;i++){ add(q[i],q[i+1]); } sum=0,mi=1e18,mx=-1e18; get(1,-1); cout<<sum<<" "<<mi<<" "<<mx<<endl; st[1]=0; } return 0; }View Code