Spark mllib多层分类感知器在情感分析中的实际应用

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, Word2Vec}
import org.apache.spark.sql.SparkSession
// 读取数据源,格式如下:以空格隔开,最后一列数字列是分析标题后,人为打上的标签,
值是按照情绪程度,值选择于【-1,-0.75,-0.5,-0.25,,0.25,0.50,0.75,1】其中之一。
// 10090 C779C882AA39436A89C463BCB406B838 涨停板,复盘,全,靠,新,股,撑,门面,万科,A,尾盘,封板 0.75
// 10091 519A9C6AD0A845298B0B3924117C0B4F 一,行业,再现,重大,利好,板块,反弹,仍,将,继续 0.75
// 10092 C86CEC7DB9794311AF386C3D7B0B7CBD 藁城区,3,大,项目,新,获,规划证,开发,房企,系,同,一家 0
// 10093 FCEA2FFC1C2F4D6C808F2CBC2FF18A8C 完善,对,*,企业,和,对外,投资,统计,监测 0.5
// 10094 204A77847F03404986331810E039DFC2 财联社,电报 0
// 10095 E571B9EF451F4D5F8426A1FA06CD9EE6 审计署,部分,央企,业绩,不,实 -0.5
// 10096 605264A2F6684CC4BB4B2A0B6A8FA078 厨卫,品牌,新,媒体,榜,看看,谁家,的,官微,最,爱,卖萌 0.25

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, Word2Vec}
import org.apache.spark.sql.SparkSession

object mllib {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local").getOrCreate()
    val parsedRDD = spark.sparkContext.textFile("D:\\data\\mlpc.txt").map(line => {
      val arr = line.split(" ")
      if (arr.length == 4) {
        (arr(3), arr(2).split(","))
      } else {
        ("", "".split(","))
      }
    })
    val msgDF = spark.createDataFrame(parsedRDD).toDF("label", "message")
    msgDF.printSchema()
    msgDF.show(false)
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(msgDF)
    val word2Vec = new Word2Vec().setInputCol("message").setOutputCol("features").setVectorSize(2).setMinCount(1)

val layers = Array[Int](2, 250, 500, 200)
    val mlpc = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(512).setSeed(1234L)
      .setMaxIter(128)
      .setFeaturesCol("features")
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")

val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

val Array(trainingData, testData) = msgDF.randomSplit(Array(0.8, 0.2))
    val pipeline = new Pipeline().setStages(Array(labelIndexer, word2Vec, mlpc, labelConverter))
    val model = pipeline.fit(trainingData)
    val predictionResultDF = model.transform(testData)
    //below 2 lines are for debug use
    predictionResultDF.printSchema
    predictionResultDF.select("message", "label", "predictedLabel").show(30)
    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")
    val predictionAccuracy = evaluator.evaluate(predictionResultDF)
    println("Testing Accuracy is %2.4f".format(predictionAccuracy * 100) + "%")
    spark.stop

}
}

上一篇:Ecplise通过Git将项目提交到GitHub


下一篇:JDK1.8 HashMap 扩容 对链表(长度小于默认的8)处理时重新定位的过程