Spark Udaf

 


//两个重点:如何从input和buffer中取出数据,如何将更改好的数据更新到buffer中!

//自定义函数的深入理解和按需自定义,六个方法的作用和执行流程如何
package areatop3

import org.apache.spark.sql.Row

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}

import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

/**
* 直接对city_id city_name两个参数进行处理

*/
class AggregateCityInfoFunction extends UserDefinedAggregateFunction {

//如果有多个参数的话对应的input就会有多个结果,也就是后面inpout(0),input(1)...
override def inputSchema: StructType = StructType(StructField("city_id", StringType) :: StructField("city_name", StringType) :: Nil)

//如果有多个参数的话对应的buffer就会有多个结果,也就是后面buffer(0),buffer(1)
override def bufferSchema: StructType = StructType(StructField("buffer_city_info", StringType) :: Nil)

//返回值类型
override def dataType: DataType = StringType

override def deterministic: Boolean = true //是否一致性检验

//缓存的初始化,例如这里是字符串的拼接,所以初始化时空串,如果是总数统计初始化就是0,如果是存储结构,那么初始化就是新建一个存储结构
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = ""

//单台机器上的数据添加操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    var city_info: String = buffer.getString(0)

    val city_id: String = input.getString(0)

    val city_name: String = input.getString(1)

    val inputElement: String = city_id + "_" + city_name

    if (!city_info.contains(inputElement)) {

        if ("" == city_info) {

            city_info += inputElement
        } else {

            city_info += "," + inputElement
        }
    }
    buffer.update(0, city_info)
}

//缓存和缓存的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    var city_info: String = buffer1.getString(0)

    var city_info2: String = buffer2.getString(0)

    for (inputElement <- city_info2.split(",")) {

        if (!city_info.contains(inputElement)) {

            if ("" == city_info) {

                city_info += inputElement
            } else {

                city_info += "," + inputElement

            }
        }
    }
    buffer1.update(0, city_info + System.currentTimeMillis())
}
//最后的返回结果
override def evaluate(buffer: Row): Any = {

    buffer.getString(0)
}
}

 

上一篇:SparkSQL的UDF函数和UDAF函数


下一篇:Spark UDAF 自定义函数