package cn.itcast.tags.ml.classification
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
object RfModel {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName.stripSuffix("$"))
.master("local[4]")
.getOrCreate()
import org.apache.spark.sql.functions._
import spark.implicits._
// 1. 加载数据
val dataframe: DataFrame = spark.read
.format("libsvm")
.load("datas/ship/total001.txt")
// 划分数据集:训练数据和测试数据
val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))
// 2. 特征工程:特征提取、特征转换及特征选择
// 2.1. 将标签值label,转换为索引,从0开始,到 K-1
val labelIndexer: StringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("index_label")
.fit(dataframe)
val df1: DataFrame = labelIndexer.transform(dataframe)
// 2.2. 对类别特征数据进行特殊处理, 当每列的值的个数小于等于设置K,那么此列数据被当做类别特征,自动进行索引转换
val featureIndexer: VectorIndexerModel = new VectorIndexer()
.setInputCol("features")
.setOutputCol("index_features")
// TODO: 表示哪些特征列为类别特征列,并且将特征列的特征值进行索引化转换操作
.setMaxCategories(4) // 类别特征最大类别个数
.fit(df1)
val df2: DataFrame = featureIndexer.transform(df1)
val rf = new RandomForestClassifier()
.setLabelCol("index_label")
.setFeaturesCol("index_features")
// .setSubsamplingRate(1.0)
// TODO: 4. 构建Pipeline管道,设置Stage,每个Stage要么是算法(模型学习器Estimator),要么是模型(转换器Transformer)
val pipeline: Pipeline = new Pipeline()
// 设置Stage,依赖顺序
.setStages(
Array(labelIndexer, featureIndexer, rf)
)
val paramGrid: Array[ParamMap] = new ParamGridBuilder()
.addGrid(rf.maxDepth, Array(5, 10,15,20,25,30))
.addGrid(rf.impurity, Array("gini", "entropy"))
.addGrid(rf.maxBins, Array(32, 64))
.addGrid(rf.numTrees, Array(5, 10,20,30,40,50))
.addGrid(rf.featureSubsetStrategy,Array("auto","sqrt"))
.build()
// val paramGrid: Array[ParamMap] = new ParamGridBuilder()
// .addGrid(rf.maxDepth, Array(5, 10,15,20,25,30))
// .addGrid(rf.impurity, Array("gini", "entropy"))
// .addGrid(rf.maxBins, Array(32, 64))
// .addGrid(rf.numTrees, Array(5, 10))
// .addGrid(rf.featureSubsetStrategy,Array("auto","sqrt"))
// .build()
// 多分类评估器
val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("index_label")
.setPredictionCol("prediction")
指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
.setMetricName("accuracy")
// 训练验证
val validator: CrossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(4)
// 训练模型
println("正在训练模型...")
val model: CrossValidatorModel = validator.fit(trainingDF)
println(model.toString())
// 5. 模型评估,计算准确度
val predictionDF: DataFrame = model.transform(testingDF)
predictionDF.printSchema()
predictionDF
.select( $"probability", $"prediction",$"index_label")
.show(100, truncate = false)
val accuracy: Double = evaluator.evaluate(predictionDF)
println(s"Accuracy = $accuracy")
spark.stop()
}
}