0 背景
实际工作中,需要使用最短路径算法,之前一直使用neo4j中的函数,想要和大数据平台结合,就想到了sparkGraphX,之前基本只使用python,不熟悉java和Scala的开发,多方查阅和学习,特此做个记录。
1 关于开发环境
idea-scala + spark的jar包,在scala工程中导入spark的jar包,就可以使用spark相关的函数
2 网络数据准备
为了便于迁移,这里使用CSV文件存储网络的节点和边。 节点数据nodes.csv如下:
node_id,nodes 1,v1 2,v2 3,v3 4,v4 5,v5 6,v6 7,v7
边数据edges.csv如下:
source,target,length 1,2,2 1,4,1 2,4,3 2,5,10 4,5,2 4,3,2 4,6,8 4,7,4 3,1,4 3,6,5 5,7,6 7,6,1
3 网络构建
读取节点和边的代码如下:
class fileExample{ def fileReader(string: String): Unit ={ // 读取csv文件内容 val ofile = Source.fromFile(string) val lines = ofile.getLines() lines.foreach(println) } def readerToArray(string: String):Array[String]={ val ofile = Source.fromFile(string) val lines = ofile.getLines() lines.toArray } } class getGraph { def nodes():Seq[(Long, String)]={ val infile = "./dataset/nodes.csv" val obj = new fileExample() val context = obj.readerToArray(infile) println("nodes_context:"+context.length) context.foreach(println) var seq = Seq((0L, "")) for (line <- context.slice(1, context.length)){ var nid = line.split(",")(0) var nme = line.split(",")(1) seq = seq :+ (nid.toLong, nme) } return seq.slice(1, seq.length) } def edges():Seq[(Long, Long, Int)]={ val infile = "./dataset/edges.csv" val obj = new fileExample() val context = obj.readerToArray(infile) println("edge_context:"+context.length) context.foreach(println) var seq = Seq((0L, 0L, 0)) for (line <- context.slice(1, context.length)){ var fid = line.split(",")(0).toLong var tid = line.split(",")(1).toLong var wht = line.split(",")(2).toInt seq = seq :+ (fid, tid, wht) } return seq.slice(1, seq.length) } }
构建SparkGraphX的图的代码如下:
class graphExample { val conf = new SparkConf().setAppName("Example").setMaster("local") val sc = new SparkContext(conf) def example(): Unit = { println("start") val graph = new getGraph() var nodes = graph.nodes() nodes.foreach(println) println("->nodes\n->edges") var edges = graph.edges() edges.foreach(println) // val nn = Seq((1L, ("Alice", 27)),(2L, ("Bob", 27))) var nn = Seq((0L, ("0", 0L))) for (node <- nodes) { nn = nn :+ (node._1, (node._2, node._1)) } val gnodes: RDD[(Long, (String, Long))] = sc.parallelize(nn.slice(1, nn.length - 1)) // val gg = Seq(Edge(2L, 1L, 7), Edge(1L, 2L, 2)) var gg = Seq(Edge(0L, 0L, 0)) for (e <- edges) { gg = gg :+ Edge(e._1, e._2, e._3) } var gedges: RDD[Edge[Int]] = sc.parallelize(gg.slice(1, gg.length)) val gx: Graph[(String, Long), Int] = Graph(gnodes, gedges) // 测试图 val tmp = gx.edges.filter { case Edge(f, t, w) => w > 3 }.count println("tmp:" + tmp)
4 路径查询
基于构建的graphX进行最短路径查询的过程如下:
// Initialize the graph val sourceId : VertexId = 1L val initialGraph: Graph[(Double, List[VertexId]), Int] = gx.mapVertices((id, _) => if (id == sourceId) (0.0, List[VertexId](sourceId)) else (Double.PositiveInfinity, List[VertexId]())) val sssp = initialGraph.pregel((Double.PositiveInfinity, List[VertexId]()), Int.MaxValue, EdgeDirection.Out)( // Vertex Program (id, dist, newDist) => if (dist._1 < newDist._1) dist else newDist, // Send Message triplet => { if (triplet.srcAttr._1 < triplet.dstAttr._1 - triplet.attr) { Iterator((triplet.dstId, (triplet.srcAttr._1 + triplet.attr, triplet.srcAttr._2 :+ triplet.dstId))) } else { Iterator.empty } }, //Merge Message (a, b) => if (a._1 < b._1) a else b) println(sssp.vertices.collect.mkString("\n")) //println(sssp.vertices.filter{case(id,v) => id ==3}) val end_ID = 6L println(end_ID) println(sssp.vertices.collect.filter{case(id,v) => id == end_ID}.mkString("\n"))
5 完整代码
整个DEMO的完整代码如下:
import org.apache.spark.graphx._ import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.SparkConf import org.apache.spark.graphx.lib.ShortestPaths import scala.io.Source object graphExample{ def main(args: Array[String]): Unit = { val exam = new graphExample() exam.example() } } class graphExample { val conf = new SparkConf().setAppName("Example").setMaster("local") val sc = new SparkContext(conf) def example(): Unit = { println("start") val graph = new getGraph() var nodes = graph.nodes() nodes.foreach(println) println("->nodes\n->edges") var edges = graph.edges() edges.foreach(println) // val nn = Seq((1L, ("Alice", 27)),(2L, ("Bob", 27))) var nn = Seq((0L, ("0", 0L))) for (node <- nodes) { nn = nn :+ (node._1, (node._2, node._1)) } val gnodes: RDD[(Long, (String, Long))] = sc.parallelize(nn.slice(1, nn.length - 1)) // val gg = Seq(Edge(2L, 1L, 7), Edge(1L, 2L, 2)) var gg = Seq(Edge(0L, 0L, 0)) for (e <- edges) { gg = gg :+ Edge(e._1, e._2, e._3) } var gedges: RDD[Edge[Int]] = sc.parallelize(gg.slice(1, gg.length)) val gx: Graph[(String, Long), Int] = Graph(gnodes, gedges) val tmp = gx.edges.filter { case Edge(f, t, w) => w > 3 }.count println("tmp:" + tmp) val sourceId : VertexId = 1L val initialGraph: Graph[(Double, List[VertexId]), Int] = gx.mapVertices((id, _) => if (id == sourceId) (0.0, List[VertexId](sourceId)) else (Double.PositiveInfinity, List[VertexId]())) val sssp = initialGraph.pregel((Double.PositiveInfinity, List[VertexId]()), Int.MaxValue, EdgeDirection.Out)( // Vertex Program (id, dist, newDist) => if (dist._1 < newDist._1) dist else newDist, // Send Message triplet => { if (triplet.srcAttr._1 < triplet.dstAttr._1 - triplet.attr) { Iterator((triplet.dstId, (triplet.srcAttr._1 + triplet.attr, triplet.srcAttr._2 :+ triplet.dstId))) } else { Iterator.empty } }, //Merge Message (a, b) => if (a._1 < b._1) a else b) println(sssp.vertices.collect.mkString("\n")) //println(sssp.vertices.filter{case(id,v) => id ==3}) val end_ID = 6L println(end_ID) println(sssp.vertices.collect.filter{case(id,v) => id == end_ID}.mkString("\n")) // for (elem <- edges) {println(elem)} } } class fileExample{ def fileReader(string: String): Unit ={ // 读取文件内容 val ofile = Source.fromFile(string) val lines = ofile.getLines() lines.foreach(println) } def readerToArray(string: String):Array[String]={ val ofile = Source.fromFile(string) val lines = ofile.getLines() lines.toArray } } class getGraph { def nodes():Seq[(Long, String)]={ val infile = "./dataset/nodes.csv" val obj = new fileExample() val context = obj.readerToArray(infile) println("nodes_context:"+context.length) context.foreach(println) var seq = Seq((0L, "")) for (line <- context.slice(1, context.length)){ var nid = line.split(",")(0) var nme = line.split(",")(1) seq = seq :+ (nid.toLong, nme) } return seq.slice(1, seq.length) } def edges():Seq[(Long, Long, Int)]={ val infile = "./dataset/edges.csv" val obj = new fileExample() val context = obj.readerToArray(infile) println("edge_context:"+context.length) context.foreach(println) var seq = Seq((0L, 0L, 0)) for (line <- context.slice(1, context.length)){ var fid = line.split(",")(0).toLong var tid = line.split(",")(1).toLong var wht = line.split(",")(2).toInt seq = seq :+ (fid, tid, wht) } return seq.slice(1, seq.length) } }