正题
题目链接:https://www.luogu.com.cn/problem/AT3611
题目大意
给出 n n n个点的一棵树。
现在有一张完全图,两个点之间的边权为 w x + w y + d i s ( x , y ) w_x+w_y+dis(x,y) wx+wy+dis(x,y)( d i s dis dis表示树上距离)
求这张完全图的最小生成树。
2 ≤ n ≤ 2 × 1 0 5 , 1 ≤ w i , c i ≤ 1 0 9 2\leq n\leq 2\times 10^5,1\leq w_i,c_i\leq 10^9 2≤n≤2×105,1≤wi,ci≤109
解题思路
考虑可能作为最小生成树的边。
一个结论就是对于一个子图。不在最小生成森林上的边一定不在原图的最小生成树上。
这样可以考虑分治,点分治之后对于根节点 x x x,其他的节点定义 f x = d e p x + w x f_x=dep_x+w_x fx=depx+wx,那么两个点之间边权就是 f x + f y f_x+f_y fx+fy了( x , y x,y x,y属于不同子树),对于同一子树的我们也加进去,因为这是不优的边所以不会影响答案。
此时图中的最小生成森林是其他所有点连接 f f f值最小的点。
这样我们可以处理出 n log n n\log n nlogn条可能的边,在这些边上再跑一次最小生成树就好了。
时间复杂度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=2e5+10,inf=1e18;
struct node{
ll to,next,w;
}a[N<<1];
struct edge{
ll x,y,w;
}e[N<<5];
ll n,tot,mins,root,ans,num,ent;
ll ls[N],f[N],siz[N],w[N],fa[N];
bool v[N];
void addl(ll x,ll y,ll w){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;a[tot].w=w;
return;
}
void groot(ll x,ll fa){
siz[x]=1;f[x]=0;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa||v[y])continue;
groot(y,x);siz[x]+=siz[y];
f[x]=max(f[x],siz[y]);
}
f[x]=max(f[x],num-siz[x]);
if(f[x]<f[root])root=x;
return;
}
void calc(ll x,ll fa,ll dep){
f[x]=w[x]+dep;
if(f[x]<f[mins])mins=x;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa||v[y])continue;
calc(y,x,dep+a[i].w);
}
return;
}
void adde(ll x,ll fa){
e[++ent]=(edge){x,mins,f[x]+f[mins]};
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(y==fa||v[y])continue;
adde(y,x);
}
}
void solve(ll x){
v[x]=1;f[x]=w[mins=x];
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(v[y])continue;
calc(y,x,a[i].w);
}
e[++ent]=(edge){x,mins,f[x]+f[mins]};
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(v[y])continue;
adde(y,x);
}
ll sum=num;
for(ll i=ls[x];i;i=a[i].next){
ll y=a[i].to;
if(v[y])continue;
num=(siz[y]>siz[x])?(sum-siz[x]):siz[y];
root=0;groot(y,x);solve(root);
}
return;
}
bool cmp(edge x,edge y)
{return x.w<y.w;}
ll find(ll x)
{return (fa[x]==x)?x:(fa[x]=find(fa[x]));}
signed main()
{
scanf("%lld",&n);
for(ll i=1;i<=n;i++)
scanf("%lld",&w[i]),fa[i]=i;
for(ll i=1;i<n;i++){
ll x,y,w;
scanf("%lld%lld%lld",&x,&y,&w);
addl(x,y,w);addl(y,x,w);
}
num=n;f[0]=inf;
groot(1,1);solve(root);
sort(e+1,e+1+ent,cmp);
for(ll i=1;i<=ent;i++){
ll x=e[i].x,y=e[i].y;
x=find(x);y=find(y);
if(x!=y)ans+=e[i].w,fa[y]=x;
}
printf("%lld\n",ans);
return 0;
}