在使用Pipeline串联多个stage时model和非model的区别

train.csv数据:

id,name,age,sex
1,lyy,20,F
2,rdd,20,M
3,nyc,18,M
4,mzy,10,M

数据读取:

 SparkSession  spark = SparkSession.builder().enableHiveSupport()
.getOrCreate();
Dataset<Row> dataset = spark
.read()
.format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat")
.option("header", true)
.option("inferSchema", true)
.option("delimiter", ",")
//.load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/SanFranciscoCrime/document/kaggle-旧金山犯罪分类/train-new.csv") //PreProcess1
.load("file:///E:/git/bigdata_sparkIDE/spark-ide/workspace/test/SparkMLTest/DataPreprocessing/document/train.csv") //PreProcess2
.persist();
     public static void PreProcess2(Dataset<Row> data) {

                 data.printSchema();
// 重新索引标签值
StringIndexerModel labelIndexer = new StringIndexer()
.setInputCol("sex")
.setOutputCol("label")
.fit(data); StringIndexerModel nameIndexer = new StringIndexer()
.setInputCol("name")
.setOutputCol("namenum")
.fit(data); /* 会报错:Exception in thread "main" java.lang.IllegalArgumentException: Field "namenum" does not exist.
* 原因是:Model类型调用fit时,要求数据集中必须包含InputCol所指定的列名
* 不会将Pipeline某个stage的输出作为InputCol,即使那个stage的OutputCol指定的列名与其相同也不行
* StringIndexerModel name1Indexer = new StringIndexer()
.setInputCol("namenum")
.setOutputCol("namenum1")
.fit(data);*/ /* 错误原因StringIndexerModel错误一样,features并不是data的列
* VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexfeatures")
.setMaxCategories(4)
.fit(data);*/ //成功
//原因说明:非model时,转换器不会调用fit,而会使用Pipeline某个stage的输出作为InputCol
//由于stage[2]即 assembler已经生成features,故而该处直接使用;
//但是该类型时不能单独使用,必须依赖Pipeline
VectorIndexer featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexfeatures")
.setMaxCategories(); //由上述分析可知,该处输入的列可以是多个stage的输出组成,因为VectorAssembler非model
//因此可以使用中间生成结果,且可以使用多个
VectorAssembler assembler = new VectorAssembler()
.setInputCols("id,namenum,age".split(","))
.setOutputCol("features"); //这里的stage的顺序很重要,一定按照依赖关系顺序放入,如下顺序就会报错:
//Exception in thread "main" java.lang.IllegalArgumentException: Field "features" does not exist.
//Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,featureIndexer,assembler}); //将featureIndexer放到assembler即可
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer}); // Train model. This also runs the indexers.
PipelineModel model = pipeline.fit(data); // Make predictions.
Dataset<Row> result = model.transform(data); result.show(, false); }

root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- sex: string (nullable = true)

+---+----+---+---+-----+-------+--------------+-------------+
|id |name|age|sex|label|namenum|features |indexfeatures|
+---+----+---+---+-----+-------+--------------+-------------+
|1 |lyy |20 |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
|2 |rdd |20 |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
|3 |nyc |18 |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
|4 |mzy |10 |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
+---+----+---+---+-----+-------+--------------+-------------+

综上分析,可以将原有代码做一简化:

 public static void PreProcess2(Dataset<Row> data) {

                 data.printSchema();
// 重新索引标签值
StringIndexer labelIndexer = new StringIndexer()
.setInputCol("sex")
.setOutputCol("label"); StringIndexer nameIndexer = new StringIndexer()
.setInputCol("name")
.setOutputCol("namenum"); VectorIndexer featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexfeatures")
.setMaxCategories(); VectorAssembler assembler = new VectorAssembler()
.setInputCols("id,namenum,age".split(","))
.setOutputCol("features"); Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {labelIndexer,nameIndexer,assembler,featureIndexer}); // Train model. This also runs the indexers.
PipelineModel model = pipeline.fit(data); //以这里的data为基准数据 // Make predictions.
Dataset<Row> result = model.transform(data); result.show(, false); }

运行结果:

root
|-- id: integer (nullable = true)
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- sex: string (nullable = true) +---+----+---+---+-----+-------+--------------+-------------+
|id |name|age|sex|label|namenum|features |indexfeatures|
+---+----+---+---+-----+-------+--------------+-------------+
| |lyy | |F |1.0 |1.0 |[1.0,1.0,20.0]|[0.0,1.0,2.0]|
| |rdd | |M |0.0 |2.0 |[2.0,2.0,20.0]|[1.0,2.0,2.0]|
| |nyc | |M |0.0 |0.0 |[3.0,0.0,18.0]|[2.0,0.0,1.0]|
| |mzy | |M |0.0 |3.0 |[4.0,3.0,10.0]|[3.0,3.0,0.0]|
+---+----+---+---+-----+-------+--------------+-------------+
上一篇:Nor Flash芯片特性分析


下一篇:Windows 下的 Redis 的启动