一、题目
二、解法
你感觉这道题有点像生成树模型,但是因为边有方向所以麻烦。
可以巧妙地转化成无向生成树模型,我们把 \((i,j)\) 之间边的权值设置成 \(a_i+a_j\),那么如果是 \(i\) 把 \(j\) 拉进连通块,我们多算了 \(a_j\),如果是 \(j\) 把 \(i\) 拉进连通块我们多算了 \(a_i\),所以每个点的点权会被多算一次,我们增加 \(a_{n+1}=0\) 连向所有点来表示直接进入组织的情况,那么用最大生成树减去点权和就是答案。
那么问题转化成了求最大生成树,可以考虑 \(\tt kruskal\) 算法,也就是从大到小枚举边权 \(w\),接下来就看哪些边有这种权值,其实就是把 \(w\) 的二进制拆分成两个不相交的子集,那么就可以套子集枚举了,用并查集暴力连起来即可,时间复杂度 \(O(3^{18}\cdot \alpha)\)
还有一种更好的方法,考虑 \(\tt Borůvka\) 算法,大概内容是对于当前的每一个连通块寻找权值最大的一个出边,这一轮把找到的边都连起来,因为每次连通块个数减半,所以一共只会有 \(O(\log n)\) 轮,复杂度关键在于找边。
二进制的东西不好用数据结构维护的,干脆每一轮我们现场算。注意到边权是 \(a_i+a_j\),我们先固定 \(a_i\),然后对于 \(i\) 关于全集的补集 \(s\),我们找到其中最大的 \(a_j\),考虑它们的连边,但是要注意 \(a_i\) 和 \(a_j\) 需要在不同的连通块内。
这相当于一个子集内求最大值的问题,可以用 \(\tt fwt\) 正变换的方法来做,也就是我们从小到大枚举数位 \(i\),然后对于包含数位 \(i\) 的集合 \(s\) 考虑 \(s-2^i\) 来更新 \(s\),这本质上是一个分治统计的过程,时间复杂度 \(O(2^{18}\cdot 18)\)
因为不能在同一连通块内所以我们要算的是不在同一联通块内的最大值和次大值,时间复杂度 \(O(2^{18}\cdot18\cdot\log n)\)
最后小结一下吧,为什么 \(\tt Borůvka\) 算法能做到更优的复杂度?我觉得是这道题的边权是和点权高度相关的,而且我们不是很清楚具体的边,这样 \(\tt kruskal\) 算法可能不优,而 \(\tt prim\) 算法是单源扩展所以可能也不行,\(\tt Borůvka\) 算法是和点高度相关的。
#include <cstdio>
#include <iostream>
using namespace std;
#define pii pair<int,int>
#define make make_pair
#define fi first
#define se second
const int M = 300005;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,cnt,a[M],fa[M];pii tmp[M],b[M][2];long long ans;
int find(int x)
{
if(x!=fa[x]) fa[x]=find(fa[x]);
return fa[x];
}
void merge(int u,int v)
{
int x=find(u),y=find(v);fa[x]=y;
}
void solve()
{
for(int i=0;i<m;i++)
b[i][0]=b[i][1]=tmp[i]=make(-1,-1);
for(int i=1;i<=n;i++)
{
int id=find(i);
if(b[a[i]][0].se==-1)
b[a[i]][0]=make(a[i],id);
else if(b[a[i]][1].se==-1 && b[a[i]][0].se!=id)
b[a[i]][1]=make(a[i],id);
}
for(int i=0;i<18;i++)
for(int s=0;s<m;s++) if((s>>i)&1)
{
int t=s^(1<<i);
for(int j=0;j<2;j++)
{
if(b[t][j].fi>b[s][0].fi)
{
if(b[s][0].se!=b[t][j].se) b[s][1]=b[s][0];
b[s][0]=b[t][j];
}
else if(b[t][j].fi>b[s][1].fi && b[s][0].se!=b[t][j].se)
b[s][1]=b[t][j];
}
}
for(int i=1;i<=n;i++)
{
int s=(m-1)^a[i],id=find(i);
if(b[s][0].fi!=-1 && b[s][0].fi+a[i]>tmp[id].fi && b[s][0].se!=id)
tmp[id]=b[s][0],tmp[id].fi+=a[i];
if(b[s][1].fi!=-1 && b[s][1].fi+a[i]>tmp[id].fi && b[s][1].se!=id)
tmp[id]=b[s][1],tmp[id].fi+=a[i];
}
for(int i=1;i<=n;i++)
if(tmp[i].fi!=-1 && find(i)!=find(tmp[i].se))
merge(i,tmp[i].se),ans+=tmp[i].fi,cnt--;
}
signed main()
{
cnt=n=read()+1;m=1<<18;
for(int i=1;i<=n;i++)
fa[i]=i;
for(int i=1;i<n;i++)
a[i]=read(),ans-=a[i];
while(cnt>1) solve();
printf("%lld\n",ans);
}