\(happybean\)
题目大意
给定一个有向完全图,其中 \(u\rightarrow v\) 的边权为 \(a_u\) 。
进行 \(m\) 次修改,第 \(i\) 次修改给定 \(x,y,z\) ,将 \(x\rightarrow y\) 的有向边边权改为 \(z\) 。
求所有点对 \((i,j)\) 且 \(i\neq j\) 的最短路之和。
对于所有数据,满足 \(1\leq n\leq 10^5,1\leq m\leq 3000,0\leq a_i\leq 10^5\) 。
分析
考虑暴力做法,\(Floyd\) 能够很轻松的过掉 \(O(n^3)\) ,但显然是过不去这题的。
转化
我们发现,在整张图中,除了边权背修改后的点,事实上我们还有很多没有改动的点,于是我们也许能够考虑将边权被改变的子图导出来考虑。
若将其忽视,显然我们能够得到若干个连通块。
怎么计算 \((x,y)\) 的最短路:
-
若 \(x\) 与 \(y\) 不属于同一个连通块,则 \(x\rightarrow\) 的路径如下:
- 若 \(dis_i\) 表示 \(x\rightarrow i\) 的最短路,那么 \(x\rightarrow y\) 的最短路显然是 \(min(dis_i+a_i)\) 其中 \(i\) 是与 \(x\) 连通块相同的点,这不难理解。
于是,我们好像把问题转化为了求单一连通块内的最短路。
对于该最短路,有一个细节,\(x\rightarrow y\) 的路径,可以是 \(x\rightarrow z\) 再从 \(z\rightarrow y\) ,其中 \(z\) 是连通块外一点,可以发现 \(z\) 一定满足 \(a_z\) 最小。
优化
由于 \(m\leq 3000\) ,且我们要求的是全源最短路,且边数巨大,普通的最短路算法显然无法解决,考虑用一颗线段树来模拟 \(DJ\) 的过程,更新时,只需要单点更新当前点有边相连的点,对于其他点可以直接区间修改(注:没有边相连的点实际上是他们之间的连边没有被修改) 。
这样我们就能做到 \(O(m^2log m)\) ,过掉这道题。
统计答案时注意统计连通块内自己的答案。
加上转化中的说的那个细节,总的时间复杂度变成了 \(O(n+m^2logm)\) 。
CODE
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+10,M=3e3+10,INF=1e9;
inline int read()
{
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w*=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
struct tree{ int mi,pos,tag,sum; }tr[4*M];
struct edge{ int to,nex,w; }e[M];
struct node{
int x,y;
bool operator <(const node A)const { return x<A.x; }
};
int n,m;
long long ans;
int X,Y,Z;
int a[N],arr[N],smn[N],fa[N],id[N];
int dis[M];
int tot,first[N];
vector<int> v[N]; //v[i]储存连同块i内的节点
vector<node> E[M];
inline int find(int x)
{
if(fa[x]==x) return x;
return fa[x]=find(fa[x]);
}
inline void Add(int x,int y,int z)
{
e[++tot].nex=first[x];
first[x]=tot;
e[tot].to=y,e[tot].w=z;
}
inline void update(int k)
{
if(!tr[k].sum) { tr[k].mi=INF; return; }
if(tr[k*2].mi<tr[k*2+1].mi) tr[k].mi=tr[k*2].mi,tr[k].pos=tr[k*2].pos;
else tr[k].mi=tr[k*2+1].mi,tr[k].pos=tr[k*2+1].pos;
}
inline void Benew(int k,int v)
{
if(!tr[k].sum) return;
tr[k].tag=min(tr[k].tag,v);
tr[k].mi=min(tr[k].mi,v);
}
inline void pushdown(int k)
{
if(tr[k].tag^INF){
Benew(k*2,tr[k].tag),Benew(k*2+1,tr[k].tag);
tr[k].tag=INF;
}
}
inline void delet(int k,int l,int r) //删除操作
{
--tr[k].sum;
if(l==r) { tr[k].pos=0,tr[k].mi=INF; return; }
pushdown(k);
int mid=(l+r)/2;
if(X<=mid) delet(k*2,l,mid);
else delet(k*2+1,mid+1,r);
update(k);
}
inline void change(int k,int l,int r)
{
if(!tr[k].sum) return;
if(l>=X&&r<=Y) return Benew(k,Z);
pushdown(k);
int mid=(l+r)/2;
if(X<=mid) change(k*2,l,mid);
if(Y>mid) change(k*2+1,mid+1,r);
update(k);
}
inline void build(int k,int l,int r)
{
tr[k].sum=r-l+1,tr[k].tag=INF;
if(l==r) { tr[k].mi=INF-1,tr[k].pos=l; return; }
int mid=(l+r)/2;
build(k*2,l,mid),build(k*2+1,mid+1,r);
update(k);
}
inline void Solve(int x)
{
for(register int i=0;i<v[x].size();i++) id[v[x][i]]=i+1;
int lim=v[x].size();
for(register int i=1;i<=lim;i++) E[i].clear();
for(register int i=0;i<lim;i++)
for(register int j=first[v[x][i]];j;j=e[j].nex)
E[i+1].push_back((node){id[e[j].to],e[j].w}); //新连边
for(register int i=1;i<=lim;i++) sort(E[i].begin(),E[i].end()); //排序方便区间修改
for(register int i=1;i<=lim;i++){
build(1,1,lim);
X=Y=i,Z=0;
change(1,1,lim);
//线段树维护一个 DJ 过程-
for(register int j=1;j<=lim;j++){
int z=tr[1].pos; //找到当前 dis 最小的节点
dis[z]=tr[1].mi; //给该点赋值
// printf("%d %d\n",z,dis[z]);
X=z,delet(1,1,lim); //注意删除该点
int temp=0; //temp维护上次修改的末端
for(register int k=0;k<E[z].size();k++){
if(E[z][k].x-1>temp) X=temp+1,Y=E[z][k].x-1,Z=dis[z]+a[v[x][z-1]],change(1,1,lim); //区间修改
X=Y=E[z][k].x,Z=dis[z]+min(a[v[x][z-1]]+min(arr[x-1],smn[x+1]),E[z][k].y),change(1,1,lim),temp=E[z][k].x; //单点修改
}
if(temp^lim) X=temp+1,Y=lim,Z=dis[z]+a[v[x][z-1]],change(1,1,lim);
}
int mi=a[v[x][i-1]];
for(register int j=1;j<=lim;j++) if(j^i) ans+=dis[j],mi=min(mi,dis[j]+a[v[x][j-1]]);
ans+=1ll*mi*(n-v[x].size());
}
}
int main()
{
n=read(),m=read();
for(register int i=1;i<=n;i++) a[i]=read(),fa[i]=i;
for(register int i=1;i<=m;i++){
int x=read(),y=read(),z=read();
Add(x,y,z);
fa[find(x)]=find(y); //并查集找连通块
}
for(register int i=1;i<=n;i++) v[find(i)].push_back(i);
arr[0]=INF;
for(register int i=1;i<=n;i++){
arr[i]=arr[i-1];
if(i==find(i))
for(register int j=0;j<v[i].size();j++) arr[i]=min(arr[i],a[v[i][j]]);
}
smn[n+1]=INF;
for(register int i=n;i;i--){
smn[i]=smn[i+1];
if(i==find(i))
for(register int j=0;j<v[i].size();j++) smn[i]=min(smn[i],a[v[i][j]]);
}
for(register int i=1;i<=n;i++) if(i==find(i)) Solve(i);
printf("%lld\n",ans);
return 0;
}