考虑如果存在两条在最小生成树上的边被换掉了,那么原树会被分成三个联通块。
考虑新加的两条边,保留权值较小的那一条,这样还剩两个连通块。
而删除的两条边至少有一条能联通这两个联通块,所以可以保留那条边。
新加的两条边中权值较大的那一条肯定大于等于我们保留的边,因为它们都起着联通两个连通块的作用,如果小于则可以替换出更小的生成树,与最小生成树的前提冲突。
这样就证明了删除两条边的任意一种方案都存在一种不劣的删除一条边的方案。
对于删除更多原树中边的情况,可以不断找出两条被删除的边,然后找出起对应联通作用的两条新边进行上述操作。
所以仅仅考虑删除原树中一条边的情况即可。
倍增时记录最大值和严格次大值,最后枚举不在原树上的边尝试替换即可。
代码写得 diao 丑,见谅。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 20
#define NN 100005
#define Min(x,y)((x)<(y)?x:y)
#define For(i,x,y)for(i=x;i<=(y);i++)
#define Down(i,x,y)for(i=x;i>=(y);i--)
struct node
{
int next,to,w;
}e[200005];
struct edge
{
bool used;
int x,y,z;
}r[300005];
int fa[NN][N],mx[NN][N],mxx[NN][N],depp[NN],bin[NN],faa[NN],dep[NN],head[NN],g;
int read()
{
int A;
bool K;
char C;
C=A=K=0;
while(C<'0'||C>'9')K|=C=='-',C=getchar();
while(C>'/'&&C<':')A=(A<<3)+(A<<1)+(C^48),C=getchar();
return(K?-A:A);
}
inline void add(int u,int v,int w)
{
e[++g].to=v;
e[g].w=w;
e[g].next=head[u];
head[u]=g;
}
inline bool cmp(edge _,edge __)
{
return _.z<__.z;
}
void build(int u)
{
int i=0,v;
dep[u]=dep[fa[u][0]]+1;
while(1<<++i<=dep[u])
{
fa[u][i]=fa[fa[u][i-1]][i-1];
mx[u][i]=mx[fa[u][i-1]][i-1];
mxx[u][i]=mxx[fa[u][i-1]][i-1];
if(mx[u][i-1]>mx[u][i])mxx[u][i]=mx[u][i],mx[u][i]=mx[u][i-1];
else if(mx[u][i-1]<mx[u][i]&&mx[u][i-1]>mxx[u][i])mxx[u][i]=mx[u][i-1];
if(mxx[u][i-1]<mx[u][i]&&mxx[u][i-1]>mxx[u][i])mxx[u][i]=mxx[u][i-1];
}
for(i=head[u];i;i=e[i].next)
{
v=e[i].to;
if(v==fa[u][0])continue;
fa[v][0]=u;
mx[v][0]=mxx[v][0]=e[i].w;
build(v);
}
}
int find(int p)
{
if(p!=faa[p])faa[p]=find(faa[p]);
return faa[p];
}
inline void unite(int p,int q)
{
if(depp[p]<depp[q])faa[p]=q;
else faa[q]=p;
if(depp[p]==depp[q])depp[p]++;
}
pair<int,int>get(int x,int y)
{
int a,b,i;
a=b=-1;
if(dep[x]<dep[y])swap(x,y);
while(dep[x]>dep[y])
{
i=bin[dep[x]-dep[y]];
if(mx[x][i]>a)b=a,a=mx[x][i];
else if(mx[x][i]<a&&mx[x][i]>b)b=mx[x][i];
if(mxx[x][i]<a&&mxx[x][i]>b)b=mxx[x][i];
x=fa[x][i];
}
Down(i,bin[dep[x]],0)
if(fa[x][i]!=fa[y][i])
{
if(mx[x][i]>a)b=a,a=mx[x][i];
else if(mx[x][i]<a&&mx[x][i]>b)b=mx[x][i];
if(mxx[x][i]<a&&mxx[x][i]>b)b=mxx[x][i];
if(mx[y][i]>a)b=a,a=mx[y][i];
else if(mx[y][i]<a&&mx[y][i]>b)b=mx[y][i];
if(mxx[y][i]<a&&mxx[y][i]>b)b=mxx[y][i];
x=fa[x][i],y=fa[y][i];
}
if(mx[x][0]>a)b=a,a=mx[x][0];
else if(mx[x][0]<a&&mx[x][0]>b)b=mx[x][0];
if(mxx[x][0]<a&&mxx[x][0]>b)b=mxx[x][0];
if(mx[y][0]>a)b=a,a=mx[y][0];
else if(mx[y][0]<a&&mx[y][0]>b)b=mx[y][0];
if(mxx[y][0]<a&&mxx[y][0]>b)b=mxx[y][0];
if(!~b)b=a;
return make_pair(a,b);
}
int main()
{
int n,m,i;
ll tot,ans;
pair<int,int>pa;
ans=tot=0;
n=read(),m=read();
For(i,1,m)
{
r[i].x=read(),r[i].y=read(),r[i].z=read();
ans+=r[i].z;
}
bin[0]=-1;
For(i,1,n)bin[i]=bin[i>>1]+1,faa[i]=i;
sort(r+1,r+m+1,cmp);
For(i,1,m)
if(find(r[i].x)!=find(r[i].y))
{
r[i].used=1;
unite(find(r[i].x),find(r[i].y));
tot+=r[i].z;
add(r[i].x,r[i].y,r[i].z),add(r[i].y,r[i].x,r[i].z);
}
build(1);
For(i,1,m)
if(!r[i].used&&r[i].x!=r[i].y)
{
pa=get(r[i].x,r[i].y);
if(pa.first<r[i].z)ans=Min(ans,tot-pa.first+r[i].z);
else if(pa.second<r[i].z)ans=Min(ans,tot-pa.second+r[i].z);
}
cout<<ans;
return 0;
}