Spark中自定义累加器Accumulator

1. 自定义累加器

自定义累加器需要继承AccumulatorParam,实现addInPlace和zero方法。

例1:实现Long类型的累加器

object LongAccumulatorParam extends AccumulatorParam[Long]{
  override def addInPlace(r1: Long, r2: Long) = {
    println(s"$r1\t$r2")
    r1 + r2
  }

  override def zero(initialValue: Long) = {
    println(initialValue)
    0
  }

  def main(args: Array[String]): Unit = {
    val sc = new SparkContext(new SparkConf().setAppName("testLongAccumulator"))
    val acc = sc.accumulator(0L, "LongAccumulator")
    sc.parallelize(Array(1L,2L,3L,4L,5L)).foreach(acc.add)
    println(acc.value)
    sc.stop()
  }
} 

例2:定义Set[String],可用于记录错误日志

object StringSetAccumulatorParam extends AccumulatorParam[Set[String]]{
  override def addInPlace(r1: Set[String], r2: Set[String]): Set[String] = { r1 ++ r2 }

  override def zero(initialValue: Set[String]): Set[String] = { Set() }
}


object ErrorLogHostSet extends Serializable {
  @volatile private var instanceErr: Accumulator[Set[String]] = null

  def getInstance(sc: SparkContext): Accumulator[Set[String]] = {
    if(null == instanceErr){
      synchronized{
        if(null == instanceErr){
          instanceErr = sc.accumulator(Set[String]())(StringSetAccumulatorParam)
        }
      }
    }
    instanceErr
  }

  def main(args: Array[String]): Unit = {
    val sc = new SparkContext(new SparkConf().setAppName("testSetStringAccumulator"))

    val dataRdd = sc.parallelize(Array("a2","c4","6v","67s","3d","45s","2c6","35d","7c8d9","34dc5"))
    val errorHostSet = getInstance(sc)

    val a = sc.accumulableCollection("a")

    dataRdd.filter(ele => {
      val res = ele.contains("d")
      if(res) errorHostSet += Set(ele)
      res
    }).foreach(println)

    errorHostSet.value.foreach(println)

    sc.stop()
  }
}

2. AccumulableCollection使用

object AccumulableCollectionTest {

  case class Employee(id: String, name: String, dept: String)

  def main(args: Array[String]): Unit = {
    val sc = new SparkContext(new SparkConf().setAppName("AccumulableCollectionTest").setMaster("local[4]"))

    val empAccu = sc.accumulableCollection(mutable.HashMap[String,Employee]())

    val employees = List(
      Employee("10001", "Tom", "Eng"),
      Employee("10002", "Roger", "Sales"),
      Employee("10003", "Rafael", "Sales"),
      Employee("10004", "David", "Sales"),
      Employee("10005", "Moore", "Sales"),
      Employee("10006", "Dawn", "Sales"),
      Employee("10007", "Stud", "Marketing"),
      Employee("10008", "Brown", "QA")
    )

    System.out.println("employee count " + employees.size)

    sc.parallelize(employees).foreach(e => {
      empAccu += e.id -> e
    })

    println("empAccumulator size " + empAccu.value.size)
    empAccu.value.foreach(entry =>
      println("emp id = " + entry._1 + " name = " + entry._2.name))
    sc.stop()
  }

}

 

上一篇:误删libc.so.6文件补救


下一篇:阿里开发者招聘节 | 阿里云CDN团队诚招技术人才啦!