1.继承UserDefinedAggregateFunction类,多输入一输出。
package sparkRdd_practice
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
/**
* @Description * @Author 黄仁议<613024710@qq.com>
* @Version V1.0
* @Since 1.0
* @Date 2019/6/4 0004 17:39
* @Description * @ClassName UDAF1
*/
//自定义弱类型,使用dataframe
//多个输入,对应一个输出,实现计算年龄均值
class UDAF1 extends UserDefinedAggregateFunction{
//定义输入参数的数据的schema
override def inputSchema: StructType = {
StructType(List(StructField("age",IntegerType,true)))
}
//分区中聚合的时候产生的中间结果的schema
//(age1+age1,1+1)
override def bufferSchema: StructType = {
StructType(StructField("sum",IntegerType)::StructField("count",IntegerType)::Nil)
// StructType(List(StructField("sum",IntegerType),StructField("count",IntegerType)))
// new StructType().add("sum",IntegerType).add("count",IntegerType)
}
//定义最后聚合返回结果的数据类型
override def dataType: DataType = DoubleType
//多个确定的输入,输出结果是否是确定的
override def deterministic: Boolean = true
//初始化中间结果对象
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化年龄聚合指
buffer(0) = 0
//初始化人的个数
buffer(1) = 0
}
//处理分区中的每一条数据,聚合到中间结果
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if(!input.isNullAt(0)) {
buffer(0) = buffer.getInt(0) + input.getInt(0)
buffer(1) = buffer.getInt(1) + 1
}
}
//每一个分区聚合后的结果再汇总
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
}
//获取最后聚合结果,即返回值
override def evaluate(buffer: Row): Double = {
buffer.getInt(0)/buffer.getInt(1).toDouble
}
}
object UDAFDemo {
def main(args: Array[String]): Unit = {
val sparkSession = SparkSession.builder().appName("UDFDemo").master("local[2]").getOrCreate()
val frame: DataFrame = sparkSession.read.json("jsonres")
//注册自定义聚合函数
sparkSession.udf.register("myAvg",new UDAF1)
frame.createOrReplaceTempView("t_people")
sparkSession.sql("select myAvg(age) as ageavg FROM t_people").show()
sparkSession.stop()
}
}
2.继承Aggregator,输出与输入是一对一
package sparkRdd_practice
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* @Description * @Author 黄仁议<613024710@qq.com>
* @Version V1.0
* @Since 1.0
* @Date 2019/6/4 0004 17:42
* @Description * @ClassName UDAF2
*/
//输入数据类型,中间结果类型,返回结果类型
case class Average(var sum:Int,var count:Int)
class UDAF2 extends Aggregator[People,Average,Double]{
//初始化中间结果
override def zero: Average = Average(0,0)
//分区内聚合,把每一条数据聚合到中间结果对象
override def reduce(b: Average, a: People): Average = {
b.sum = b.sum + a.age
b.count = b.count + 1
b
}
//分区结果汇总
override def merge(b1: Average, b2: Average): Average = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count+b2.count
b1
}
//返回最后聚合的结果
override def finish(reduction: Average): Double = reduction.sum/reduction.count.toDouble
//定义中间结果和返回结果的编码方式
override def bufferEncoder: Encoder[Average] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object UDAFDemo2 {
}