参考:《算法导论》
@Data @AllArgsConstructor public class WeightGraph { //节点名,前驱节点,最短路径 private List<Node<String,String,Integer>> nodes; //节点名,连接节点索引,边权重 private Table<String, Integer, Integer> edgeTable; public static void main(String[] args) { //构建图 WeightGraph graph = buildGraph(); List<Node<String, String, Integer>> nodes = graph.getNodes(); Table<String, Integer, Integer> edgeTable = graph.getEdgeTable(); //已知最短路径的节点集合S HashSet<String> S = new HashSet<>(); //优先队列,以最短路径为排序key Queue<Node<String,String, Integer>> queue = new PriorityQueue<>(Comparator.comparing(Node::getDistance)); //节点信息的初始化,v1=0,其他未知都是无穷大 nodes.forEach(node->{ if (node.getName().equals("v1")) { node.setDistance(0); } }); //队列初始化,将所有节点加入队列 queue.addAll(nodes); while (!queue.isEmpty()) { //第一个获取的就是v1 Node<String,String, Integer> nodeU = queue.remove(); String nodeUName = nodeU.getName(); S.add(nodeUName); Map<Integer, Integer> row = edgeTable.row(nodeUName); //对u->v节点进行松弛操作 //如果v.d大于u.d+w(u,v),则将更小的值更新为v.d for (Map.Entry<Integer, Integer> entry : row.entrySet()) { Integer nodeVIndex = entry.getKey(); Node<String, String, Integer> nodeV = nodes.get(nodeVIndex); Integer weightUV = entry.getValue(); if (nodeV.getDistance() > nodeU.getDistance() + weightUV) { nodeV.setDistance(nodeU.getDistance() + weightUV); nodeV.setPre(nodeUName); } } } for (Node<String, String, Integer> node : nodes) { System.out.println(node.getName()+":"+node.getDistance()); } System.out.println(S); } /** * 初始化图结构 * * @return */ public static WeightGraph buildGraph() { List<String> nodes = Lists.newArrayList("v1", "v2", "v3", "v4", "v5", "v6"); List<Node<String, String, Integer>> nodeList = nodes.stream().map(node -> { Node<String, String, Integer> nodeObj = new Node<>(); nodeObj.setName(node); nodeObj.setPre(null); nodeObj.setDistance(Integer.MAX_VALUE); return nodeObj; }).collect(Collectors.toList()); Table<String, Integer, Integer> edgeTable = HashBasedTable.create(); edgeTable.put("v1", nodes.indexOf("v2"), 10); edgeTable.put("v2", nodes.indexOf("v3"), 7); edgeTable.put("v4", nodes.indexOf("v3"), 4); edgeTable.put("v4", nodes.indexOf("v5"), 7); edgeTable.put("v6", nodes.indexOf("v5"), 1); edgeTable.put("v1", nodes.indexOf("v6"), 3); edgeTable.put("v6", nodes.indexOf("v2"), 2); edgeTable.put("v4", nodes.indexOf("v1"), 3); edgeTable.put("v2", nodes.indexOf("v4"), 5); edgeTable.put("v6", nodes.indexOf("v4"), 6); return new WeightGraph(nodeList,edgeTable); } /** * 节点名,前驱节点,最短路径 * 也用于存储最终的最短路径数据 * * @param <N> * @param <D> */ @Data @AllArgsConstructor @NoArgsConstructor static class Node<N,P, D> { private N name; private P pre; private D distance; } }
输出:
v1:0 v2:5 v3:12 v4:9 v5:4 v6:3