参考官网:GenericUDAFCaseStudy - Apache Hive - Apache Software Foundationhttps://cwiki.apache.org/confluence/display/Hive/GenericUDAFCaseStudy
package comxxx.hive;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import java.text.DecimalFormat;
import java.util.*;
/**
* 1.Writing the resolver -- 负责解析函数的元数据,函数传入的参数的类型检查。函数返回值的说明等
* 2.Writing the evaluator --负责计算
* 2.1getNewAggregationBuffer
* 2.2iterate
* 2.3terminatePartial
* 2.4merge
* 2.5terminate
*
* 3.UDAF的运行原理:
* ①在group by 分组后运行
* ②运行的范围是分组的一组内
* ③依次对组中的每一行进行计算,最终得到一行结果
* 4.函数如何用? --分组后直接调用函数,传入spu_name
* select
* coupon_id,myudaf(spu_name)
* from test6
* group by coupon_id
*/
public class MyUDAF extends AbstractGenericUDAFResolver {
// 创建一个自己定义的Evaluator
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
// Type-checking goes here! --进行函数输入的参数类型检查
//获取函数输入的参数
TypeInfo[] parameters = info.getParameters();
//对参数个数进行验证 -- 我们要求只传入一列
if (parameters.length != 1) {
throw new UDFArgumentException("参数个数只能是一个!");
}
//校验类型 --判断是不是基本数据类型
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("参数类型必须是基础数据类型!");
}
// 校验类型 --我们要求必须是String
if (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()
!= PrimitiveObjectInspector.PrimitiveCategory.STRING) {
throw new UDFArgumentException("参数类型必须是String!");
}
return new MyEvaluator();
}
// 定义Evaluator
public static class MyEvaluator extends GenericUDAFEvaluator {
/**
* 需要手动调出init方法
* 目的是给函数标识当前处于计算的哪个阶段(mode)
* ObjectInspector :是类型检查器
*/
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
//调用此方法获取当前的mode
super.init(m, parameters);
// This function should be overriden in every sub class
// And the sub class should call super.init(m, parameters) to get mode set.
//根据当前所处的阶段,获取当前阶段要使用的类型检查器
/**
* @param parameters In PARTIAL1 and COMPLETE mode, the parameters are original data;
* In PARTIAL2 and FINAL mode, the parameters are just partial aggregations
*/
if (m == Mode.PARTIAL2 || m == Mode.FINAL) {
// 给mapOI赋值
mapOI = (StandardMapObjectInspector) parameters[0];
}
/**
@return In PARTIAL1 and PARTIAL2 mode, the ObjectInspector for the return value of terminatePartial() call;
* In FINAL and COMPLETE mode, the ObjectInspector for the return value of terminate() call.
*/
if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
//返回terminatePartial()返回值(是map)对应的类型检查器--map的检查器(k,v的检查器)
return ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
PrimitiveObjectInspectorFactory.javaIntObjectInspector);
} else {
//返回terminate()返回值(是string)对应的类型检查器--string
return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
}
}
//创建一个新的缓冲区 --缓冲区是自己定义的!--我们需要存的是品牌名称和下单次数,需要map结构
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new MyBuf();
}
// 重置缓冲区,清空缓冲区
@Override
public void reset(AggregationBuffer agg) throws HiveException {
// 先将缓冲区对象强转成自己定义的缓冲区对象类型,再调出自己定义的缓冲区对象
((MyBuf) agg).buff.clear();
}
/**
* 迭代输入的每一行,将结果存入缓冲区
*
* @param agg 缓冲区对象
* @param parameters 输入的一行 --一列spu_name
*/
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
//取出输入的一行中的spu_name --参数只有一列
String spu_name = parameters[0].toString();
//获取map
HashMap<String, Integer> buff = ((MyBuf) agg).buff;
//将spu_name累加到map中 --先取出buff里原来的value,如果没有就默认为0,取出来之后再+1作为新的value,放入map中
buff.put(spu_name, buff.getOrDefault(spu_name, 0) + 1);
// 统计品牌次数
}
/**
* 负责缓冲区序列化
* Here persistable means the return value can only be built up in terms of Java primitives,
* arrays, primitive wrappers (e.g. Double), Hadoop Writables, Lists, and Maps
* 返回值只能是基础数据类型、arrays, primitive wrappers (e.g. Double), Hadoop Writables, Lists, and Maps等
*/
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
return ((MyBuf) agg).buff;
}
/**
* 将两个缓冲区进行合并,得到一个缓冲区
*
* @param agg 当前task的缓冲区
* @param partial 从网络中接收的其他task序列化后的缓冲区对象,使用时需要反序列化
* 反序列化之前需要先用缓冲区对应的ObjectInspector进行类型检查,检查通过,才能反序列化
*/
//声明一个缓冲区对应的ObjectInspector --map类型的对象检查器
//声明后是在Init()中为mapOI赋值,如果在此处赋值,后续仍会将此值清空
private StandardMapObjectInspector mapOI;
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
//取出当前缓冲区的map
HashMap<String, Integer> map1 = ((MyBuf) agg).buff;
//对从网络中接收的缓冲区对象partial进行类型检查
Map<?, ?> map2 = mapOI.getMap(partial);
//对key,value的类型继续检查
PrimitiveObjectInspector mapKeyObjectInspector = (PrimitiveObjectInspector)mapOI.getMapKeyObjectInspector();
PrimitiveObjectInspector mapValueObjectInspector = (PrimitiveObjectInspector)mapOI.getMapValueObjectInspector();
// 使用key,value的类型检测器,检测key,value是不是此类型,如果是,反序列化获取Key,value
for (Map.Entry<?, ?> entry : map2.entrySet()) {
String key = PrimitiveObjectInspectorUtils.getString(entry.getKey(), mapKeyObjectInspector);
int value = PrimitiveObjectInspectorUtils.getInt(entry.getValue(), mapValueObjectInspector);
//将当前缓冲区的map中的元素与从网络中接收的map中的元素进行合并
map1.put(key,map1.getOrDefault(key,0) + value);
}
}
/**
* 基于最后合并的最终的缓冲区,计算得到函数输出的结果
*
* @param agg 最终合并的缓冲区
*/
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
//先取出合并后的缓冲区map
HashMap<String, Integer> buff = ((MyBuf) agg).buff;
//1.计算总次数
double totalTimes = 0d;
for (Integer value : buff.values()) { //for循环遍历得到每个spu_name对应的次数,进行累加
totalTimes += value;
}
// 2.对每一个spu_name的times进行排序,取前三名 --map是非线性的,不能排序,要转成list等线性的结构进行排序
// 先转成set,再转成list
ArrayList<Map.Entry<String, Integer>> entryArrayList = new ArrayList<>(buff.entrySet());
// 需要传入一个比较器
entryArrayList.sort(new Comparator<Map.Entry<String, Integer>>() {
// 降序排序(默认是升序,所以前面加个-号)
@Override
public int compare(Map.Entry<String, Integer> o1, Map.Entry<String, Integer> o2) {
return -o1.getValue().compareTo(o2.getValue());
}
});
// 取前三 --subList(0,3)截取前三个 --但集合中可能没有三个,所以要取最小值Math.min(3,entryArrayList.size())
List<Map.Entry<String, Integer>> top3Spu_name = entryArrayList.subList(0, Math.min(3, entryArrayList.size()));
// 3.计算前三的比例之和,求其他的比例
double top3Percent = 0d;
// 声明一个存放每一个spu_name最终字符串的集合
ArrayList<String> strs = new ArrayList<>();
//声明一个百分数格式化器
DecimalFormat decimalFormat = new DecimalFormat("##.##%");
for (Map.Entry<String, Integer> entry : top3Spu_name) {
// 计算前三的每个spu_name的占比
double spu_percent = entry.getValue() / totalTimes;
strs.add(entry.getKey() + ":" + decimalFormat.format(spu_percent));
top3Percent += spu_percent;
}
// 计算其他的比例 --只有当前coupon_id下的spu_name>3才有其他 (spu_name有三个以上才有其他)
if (entryArrayList.size() > 3) {
strs.add("其他:" + decimalFormat.format(1 - top3Percent));
}
//将集合中的字符串拼接为结果--使用一个工具类
String result = StringUtils.join(strs, ',');
return result;
}
// UDAF logic goes here! --定义缓冲区对象
static class MyBuf implements AggregationBuffer {
//自定定义存储想存储的数据的结构 --我们需要存的是品牌名称和下单次数,需要map结构
private HashMap<String, Integer> buff = new HashMap<>();
}
}
}