题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5876
题意:
有一个 n 个点无向图,再给你 m 对顶点, 代表着这 m 对顶点之间没有边, 除此之外每两个点之间都有一条边, 且权值为 1.然后还有一个源点 S, 让你计算源点到其他各点之间的最短距离,如果不存在则输出 -1.也就是说让你在所给的图的补图上求源点到其他各点的最短路径.
思路:
补图上求最短路径算是比较经典的题.在这里所求的最短路其实并不需要用到 dijkstra 之类的算法,由于每条边之间的距离都为 1,每条边的权值一样.那么就可以想到这个做法:
步骤1:根据题意建图.然后建立一个队列,把源点 S 压入队列,其他的各个点存到另一个集合 V 里, 并建立一个数组来存储源点到其他各点的最短距离,初始化为 -1, dis[S] = 0.
步骤2:从队列首部取出一个节点 v,访问所有与节点 v 不相邻的节点 u,把 u 压入到队列尾部, 并从集合 V 中把 u 除去, 更新dis[u] = dis[v] + 1.
步骤3:反复重复步骤 2, 直到队列为空.之后 dis 数组里保存的就是答案.
解决了怎么做之后,还没有完,由于题目给的点数非常大,所以还需要优化下时间.这里比较费时间的就是步骤 2 中的找不相邻的点.这里就可以用 STL 中的 set 来维护集合 V, 也就是未被访问到的点. 初始化 set1 中为所有点(除去源点 S), 当访问的点 v 时, 在 set1 中除去与点 v 相邻的点 u, 并加入到 set2 中,那么 set1 中剩下的点就是所有与点 v 不相邻的点, 依次遍历压入队列. 之后再把 set2 拷贝到 set1, 那么 set1 就是剩下的未被访问到的点,如此反复下去.可以看到,对于每条边只访问一次.每个点也只进一次队列.所以总的时间复杂度为 O(n * m),可以达到要求了.
notes:建立 无向图 的时候每条边要存两次. 所以数组的大小一定是原题所给的边数的两倍! 两倍! 两倍!顺便求一个用链表实现的代码,自己用链表没写出来.(我好菜啊.jpg
代码:
#include <iostream>
#include <cstdio>
#include <queue>
#include <set>
#include <cstring>
#include <algorithm> using namespace std;
typedef long long LL; const int MAXN = ;
const int MAXE = ;
int n, m, T, S;
int dis[MAXN + ];//保存最终的最短距离 int head[MAXN + ], len; //链式前向星
struct NODE {int to; int next; };
NODE edge[ * MAXE + ]; void addedge(int u, int v) { //链式前向星加边
edge[len].to = v;
edge[len].next = head[u];
head[u] = len++;
} void BFS() { //从起点开始BFS
memset(dis, -, sizeof(dis));
queue <int> Qu;
Qu.push(S); dis[S] = ; //起点初始化
set<int> unsed, hep; //用来维护尚未被访问的点
for(int i = ; i <= n; i++) unsed.insert(i);
unsed.erase(S);
while( !Qu.empty() ) {
int tp = Qu.front(); Qu.pop();
for(int k = head[tp]; k != -; k = edge[k].next) { //从 unsed 中除去和当前拓展节点相邻的点,同时加入到临时辅助的集合中
if(unsed.find(edge[k].to) != unsed.end()) {
unsed.erase(edge[k].to);
hep.insert(edge[k].to);
}
}
for(set<int>::iterator it = unsed.begin(); it != unsed.end(); it++) {// unsed 暂时保存的是和当前拓展节点不相邻的点
Qu.push(*it);
dis[*it] = dis[tp] + ;
}
hep.swap(unsed), hep.clear();//从辅助集合中 copy 剩下的未被访问到的点.
}
} int main() {
scanf("%d", &T);
while(T--) {
memset(head, -, sizeof(head));
scanf("%d%d", &n, &m);
int u, v; len = ;
for(int i = , len = ; i < m; i++) {
scanf("%d%d", &u, &v);
addedge(u, v); addedge(v, u);
}
scanf("%d", &S);
BFS();
for(int i = , j = ; i <= n; i++) if(i != S) printf("%d%c", dis[i], " \n"[++j == n - ]);
}
return ;
}