和倍增法求lca差不多,维护每个点往上跳2^i步能到达的点,以及之间的边的最大值和次大值,先求出最小生成树,对于每个非树边枚举其端点在树上的路径的最大值,如果最大值和非树边权值一样则找次大值,然后维护答案即可。
代码
#include<cstdio>
#include<algorithm>
using namespace std;
const int N = ;
const int M = ;
const int inf = ;
int f[N],n,m,i;
int dp,p[N],pre[M],tt[M],ww[M],flag[M];
int deep[N],jump[N][],mi[N][],Mi[N][];
long long ans,Ans;
struct g{
int l,r,v;
}a[M];
void link(int x,int y,int z)
{
dp++;pre[dp]=p[x];p[x]=dp;tt[dp]=y;ww[dp]=z;
}
bool cmp(g a,g b)
{
return a.v<b.v;
}
int gf(int x)
{
while (x!=f[x]) x=f[x];return x;
}
void dfs(int x,int fa,int va)
{
deep[x]=deep[fa]+;
jump[x][]=fa;
mi[x][]=va;
Mi[x][]=-inf;
int i;
for (i=;i<=;i++)
{
jump[x][i]=jump[jump[x][i-]][i-];
mi[x][i]=max(mi[x][i-],mi[jump[x][i-]][i-]);
if (mi[x][i-]==mi[jump[x][i-]][i-])
Mi[x][i]=max(Mi[x][i-],Mi[jump[x][i-]][i-]);
else
if (mi[x][i-]<mi[jump[x][i-]][i-])
Mi[x][i]=max(mi[x][i-],Mi[jump[x][i-]][i-]);
else
Mi[x][i]=max(Mi[x][i-],mi[jump[x][i-]][i-]);
}
i=p[x];
while (i)
{
if (tt[i]!=fa)
dfs(tt[i],x,ww[i]);
i=pre[i];
}
}
void updata(int x,int y,int &ans,int &Ans)
{
if (x>=ans)
{
if (x>ans)
Ans=max(Ans,ans);
Ans=max(Ans,y);
ans=x;
}
else
if (x>Ans)
Ans=max(Ans,x);
}
int get(int a,int b,int c)
{
if (deep[a]<deep[b]) a^=b^=a^=b;
int i,ans=-inf,Ans=-inf;
for (i=;i>=;i--)
if (deep[jump[a][i]]>=deep[b])
{
updata(mi[a][i],Mi[a][i],ans,Ans);
a=jump[a][i];
}
if (a==b)
{
if (ans==c) return Ans;else return ans;
}
for (i=;i>=;i--)
if (jump[a][i]!=jump[b][i])
{
updata(mi[a][i],Mi[a][i],ans,Ans);
updata(mi[b][i],Mi[b][i],ans,Ans);
a=jump[a][i];b=jump[b][i];
}
updata(mi[a][],Mi[a][],ans,Ans);
updata(mi[b][],Mi[b][],ans,Ans);
if (ans==c) return Ans;else return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for (i=;i<=m;i++)
scanf("%d%d%d",&a[i].l,&a[i].r,&a[i].v);
sort(a+,a++m,cmp);
for (i=;i<=n;i++) f[i]=i;
for (i=;i<=m;i++)
{
int l=a[i].l,r=a[i].r;
if (gf(l)!=gf(r))
{
link(l,r,a[i].v);
link(r,l,a[i].v);
f[gf(l)]=gf(r);
Ans+=a[i].v;
flag[i]=;
}
}
dfs(,,);
ans=(long long) inf*inf;
for (i=;i<=m;i++)
if (!flag[i])
{
long long tmp=Ans-get(a[i].l,a[i].r,a[i].v)+a[i].v;
if ((tmp>Ans)&&(tmp<ans)) ans=tmp;
}
printf("%lld\n",ans);
}