题面
题解
题意
对于每条道路,求出有多少条最短路经过它。
解析
先看看数据范围,不算大。
所以我们分别以每个点为起点,用 \(SPFA\) 求出每个点到源点的最短距离。
不难发现对于边 \((u,v,w)\),如果 \(dis[u] + w == dis[v]\),那么这条边一定在源点到 \(v\) 的最短路上。
而不满足这个条件的边,显然对答案毫无贡献,所以我们可以在跑完最短路后忽略不满足条件的这些边。
这样原图就变成了一副 \(DAG\) 图,因为不存在 \(dis[u] + w == dis[v]\) ,且 \(dis[v] + w == dis[u]\)。
在 \(DAG\) 图上我们就可以跑拓扑排序了。对于每个点求出从源点到达它的方案数,再按照拓扑排序的逆序求出从后面的节点到达当前节点的方案数。
原因请看下图:
从源点到 \(u\) 的方案数记为 \(cnt1[u]\),从后面的节点到达 \(v\) 的方案数记为 \(cnt2[v]\) ,不难发现根据乘法原理,这条从 \(u\) 到 \(v\) 的边就有 \(cnt1[u] \times cnt2[v]\) 条最短路经过了它。因为当前图只保留了在最短路上的边,这些方案数确实都是最短路,所以答案是正确的。
代码
不知道我按照正常的 \(SPFA\) 的写法为什么只能过 \(40\) 分,这份代码里面的 \(SPFA\) 的内容是参考了其它题解的写法的。
#include<cstdio>
#include<cstring>
#include<queue>
#define re register
using namespace std;
const int N = 1505,M = 5005;
const int mod = 1e9 + 7;
struct edge {
int from,next,to,w;
}a[M];
int head[N],dis[N],ans[M],que[M << 1],n,m,a_size = 0;
bool vis[N],is[M];
inline void add(int u,int v,int w) {
a[++a_size] = (edge){u,head[u],v,w};
head[u] = a_size;
}
void SPFA(int s) {
memset(dis,0x3f,sizeof(dis));
memset(is,0,sizeof(is)); re int len;
dis[que[len = 1] = s] = 0; vis[s] = true;
for (re int i = 1; i <= len; i++) {
int u = que[i]; vis[u] = 0;
for (re int e = head[u],v; e; e = a[e].next)
if (dis[u] + a[e].w < dis[v = a[e].to]) {
dis[v] = dis[u] + a[e].w;
if (!vis[v]) vis[que[++len] = v] = 1;
}
}
for (re int i = 1; i <= m; i++)
if (dis[a[i].from] + a[i].w == dis[a[i].to])
is[i] = 1;
}
int deg[N],cnt1[N],cnt2[N],ord[N],len;
queue<int> q;
void tuopu(int s) {
while(!q.empty()) q.pop();
memset(deg,0,sizeof(deg));
memset(cnt1,0,sizeof(cnt1));
memset(cnt2,0,sizeof(cnt2)); len = 0;
for(re int i = 1; i <= m; i++) if(is[i]) deg[a[i].to]++;
q.push(s); cnt1[s] = 1;
while(!q.empty()) {
int x = q.front();
q.pop(); ord[++len] = x;
for(int i = head[x]; i; i = a[i].next) {
if(!is[i]) continue;
int y = a[i].to;
cnt1[y] = (cnt1[y] + cnt1[x]) % mod;
if(--deg[y] == 0) q.push(y);
}
}
for(re int j = len; j >= 1; j--) {
re int x = ord[j]; cnt2[x]++;
for(re int i = head[x]; i; i = a[i].next) {
if(!is[i]) continue;
cnt2[x] = (cnt2[x] + cnt2[a[i].to]) % mod;
}
}
}
inline int read() {
re int x = 0,flag = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')flag = -1;ch = getchar();}
while(ch >='0' && ch <='9'){x = (x << 3) + (x << 1) + ch - 48;ch = getchar();}
return x * flag;
}
int main() {
n = read(),m = read();
for(re int i = 1,u,v,w; i <= m; i++) {
u = read(),v = read(),w = read();
add(u,v,w);
}
for(re int i = 1; i <= n; i++) {
SPFA(i); tuopu(i);
for(re int j = 1; j <= m; j++)
if(is[j]) ans[j] = (ans[j] + 1ll * cnt1[a[j].from] * cnt2[a[j].to] % mod) % mod;
}
for(re int i = 1; i <= m; i++) printf("%d\n",ans[i]);
return 0;
}