需求
有udaf.json数据内容如下
{"name":"Michael","salary":3000}
{"name":"Andy","salary":4500}
{"name":"Justin","salary":3500}
{"name":"Berta","salary":4000}
求取 平均工资
●继承UserDefinedAggregateFunction方法重写说明
inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性,一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
package SparkSql
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* Created by 一个蔡狗 on 2020/4/13.
*
* UDAF 自定义函数
*
*/
object UDAF_01 {
// 继承 UserDefinedAggregateFunction 重写方法
// inputSchema:输入数据的类型
// bufferSchema:产生中间结果的数据类型
// dataType:最终返回的结果类型
// deterministic:确保一致性,一般用true
// initialize:指定初始值
// update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
// merge:全局聚合(将每个分区的结果进行聚合)
// evaluate:计算最终的结果
class GetAvg extends UserDefinedAggregateFunction {
// inputSchema:输入数据的类型 StructType 表结构
override def inputSchema: StructType = {
// ::Nil 创建 list ::Nil
StructType(StructField("input",LongType)::Nil)
}
// bufferSchema:产生中间结果的数据类型
// sum : 每次的 临时的 总和
// total : 临时的总次数
override def bufferSchema: StructType = {
StructType(StructField("sum",LongType)::StructField("total",LongType)::Nil)
}
// dataType:最终返回的结果类型
override def dataType: DataType = {
DoubleType
}
// deterministic:确保一致性,一般用true
override def deterministic: Boolean = {
true
}
//初始化数据 1 设置 x 个变量 2 每个变量 进行 初始化数据
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// buffer(0) 作用 : 用于 记录临时的 数据 和
buffer(0)=0L
// buffer(1) 作用 : 用于 记录临时的 数据 条数
buffer(1)=0L
}
// RDD 中 有多个分区 update 计算一个 分区内的 数据
// update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// buffer 记录临时的 数据的 和 input 输入的数据
buffer(0)=buffer.getLong(0)+input.getLong(0)
//临时输入的总数量(条数)
buffer(1)=buffer.getLong(1)+1
}
//merge:全局聚合(将每个分区的结果进行聚合)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//累加 第一个分区的 金额总和 与 第二个 分区的 金额总和
buffer1(0) =buffer1.getLong(0)+buffer2.getLong(0)
//离家 第一个分区的次数 和 第二个 分区的 次数
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
// evaluate:计算最终的结果
override def evaluate(buffer: Row): Any = {
//平均值
buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
}
}
def main(args: Array[String]): Unit = {
// 1创建 SparkSeesion
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("01").getOrCreate()
val udafJson: DataFrame = spark.read.json("E:\\udaf.json")
udafJson.show()
//注册一个 udaf 函数
spark.udf.register("GetAvg", new GetAvg())
//注册成一张表
udafJson.createOrReplaceTempView("UdafTabel")
//查看薪水 注册一个 GetAvg() 方法 实现 平均工资
spark.sql("select GetAvg(salary) from UdafTabel ").show()
spark.sql("select avg(salary) from UdafTabel ").show()
//关闭 spark
spark.stop()
}
//熟练掌握 基于 Spark的 UDAF
}