//两个重点:如何从input和buffer中取出数据,如何将更改好的数据更新到buffer中!
//自定义函数的深入理解和按需自定义,六个方法的作用和执行流程如何
package areatop3
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
/**
* 直接对city_id city_name两个参数进行处理
*/
class AggregateCityInfoFunction extends UserDefinedAggregateFunction {
//如果有多个参数的话对应的input就会有多个结果,也就是后面inpout(0),input(1)...
override def inputSchema: StructType = StructType(StructField("city_id", StringType) :: StructField("city_name", StringType) :: Nil)
//如果有多个参数的话对应的buffer就会有多个结果,也就是后面buffer(0),buffer(1)
override def bufferSchema: StructType = StructType(StructField("buffer_city_info", StringType) :: Nil)
//返回值类型
override def dataType: DataType = StringType
override def deterministic: Boolean = true //是否一致性检验
//缓存的初始化,例如这里是字符串的拼接,所以初始化时空串,如果是总数统计初始化就是0,如果是存储结构,那么初始化就是新建一个存储结构
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = ""
//单台机器上的数据添加操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
var city_info: String = buffer.getString(0)
val city_id: String = input.getString(0)
val city_name: String = input.getString(1)
val inputElement: String = city_id + "_" + city_name
if (!city_info.contains(inputElement)) {
if ("" == city_info) {
city_info += inputElement
} else {
city_info += "," + inputElement
}
}
buffer.update(0, city_info)
}
//缓存和缓存的合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
var city_info: String = buffer1.getString(0)
var city_info2: String = buffer2.getString(0)
for (inputElement <- city_info2.split(",")) {
if (!city_info.contains(inputElement)) {
if ("" == city_info) {
city_info += inputElement
} else {
city_info += "," + inputElement
}
}
}
buffer1.update(0, city_info + System.currentTimeMillis())
}
//最后的返回结果
override def evaluate(buffer: Row): Any = {
buffer.getString(0)
}
}