在很多场景下我们可能都会遇到根据用户某信息获取其位置信息的徐求,比较常见的可能就是根据ip或电话号码来计算了。这里介绍以ip地址来计算归属地,以具体需求来说明。使用scala语言编写spark程序来实现。
需求:根据访问日志的ip地址计算用户归属地信息,并且按其归属地分类统计访问量。
(1)、获取ip地址对应地理信息,作为字典数据。对于省市地区来说每一个行政区域都有一个ip段,将ip转换为十进制来计算。
(2)、加载规则,整理规则,取出计算需要的字段,将数据缓存到内存中。、
(3)、将访问的log与ip规则进行匹配(二分法)
(4)、分组
(5)、聚合(省份)
Log数据示例:
字典数据示例:
首先构造一个工具类,将ip规则数据进行处理:
object IPUtils { def ip2Long(ip: String): Long = { val fragments = ip.split("[.]") var ipNum = 0L for (i <- 0 until fragments.length){ ipNum = fragments(i).toLong | ipNum << 8L } ipNum } def readIpRules(path: String): Array[(Long, Long, String)] = { //读取ip规则 val bf: BufferedSource = Source.fromFile(path) val lines: Iterator[String] = bf.getLines() //对ip规则进行整理,并放入到内存 val rules: Array[(Long, Long, String)] = lines.map(line => { val fileds = line.split("[|]") val startNum = fileds(2).toLong val endNum = fileds(3).toLong val province = fileds(6) (startNum, endNum, province) }).toArray rules } /*
*使用二分法查找
*/ def binarySearch(lines: Array[(Long, Long, String)], ip: Long) : Int = { var low = 0 var high = lines.length - 1 while (low <= high) { val middle = (low + high) / 2 if ((ip >= lines(middle)._1) && (ip <= lines(middle)._2)) return middle if (ip < lines(middle)._1) high = middle - 1 else { low = middle + 1 } } -1 }
/*
*将计算结果存储mysql
*/ def data2MySQL(it: Iterator[(String, Int)]): Unit = { //一个迭代器代表一个分区,分区中有多条数据 //先获得一个JDBC连接 val conn: Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?characterEncoding=UTF-8", "root", "123568") //将数据通过Connection写入到数据库 val pstm: PreparedStatement = conn.prepareStatement("INSERT INTO access_log VALUES (?, ?)") //将分区中的数据一条一条写入到MySQL中 it.foreach(tp => { pstm.setString(1, tp._1) pstm.setInt(2, tp._2) pstm.executeUpdate() }) //将分区中的数据全部写完之后,在关闭连接 if(pstm != null) { pstm.close() } if (conn != null) { conn.close() } } }
使用spark调用方法进行计算:
1 object IpLoactAndAggByProvince { 2 3 def main(args: Array[String]): Unit = { 4 5 val conf = new SparkConf().setAppName("IpLoaction1").setMaster("local[4]") 6 7 val sc = new SparkContext(conf) 8 9 //取到HDFS中的ip规则 10 val rulesLines:RDD[String] = sc.textFile(args(0)) 11 //整理ip规则数据 12 val ipRulesRDD: RDD[(Long, Long, String)] = rulesLines.map(line => { 13 val fields = line.split("[|]") 14 val startNum = fields(2).toLong 15 val endNum = fields(3).toLong 16 val province = fields(6) 17 (startNum, endNum, province) 18 }) 19 20 //将分散在多个Executor中的部分IP规则收集到Driver端 21 val rulesInDriver: Array[(Long, Long, String)] = ipRulesRDD.collect() 22 23 //将Driver端的数据广播到Executor 24 //广播变量的引用(还在Driver端) 25 val broadcastRef: Broadcast[Array[(Long, Long, String)]] = sc.broadcast(rulesInDriver) 26 27 //创建RDD,读取访问日志 28 val accessLines: RDD[String] = sc.textFile(args(1)) 29 30 //整理数据 31 val proviceAndOne: RDD[(String, Int)] = accessLines.map(log => { 32 //将log日志的每一行进行切分 33 val fields = log.split("[|]") 34 val ip = fields(1) 35 //将ip转换成十进制 36 val ipNum = MyUtils.ip2Long(ip) 37 //进行二分法查找,通过Driver端的引用或取到Executor中的广播变量 38 //(该函数中的代码是在Executor中别调用执行的,通过广播变量的引用,就可以拿到当前Executor中的广播的规则了) 39 //Driver端广播变量的引用是怎样跑到Executor中的呢? 40 //Task是在Driver端生成的,广播变量的引用是伴随着Task被发送到Executor中的 41 val rulesInExecutor: Array[(Long, Long, String)] = broadcastRef.value 42 //查找 43 var province = "未知" 44 val index = MyUtils.binarySearch(rulesInExecutor, ipNum) 45 if (index != -1) { 46 province = rulesInExecutor(index)._3 47 } 48 (province, 1) 49 }) 50 51 //聚合 52 //val sum = (x: Int, y: Int) => x + y 53 val reduced: RDD[(String, Int)] = proviceAndOne.reduceByKey(_+_) 54 55 //将结果打印 56 //val r = reduced.collect() 57 //println(r.toBuffer) 58 59 60 /** 61 reduced.foreach(tp => { 62 //将数据写入到MySQL中 63 //问?在哪一端获取到MySQL的链接的? 64 //是在Executor中的Task获取的JDBC连接 65 val conn: Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?charatorEncoding=utf-8", "root", "123568") 66 //写入大量数据的时候,有没有问题? 67 val pstm = conn.prepareStatement("...") 68 pstm.setString(1, tp._1) 69 pstm.setInt(2, tp._2) 70 pstm.executeUpdate() 71 pstm.close() 72 conn.close() 73 }) 74 */ 75 76 //一次拿出一个分区(一个分区用一个连接,可以将一个分区中的多条数据写完在释放jdbc连接,这样更节省资源) 77 // reduced.foreachPartition(it => { 78 // val conn: Connection = DriverManager.getConnection("jdbc:mysql://localhost:3306/bigdata?characterEncoding=UTF-8", "root", "123568") 79 // //将数据通过Connection写入到数据库 80 // val pstm: PreparedStatement = conn.prepareStatement("INSERT INTO access_log VALUES (?, ?)") 81 // //将一个分区中的每一条数据拿出来 82 // it.foreach(tp => { 83 // pstm.setString(1, tp._1) 84 // pstm.setInt(2, tp._2) 85 // pstm.executeUpdate() 86 // }) 87 // pstm.close() 88 // conn.close() 89 // }) 90 91 reduced.foreachPartition(it => MyUtils.data2MySQL(it)) 92 93 94 sc.stop() 95 96 97 98 } 99 }