关于spark 读取 elasticsearch时,空字符串被转成null的问题排查与解决

【版本介绍】

  本次问题所使用的代码版本是spark 2.2.0 和 elasticsearch-spark-20_2.11 

【情景介绍】

  今天公司的小伙伴发现了一个问题,在spark 中,使用 elasticsearch-spark 读取es的数据,"" 这种空字符串的值,在spark中会被转成null,导致计算结果异常

  代码如下:

 1 def getTable()(implicit spark:SparkSession)={
 2   var query=
 3     s"""
 4        |{
 5        |  "query": {
 6        |    "bool": {
 7        |      "must": [
 8        |        {
 9        |          "term": {
10        |            "revise_status": {
11        |              "value": ""
12        |            }
13        |          }
14        |        }
15        |      ]
16        |    }
17        |  }
18        |}
19     """.stripMargin
20   //读取es数据
21   EsSparkSQL.esDF(spark,s"""aaa/bbb""", query)
22 }
23 
24 def main(args: Array[String]): Unit = {
25   implicit val spark = SparkAndRelevantCptUtil.getSparkSession("test", "local[3]")
26   getTable().select("revise_status").show(1000, false)
27 }

  显示的结果

1 +-------------+
2 |revise_status|
3 +-------------+
4 |null         |
5 |null         |
6 |null         |
7 |null         |
8 |null         |
9 +-------------+

  按理来说,空字符串和null是两个概念差很多的东西,elasticsearch-spark 读取出来为什么会转成null呢,第一个想法会不会是什么参数没配上?在百度谷歌都没找到答案的情况下,只能看是看源码了

【解决方法】

  这个问题的原因是因为 EsSparkSQL 自己代码中,将空字符串识别为 无效数据,是一个bug,解决该问题的步骤如下:

  1、将 org.elasticsearch.spark.serialization.ScalaValueReader 的内容拷贝到一个文本编辑器中

  2、在自己的项目中创建一个 org.elasticsearch.spark.serialization.ScalaValueReader 的Scala类,包名必须一致

关于spark 读取 elasticsearch时,空字符串被转成null的问题排查与解决

  3、将刚刚拷贝的 原版 ScalaValueReader 内容,全部粘贴到这个 刚创建的 ScalaValueReader  类中

  4、修改新 ScalaValueReader  类里面的 checkNull() 方法 的代码

def checkNull(converter: (String, Parser) => Any, value: String, parser: Parser) = {
  if (value != null) {
    //解决掉es误把空字符串弄成null的bug
    //if (!StringUtils.hasText(value) && emptyAsNull) {
    if (!"".equals(value) && !StringUtils.hasText(value) && emptyAsNull) {
      nullValue()
    }
    else {
      converter(value, parser).asInstanceOf[AnyRef]
    }
  }
  else {
    nullValue()
  }
}

  5、重新启动项目测试即可

 

【排查步骤】

  从代码的 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query) 开始研究

 

  【stage1】

  在 org.elasticsearch.spark.sql.EsSparkSQL#esDF 下

  这句没啥可看,继续追踪esDF() 方法

 1 def esDF(ss: SparkSession, resource: String, query: String, cfg: Map[String, String]): DataFrame = esDF(ss.sqlContext, resource, query, cfg) 

  【stage2】

  在 org.elasticsearch.spark.sql.EsSparkSQL#esDF 下

      这句没啥可看,继续追踪esDF() 方法

1 def esDF(sc: SQLContext, resource: String, query: String, cfg: Map[String, String]): DataFrame = {
2   //继续追踪里面的代码
3   esDF(sc, collection.mutable.Map(cfg.toSeq: _*) += (ES_RESOURCE_READ -> resource, ES_QUERY -> query))
4 }

  【stage3】

  在org.elasticsearch.spark.sql.EsSparkSQL#esDF下

 1 def esDF(sc: SQLContext, cfg: Map[String, String]): DataFrame = {
 2   //获取spark的配置
 3   val esConf = new SparkSettingsManager().load(sc.sparkContext.getConf).copy()
 4   //外部如果有传入参数,就合并到esConf里面来,将所有参数整一块
 5   esConf.merge(cfg.asJava)
 6 
 7   /**
 8    * 通过format方法设置数据格式的实现类
 9    * 用options方法传入配置
10    * 【重点】load方法生成DataFrame
11    */
12   sc.read.format("org.elasticsearch.spark.sql").options(esConf.asProperties.asScala.toMap).load
13 }

  【stage4】

   load方法明显是核心的逻辑,所以我们追踪一下load方法

  在 org.apache.spark.sql.DataFrameReader#load 下

1 def load(): DataFrame = {
2   //追踪下去
3   load(Seq.empty: _*) // force invocation of `load(...varargs...)`
4 }

  【stage5】

   在 org.apache.spark.sql.DataFrameReader#load 下

  load方法是很重要,他负责生成DataFrame,我们在看这个load方法时,里面主要两个内容重要:

  1、首先看到了:sparkSession.baseRelationToDataFrame( ... DataSource实例 ... ) ,它的作用是将内部的 DataSource实例参数转化成一个DataFrame,具体做法会在后面讲解

  2、这个DataSource实例是由 DataSource.apply( ... ).resolveRelation() 生成

    (1)、apply() 方法是用来构造DataSource这个类

    (2)、resolveRelation()方法作用 是使用反射创建出对应 DataSource 实例

 1 /**
 2  * load方法最重要的功能就是将baseRelation转换成DataFrame,
 3  * 该功能是通过sparkSession的 def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame
 4  * 接口实现的,其中的参数baseRelation通过DataSource类的resolveRelation方法提供。
 5  */
 6 def load(paths: String*): DataFrame = {
 7   if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
 8     throw new AnalysisException("Hive data source can only be used with tables, you can not " +
 9       "read files of Hive data source directly.")
10   }
11 
12   /**
13    * baseRelationToDataFrame() 方法 接受 baseRelation 参数返回 DataFrame,是通过 Dataset.ofRows(sparkSession,logicalPlan) 方法实现的,
14    * 其中的参数 logicPlan 是由 LogicalRelation(baseRelation) 得到。
15    */
16   sparkSession.baseRelationToDataFrame(
17     DataSource
18       //创建一个DataSource元数据信息类
19       .apply(
20         sparkSession,
21         paths = paths,
22         userSpecifiedSchema = userSpecifiedSchema,
23         className = source,
24         options = extraOptions.toMap
25       )
26 
27       /**
28        * DataSource的resolveRelation() 方法中使用反射创建出对应 DataSource 实例,协同用户指定的 userSpecifiedSchema 进行匹配,匹配成功返回对应的 baseRelation:
29        * 1、如果是基于文件的,返回HadoopFsRelation实例
30        * 2、非文件的,返回如KafkaRelation或者JDBCRelation
31        */
32       .resolveRelation()
33   )
34 }

   【stage6】

  现在我们先关注  DataSouce 的 resolveRelation() 方法

  在org.apache.spark.sql.execution.datasources.DataSource#resolveRelation下

 1 def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
 2   //【追踪点1】这里的 providingClass 是 DataSource 的类,providingClass.newInstance() 就是数据源用反射的方式创建 DataSource 实例
 3   val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
 4     ...
 5     
 6     //由于 EsSparkSQL 并没有提供设置schema的方法,所以schema为空,如果有兴趣的小伙伴可以自己改造 EsSparkSQL ,给他加上设置 schema 的方法,就可以显示设置字段类型
 7     case (dataSource: RelationProvider, None) =>
 8       //【追踪点2】 用 org.elasticsearch.spark.sql.DefaultSource 的实例 创建 ElasticsearchRelation
 9       //ElasticsearchRelation 是一种提供数据获取buildScan,数据插入更新insert等操作的数据源实际操作类实例
10       dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
11    
12     ...
13   }
14 
15   //返回 ElasticsearchRelation
16   relation
17 }

   【stage7】

  我们先看【追踪点1】,这里的 providingClass 是 DataSource 的类,那它是什么时候被赋值的呢?

  在 org.apache.spark.sql.execution.datasources.DataSource#providingClass 下

1 //DataSource.lookupDataSource 用了 className 去查找 提供数据源支持的真正的那个类,这边的className就是 org.elasticsearch.spark.sql
2 lazy val providingClass: Class[_] = DataSource.lookupDataSource(sparkSession, className)

  这里我们可以看到className,它的值是,那么这个值是从哪里来的呢?就是在 【stage3】的时候,在 formt( className ) 方法设置的

1 sc.read.format("org.elasticsearch.spark.sql")...

  那我们继续追踪到 DataSource 的 lookupDataSource 方法中

   【stage8】

  在 org.apache.spark.sql.execution.datasources.DataSource#lookupDataSource 下

  在这里,我们要拿到数据源的类

 1 //查找DataSource的类,注意这时候的 provider 的值是 org.elasticsearch.spark.sql
 2 def lookupDataSource(sparkSession: SparkSession, provider: String): Class[_] = {
 3   //backwardCompatibilityMap 会保存一些过时的 数据源类,如果在这之中,就会替换成最新的 数据源类,否则还是按照用来之前的类名
 4   var provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
 5   //如果是orc、org.apache.spark.sql.hive.orc.OrcFileFormat 这两种特殊情况,设置 数据源类为 OrcFileFormat
 6   if (Seq("orc", "org.apache.spark.sql.hive.orc.OrcFileFormat").contains(provider1.toLowerCase) &&
 7       sparkSession.conf.get(SQLConf.ORC_ENABLED)) {
 8     logInfo(s"$provider1 is replaced with ${classOf[OrcFileFormat].getCanonicalName}")
 9     provider1 = classOf[OrcFileFormat].getCanonicalName
10   }
11   
12   //对 provider1 加工出完整的 数据源类
13   val provider2 = s"$provider1.DefaultSource"
14   //拿到spark当前线程上下文中的类加载器,如果没有,就用当前创建Utils类的类加载器
15   val loader = Utils.getContextOrSparkClassLoader
16   //拿到所有已注册的格式集合,比如TEXT、JSON、CSV等等
17   val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)
18 
19   try {
20     //过滤出符合spark内置格式的数据
21     serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
22       // the provider format did not match any given registered aliases
23       //由于 provider1 值是:org.elasticsearch.spark.sql,不是spark提供的常规格式,所以进入到这步骤
24       case Nil =>
25         try {
26           //注意此刻的 provider1 值是:org.elasticsearch.spark.sql, provider2是 org.elasticsearch.spark.sql.DefaultSource
27         · //尝试用类加载器去加载 provider1 和 provider2类,谁能加载成功,就用谁做数据源,
28           //由于 provider1 的值是 org.elasticsearch.spark.sql,是scala 中objct类型,并不是一个类,所以无法加载成功,所以最终加载成功的是 provider2
29           Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
30             case Success(dataSource) =>
31               // Found the data source using fully qualified path
32               //返回 类型为  provider2,即 org.elasticsearch.spark.sql.DefaultSource
33               dataSource
34             ...
35           }
36         } catch {
37           ...
38         }
39       ...
40     }
41   } catch {
42     ...
43   }
44 }

   【stage9】  

  所以回到 【stage6】 中,providingClass 的值就是 org.elasticsearch.spark.sql.DefaultSource,然后再看【追踪点2】的这段代码

1  dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)

  就知道这就是调用了 org.elasticsearch.spark.sql.DefaultSource 的 createRelation() 方法,所以接下来我们要追踪的就是 org.elasticsearch.spark.sql.DefaultSource#createRelation

   【stage10】 

  在 org.elasticsearch.spark.sql.DefaultSource#createRelation 下

1 /**
2  * ElasticsearchRelation 是一种提供数据获取buildScan,数据插入更新insert等操作的数据源实际操作类实例
3  */
4 override def createRelation(@transient sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
5   //创建 ElasticsearchRelation 
6   ElasticsearchRelation(params(parameters), sqlContext)
7 }

 

   【stage11】 

  在 org.elasticsearch.spark.sql.ElasticsearchRelation 下

  我们可以大致看一下 ElasticsearchRelation 的重要的方法,其中 buildScan() 方法是本次的重点,里面包含了本次问题的关键代码

1 //【本次核心代码】创建一个读取es数据的RDD
2 def buildScan(requiredColumns: Array[String], filters: Array[Filter]) { ... }
3 //插入更新
4 def insert(data: DataFrame, overwrite: Boolean): Unit = { ... }

  【stage12】

  回到【stage5】,已经得出结论 DataSource.apply( ... ).resolveRelation() 生成的是 org.elasticsearch.spark.sql.ElasticsearchRelation 的实例,那我们接着看

  sparkSession.baseRelationToDataFrame( ... DataSource实例 ... ) ,刚刚也知道它的作用是将内部的 DataSource实例参数转化成一个DataFrame,接下来我们重点分析 sparkSession 的 baseRelationToDataFrame( ... ) 这个方法

   

  在 org.apache.spark.sql.SparkSession#baseRelationToDataFrame 下

1 def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
2   /**
3    * 这里做两件重要的事
4    * 1、用LogicalRelation(baseRelation) 生成计划任务
5    * 2、用 Dataset.ofRows 获取 DataFrame
6    */
7   Dataset.ofRows(self, LogicalRelation(baseRelation))
8 }

   

  【stage13】

  我们首先看 LogicalRelation(baseRelation), 它的作用事生成计划任务,所以看看它究竟在做什么

   在 org.apache.spark.sql.execution.datasources.LogicalRelation#apply 下

1 def apply(relation: BaseRelation): LogicalRelation =
2   //创建一个 LogicalRelation 计划任务
3   LogicalRelation(relation, relation.schema.toAttributes, None)

  【stage14】

  这个 LogicalRelation 内部具体什么东西就先不看了,因为现在似乎看不出什么东西,先进入下一步

 1 case class LogicalRelation(
 2     relation: BaseRelation,
 3     output: Seq[AttributeReference],
 4     catalogTable: Option[CatalogTable])
 5   extends LeafNode with MultiInstanceRelation {
 6   
 7   
 8   override def equals(other: Any): Boolean
 9 
10   override def hashCode: Int
11 
12   override def preCanonicalized: LogicalPlan
13 
14   @transient override def computeStats(conf: SQLConf): Statistics
15 
16   override def newInstance(): LogicalRelation
17 
18   override def refresh(): Unit
19 
20   override def simpleString: String
21 }

  【stage15】

  回到 【stage12】 中,查看代码 Dataset.ofRows( ... }

  在 org.apache.spark.sql.Dataset#ofRows 下

1 def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
2   //执行逻辑计划,此处为懒加载,只新建QueryExecution实例,并不会触发实际动作。需要注意的是QueryExecution其实是包含了SQL解析执行的4个阶段计划(解析、分析、优化、执行)
3   val qe = sparkSession.sessionState.executePlan(logicalPlan)
4   //触发语法分析,得到分析计划(Analyzed Logical Plan)
5   qe.assertAnalyzed()
6   //新建一个DataSet 来 获取数据,并将Dataset返回成DataFrame
7   new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
8 }

  【stage16】

  我们先对 【stage15】这段代码进行分析

1 val qe = sparkSession.sessionState.executePlan(logicalPlan)

  在 org.apache.spark.sql.internal.SessionState#executePlan 下

1 //执行逻辑计划,此处为懒加载,只新建QueryExecution实例,并不会触发实际动作。需要注意的是QueryExecution其实是包含了SQL解析执行的4个阶段计划(解析、分析、优化、执行)
2 def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan)

  【stage17】

  其中 createQueryExecution 是来源于SessionState 类的构造函数

  在 org.apache.spark.sql.internal.SessionState 下

  这里createQueryExecution: LogicalPlan => QueryExecution是一枚函数,将logicalplan转换为QueryExecution,它其执行整个workflow

  这里顺便备注一下,planner 就是 【stage12】里的 LogicalRelation(baseRelation),可以从 【stage17】 倒看得到 【stage12】,就能看出来,这个东西在后面有用到

private[sql] class SessionState(
    ...
   val planner: SparkPlanner,
    ...
    createQueryExecution: LogicalPlan => QueryExecution,
    ...) { ... }

  【stage18】

  看到 createQueryExecution 这个参数是 在创建 SessionState 传进来的,所以我们要去找 SessionState 的创建代码,

  所以回到 【stage15】我们看一下这段代码中的 sparkSession.sessionState 是怎么来的

1 val qe = sparkSession.sessionState.executePlan(logicalPlan)

  在 org.apache.spark.sql.SparkSession#sessionState 下

 1 lazy val sessionState: SessionState = {
 2   parentSessionState
 3     .map(_.clone(this))
 4     .getOrElse {
 5       //用反射的方式把它实例化出来一个Builder,然后再通过build()方法创建一个 SessionState实例,然后返回
 6       val state = SparkSession.instantiateSessionState( 
 7         SparkSession.sessionStateClassName(sparkContext.conf),
 8         self)
 9       initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) }
10       state
11     }
12 }

 

  【stage19】

  我们先追踪 SparkSession.sessionStateClassName(sparkContext.conf) 这句代码

  在 org.apache.spark.sql.SparkSession#instantiateSessionState 下

   这个是获取session状态的类名

 1 private def sessionStateClassName(conf: SparkConf): String = {
 2   conf.get(CATALOG_IMPLEMENTATION) match {
 3     case "hive" =>
 4       if (isLLAPEnabled(conf)) {
 5         LLAP_SESSION_STATE_BUILDER_CLASS_NAME
 6       }
 7       else {
 8         //【会到这步】这个静态变量的值是:"org.apache.spark.sql.hive.HiveSessionStateBuilder"
 9         HIVE_SESSION_STATE_BUILDER_CLASS_NAME
10       }
11     case "in-memory" => classOf[SessionStateBuilder].getCanonicalName
12   }
13 }

  【stage20】

    回到【stage18】,追踪 SparkSession.instantiateSessionState( ... ) 代码

 1 private def instantiateSessionState(
 2     className: String,
 3     sparkSession: SparkSession): SessionState = {
 4   try {
 5     // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])`
 6     //【看这里】className的值是:org.apache.spark.sql.hive.HiveSessionStateBuilder ,用反射的方式把它实例化出来一个Builder,然后再通过build()方法创建一个 SessionState实例,然后返回
 7     val clazz = Utils.classForName(className)
 8     val ctor = clazz.getConstructors.head
 9     ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
10   } catch {
11     case NonFatal(e) =>
12       throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
13   }
14 }

  【stage21】

    接下来我们根据上面的代码

1 ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()

    就知道了这就是调用了 org.apache.spark.sql.hive.HiveSessionStateBuilder 的 build() 方法,那我们追踪下这个代码

    由于build() 方法是 HiveSessionStateBuilder 父类 BaseSessionStateBuilder 的方法,所以我们到 BaseSessionStateBuilder 下查看

    

    在 org.apache.spark.sql.internal.BaseSessionStateBuilder#build 下

 1 def build(): SessionState = {
 2   new SessionState(
 3     session.sharedState,
 4     conf,
 5     experimentalMethods,
 6     functionRegistry,
 7     udfRegistration,
 8     catalog,
 9     sqlParser,
10     analyzer,
11     optimizer,
12     planner,
13     streamingQueryManager,
14     listenerManager,
15     resourceLoader,
16     
17     //这就是我们在追寻的 createQueryExecution
18     createQueryExecution,
19     createClone)
20 }

  【stage22】

  在 org.apache.spark.sql.internal.BaseSessionStateBuilder#createQueryExecution 下

  这里会返回一个函数  LogicalPlan => QueryExecution ,这就是 【stage16】 中, def executePlan(plan: LogicalPlan): QueryExecution = createQueryExecution(plan) 中的 createQueryExecution 这个元素的具体实现

1 protected def createQueryExecution: LogicalPlan => QueryExecution = { plan =>
2   //追踪这里
3   new QueryExecution(session, plan)
4 }

  【stage23】

  回到 【stage15】 ,val qe = sparkSession.sessionState.executePlan(logicalPlan) 这句代码,经过 【stage16】 到 【stage22】 的分析,已经知道 qe就是 一个 QueryExecution 的实例,相当于

1 val qe = new QueryExecution(session, logicalPlan)

  【stage24】

   接着回到 【stage15】 ,继续看这段代码

1 //触发语法分析,得到分析计划(Analyzed Logical Plan)
2 qe.assertAnalyzed()

  我们已经知道了qe就是 QueryExecution, 所以我们要去看 QueryExecution 下的 assertAnalyzed() 方法

   在 org.apache.spark.sql.execution.QueryExecution#assertAnalyzed 下

 1 def assertAnalyzed(): Unit = {
 2   // Analyzer is invoked outside the try block to avoid calling it again from within the
 3   // catch block below.
 4   //【追踪】analyzed是个懒加载的属性,执行去加载它,它的作用是对逻辑计划进行分析,得到分析后的逻辑计划
 5   analyzed
 6   try {
 7     //检查
 8     sparkSession.sessionState.analyzer.checkAnalysis(analyzed)
 9   } catch {
10     case e: AnalysisException =>
11       val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
12       ae.setStackTrace(e.getStackTrace)
13       throw ae
14   }
15 }

  【stage25】 

  在 org.apache.spark.sql.execution.QueryExecution#analyzed 下

  对逻辑计划进行分析,得到分析后的逻辑计划,即分析计划,分析计划的生成逻辑这里就不再细追下去,有兴趣自己看即可

1 lazy val analyzed: LogicalPlan = {
2   SparkSession.setActiveSession(sparkSession)
3   //对逻辑计划进行分析,得到分析后的逻辑计划,这里就不再细追下去
4   sparkSession.sessionState.analyzer.execute(logical)
5 }

   得到分析计划后,现在先暂时止步到这里,当然后面还有 逻辑计划转换成一个或多个物理执行计划 等操作,后面会讲到

  【stage26】 

  回到 【stage15】中,我们接着看下这段代码

1 //新建一个DataSet 来 获取数据
2 new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))

  这个代码就是把 sparkSession、 逻辑计划、 数据行结构 作为参数,生成一个Dataset,具体也不进行解析了,自己看即可

 

  现在上面的步骤,分析了从调用 EsSparkSQL.esDF( ... )  这个方法开始,到最后输出 DataFrame 的每一步过程

  接下来,我们会从 DataFrrame.show(false) 来讲解步骤

 

  【stage27】 

   现在 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query) .show(false) 这句代码,我们已经分析完 EsSparkSQL.esDF(spark,s"""aaa/bbb""", query) 部分,现在我们想分析 .show(false) 方法

  

  在 def show(truncate: Boolean): Unit = show(20, truncate) 下

1//继续追踪show() 方法
2 def show(truncate: Boolean): Unit = show(20, truncate)

  【stage28】 

  在 org.apache.spark.sql.Dataset#show 下

1 def show(numRows: Int, truncate: Boolean): Unit = if (truncate) {
2   println(showString(numRows, truncate = 20))
3 } else {
4   //走这步showString() 方法
5   println(showString(numRows, truncate = 0))
6 }

  【stage29】 

   在 org.apache.spark.sql.Dataset#showString 下

1 private[sql] def showString(_numRows: Int, truncate: Int = 20): String = {
2   val numRows = _numRows.max(0)
3   
4   //获取数据
5   val takeResult = toDF().take(numRows + 1)
6   ...
7 }

  【stage30】 

  在 org.apache.spark.sql.Dataset#take 下

1 //head就是获取头几条数据
2 def take(n: Int): Array[T] = head(n)

  【stage31】 

  在 org.apache.spark.sql.Dataset#head 下

1 //追踪withAction
2 def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)

  【stage32】 

  在 org.apache.spark.sql.Dataset#withAction 下

 1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
 2   try {
 3     /**
 4     * 【继续追踪】
 5     * 1、拿到缓存的解析计划,使用遍历优化器执行解析计划,得到若干优化计划。
 6     * 2、获取第一个优化计划,遍历执行前优化获得物理执行计划,这是已经可以执行的计划了。
 7     */
 8     qe.executedPlan.foreach { plan =>
 9       plan.resetMetrics()
10     }
11     val start = System.nanoTime()
12     
13     //执行物理计划,返回实际结果。至此,这条sql之旅就结束了
14     val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
15       action(qe.executedPlan)
16     }
17     val end = System.nanoTime()
18     sparkSession.listenerManager.onSuccess(name, qe, end - start)
19     result
20   } catch {
21     case e: Exception =>
22       sparkSession.listenerManager.onFailure(name, qe, e)
23       throw e
24   }
25 }

 

  【stage33】 

   在 org.apache.spark.sql.execution.QueryExecution#executedPlan 下

//这里有两个地方要注意看,一个是sparkPlan,另一个是prepareForExecution() 方法
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)

  【stage34】 

  首先看 sparkPlan

   在 org.apache.spark.sql.execution.QueryExecution#sparkPlan 下

  这里的重点是把优化后的逻辑计划(即执行计划)转换成一个或多个物理执行计划

 1 lazy val sparkPlan: SparkPlan = {
 2   //设置当前活动的sparkSession
 3   SparkSession.setActiveSession(sparkSession)
 4   // TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
 5   //       but we will implement to choose the best plan.
 6   
 7   //QueryExecution获取一个sparkSession.sessionState.planner,这是一个优化器,其实现类是SparkPlanner, 该planner会把一个优化后的逻辑计划转换成一个或多个物理执行计划。
 8   //注意看 optimizedPlan,先从这里继续跟踪
 9   planner.plan(ReturnAnswer(optimizedPlan)).next()
10 }

  【stage35】 

  在 org.apache.spark.sql.execution.QueryExecution#optimizedPlan 下

  对分析后的逻辑计划(分析计划)进行优化,得到优化后的逻辑执行计划,即执行计划

1 //optimizedPlain是通过sparkSession.sessionState.optimizer对逻辑执行计划进行优化,得到优化后的逻辑执行计划
2 //在这里,我们关注 withCachedData 
3 lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData)

  【stage36】 

   在 org.apache.spark.sql.execution.QueryExecution#withCachedData 下

  缓存分析后的逻辑计划(分析计划)

1 //对分析后的逻辑计划进行缓存,更新缓存中的计划
2 lazy val withCachedData: LogicalPlan = {
3   assertAnalyzed()
4   assertSupported()
5   //通过sparkSession.sharedState.cacheManager.useCachedData把analyzed进行缓存,或更新cacheManager中的检查后的逻辑执行计划。
6   sparkSession.sharedState.cacheManager.useCachedData(analyzed)
7 }

  【stage37】 

  在 org.apache.spark.sql.execution.QueryExecution#analyzed 下

  得到分析后的逻辑计划(分析计划)

1 //对逻辑计划进行分析,得到分析后的逻辑计划,即分析计划
2 lazy val analyzed: LogicalPlan = {
3   SparkSession.setActiveSession(sparkSession)
4   //通过sparkSession.sessionState.analyzer.executeAndCheck来检查一个逻辑执行计划,并得到一个分析和检查后的逻辑计划:analyzed。
5   sparkSession.sessionState.analyzer.execute(logical)
6 }

  【核心stage38】 

   回到【stage34】中,查看下面代码

  整体做的事就是把优化后的逻辑计划(即执行计划)转换成一个或多个物理执行计划

1 planner.plan(ReturnAnswer(optimizedPlan)).next()

  我们接着分析 planner.plan( ... ) 这个方法

  在 org.apache.spark.sql.catalyst.planning.QueryPlanner#plan 下

 1 def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = {
 2   // Obviously a lot to do here still...
 3 
 4   // Collect physical plan candidates.
 5   /**
 6    * 【重点】
 7    * 将strategies 进行遍历,将strategies逐个去应用在逻辑计划上,然后做flat操作,返回一个PhysicalPlan的iterator。
 8    * 因为这个 strategies 是 DataSourceStrategy,Spark针对DataSource预定义了四种scan接口,
 9    *   1、TableScan
10    *   2、PrunedScan
11    *   3、PrunedFilteredScan
12    *   4、CatalystScan(其中CatalystScan是unstable的,也是不常用的),
13    *  
14    *  如果开发者(用户)自己实现的DataSource是实现了这四种接口之一的,在scan到执行计划的底层Relation时,就会调用来扫描文件。
15    *  这样最终得到一个Iterator[SparkPlan],每个SparkPlan就是可执行的物理操作了。
16    *
17    *  strategies 的值此刻有包含 DataSourceStrategy 类,所以会执行 DataSourceStrategy 的 apply的方法
18    */
19   val candidates = strategies.iterator.flatMap(_(plan))
20 
21   // The candidates may contain placeholders marked as [[planLater]],
22   // so try to replace them by their child plans.
23   //这个没仔细看,和这次的逻辑不太相关,大概是childPlan代替什么占位符
24   val plans = candidates.flatMap { candidate =>
25     val placeholders = collectPlaceholders(candidate)
26 
27     if (placeholders.isEmpty) {
28       // Take the candidate as is because it does not contain placeholders.
29       Iterator(candidate)
30     } else {
31       // Plan the logical plan marked as [[planLater]] and replace the placeholders.
32       placeholders.iterator.foldLeft(Iterator(candidate)) {
33         case (candidatesWithPlaceholders, (placeholder, logicalPlan)) =>
34           // Plan the logical plan for the placeholder.
35           val childPlans = this.plan(logicalPlan)
36 
37           candidatesWithPlaceholders.flatMap { candidateWithPlaceholders =>
38             childPlans.map { childPlan =>
39               // Replace the placeholder by the child plan
40               candidateWithPlaceholders.transformUp {
41                 case p if p == placeholder => childPlan
42               }
43             }
44           }
45       }
46     }
47   }
48 
49   val pruned = prunePlans(plans)
50 
51   assert(pruned.hasNext, s"No plan for $plan")
52   
53   //最终得到一个Iterator[SparkPlan],每个SparkPlan就是可执行的物理操作了。
54   pruned
55 }

  【核心stage39】 

  在 org.apache.spark.sql.execution.datasources.DataSourceStrategy#apply 下

  可以看到 读取数据的核心的逻辑了,就是 t.buildScan(a.map(_.name).toArray,这个代码会作为读取数据的主要逻辑被封装到一个rowRdd中

 1 def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
 2   ...
 3 
 4   //程序会进入这段代码
 5   case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
 6     pruneFilterProject(
 7       l,
 8       projects,
 9       filters,
10       
11       //【重点】重点看toCatalystRDD() 方法 和 t.buildScan() 方法
12       (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
13 
14   ...
15 }

  【核心stage40】 

  在 org.apache.spark.sql.execution.datasources.DataSourceStrategy#toCatalystRDD 下

  在这里可以看到rdd被转换成rowRdd的代码

 1 private[this] def toCatalystRDD(
 2     relation: LogicalRelation,
 3     output: Seq[Attribute],
 4     rdd: RDD[Row]): RDD[InternalRow] = {
 5   if (relation.relation.needConversion) {
 6     //将rdd,转换成RowRdd,这里的每行读取数据的核心代码,也就是外面传入的 t.buildScan()
 7     execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
 8   } else {
 9     rdd.asInstanceOf[RDD[InternalRow]]
10   }
11 }

  【核心stage41】 

  这里开始查看【核心stage39】 中的t.buildScan( ... ) 方法

       这个t的类是 org.elasticsearch.spark.sql.ElasticsearchRelation,这里要查看 ElasticsearchRelation 下的buildScan方法

  

  在 org.elasticsearch.spark.sql.ElasticsearchRelation#buildScan 下

1 def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
2   ...
3   
4   //【重点】输出一个ScalaEsRowRDD,这就是读取es数据的核心RDD
5   new ScalaEsRowRDD(sqlContext.sparkContext, paramWithScan, lazySchema)
6 }

  【核心stage42】 

   在 org.elasticsearch.spark.sql.ScalaEsRowRDDIterator 下

  这时候,重点关注一下ScalaEsRowRDDIterator的 父类 AbstractEsRDDIterator

 1 private[spark] class ScalaEsRowRDDIterator(
 2   context: TaskContext,
 3   partition: PartitionDefinition,
 4   schema: SchemaUtils.Schema)
 5   extends AbstractEsRDDIterator[Row](context, partition) { //【重点关注】继承了AbstractEsRDDIterator
 6 
 7   override def getLogger() = LogFactory.getLog(classOf[ScalaEsRowRDD])
 8 
 9   //初始化reader
10   override def initReader(settings: Settings, log: Log) = {
11     InitializationUtils.setValueReaderIfNotSet(settings, classOf[ScalaRowValueReader], log)
12 
13     // parse the structure and save the order (requested by Spark) for each Row (root and nested)
14     // since the data returned from Elastic is likely to not be in the same order
15     SchemaUtils.setRowInfo(settings, schema.struct)
16   }
17   
18   //输出es的value值
19   override def createValue(value: Array[Object]): Row = {
20     // drop the ID
21     value(1).asInstanceOf[ScalaEsRow]
22   }
23 }

  【核心stage43】 

  提醒一下,从现在开始,会有一大段代码都是在追踪如何取值,如果觉得没必要看,就直接跳到 【核心stage44】 看bug所在代码即可

  在 org.elasticsearch.spark.rdd.AbstractEsRDDIterator 下,我们要关注的是hasNext() 代码

  因为在读取数据的时候,会调用 AbstractEsRDDIterator 下的 hasNext() 方法,继续追踪代码

  

  在 org.elasticsearch.spark.rdd.AbstractEsRDDIterator#hasNext 下

  读取数据依靠 reader.hasNext() 来读取,这个reader就是

1 def hasNext: Boolean = {
2   if (CompatUtils.isInterrupted(context)) {
3     throw new TaskKilledException
4   }
5 
6   //【重点】重点看 reader.hasNext()
7   !finished && reader.hasNext()
8 }

 

  在 org.elasticsearch.hadoop.rest.ScrollQuery#hasNext 下

       可以看到用scroll 方式去从es读取数据过来,距离我们拿到数据的代码很近了

 1 public boolean hasNext() {
 2     ...
 3 
 4     if (!initialized) {
 5         initialized = true;
 6         
 7         try {
 8             //【重点】这里用 scroll 方式去从es读取数据过来,query是我们传入的dsl语句,body是查询体
 9             Scroll scroll = repository.scroll(query, body, reader);
10             // size is passed as a limit (since we can't pass it directly into the request) - if it's not specified (<1) just scroll the whole index
11             size = (size < 1 ? scroll.getTotalHits() : size);
12             scrollId = scroll.getScrollId();
13             batch = scroll.getHits();
14         } catch (IOException ex) {
15             throw new EsHadoopIllegalStateException(String.format("Cannot create scroll for query [%s/%s]", query, body), ex);
16         }
17 
18         // no longer needed
19         body = null;
20         query = null;
21     }
22     
23     ...
24 
25     return true;
26 }

 

  在 org.elasticsearch.hadoop.rest.RestRepository#scroll(java.lang.String, org.elasticsearch.hadoop.util.BytesArray, org.elasticsearch.hadoop.serialization.ScrollReader) 下

 1 Scroll scroll(String query, BytesArray body, ScrollReader reader) throws IOException {
 2     InputStream scroll = client.execute(POST, query, body).body();
 3     try {
 4         //【重点】读取数据
 5         return reader.read(scroll);
 6     } finally {
 7         if (scroll instanceof StatsAware) {
 8             stats.aggregate(((StatsAware) scroll).stats());
 9         }
10     }
11 }

 

  在 org.elasticsearch.hadoop.serialization.ScrollReader#read(java.io.InputStream) 下

 1 public Scroll read(InputStream content) throws IOException {
 2     Assert.notNull(content);
 3 
 4     BytesArray copy = null;
 5 
 6     if (log.isTraceEnabled() || returnRawJson) {
 7         //copy content
 8         copy = IOUtils.asBytes(content);
 9         content = new FastByteArrayInputStream(copy);
10         log.trace("About to parse scroll content " + copy);
11     }
12 
13     this.parser = new JacksonJsonParser(content);
14 
15     try {
16         //【重点】读取数据
17         return read(copy);
18     } finally {
19         parser.close();
20     }
21 }

 

   在 org.elasticsearch.hadoop.serialization.ScrollReader#read(org.elasticsearch.hadoop.util.BytesArray) 下

 1 private Scroll read(BytesArray input) {
 2     ...
 3     
 4     for (token = parser.nextToken(); token != Token.END_ARRAY; token = parser.nextToken()) {
 5         //【重点】从hit读取数据
 6         results.add(readHit());
 7     }
 8 
 9     ...
10 }

 

  在 org.elasticsearch.hadoop.serialization.ScrollReader#readHit 下

1 private Object[] readHit() {
2     Token t = parser.currentToken();
3     Assert.isTrue(t == Token.START_OBJECT, "expected object, found " + t);
4     //【重点】这里走readHitAsMap(),因为我们读取的数据是正常数据,没有特别设置返回json格式
5     return (returnRawJson ? readHitAsJson() : readHitAsMap());
6 }

 

  在 org.elasticsearch.hadoop.serialization.ScrollReader#readHitAsMap 下

 1 private Object[] readHitAsMap() {
 2     Object[] result = new Object[2];
 3     Object metadata = null;
 4     Object id = null;
 5 
 6     ...
 7 
 8         //读取数据出来
 9         data = read(StringUtils.EMPTY, t, null);
10 
11     ...
12 }

  

  在 org.elasticsearch.hadoop.serialization.ScrollReader#read(java.lang.String, org.elasticsearch.hadoop.serialization.Parser.Token, java.lang.String) 下

  这里的 fieldMapping 就是我们要查询的字段 revise_status

1 protected Object read(String fieldName, Token t, String fieldMapping) {
2     ...
3       return map(fieldMapping);
4     ...
5 }

 

   在 org.elasticsearch.hadoop.serialization.ScrollReader#map 下

1 protected Object map(String fieldMapping) {
2     ...
3       //reader读取数据,存入一个map,在方法的最后会输出出去,这里继续追踪read() 方法
4       reader.addToMap(map, fieldName, read(absoluteName, parser.nextToken(), nodeMapping));
5     ...
6 }

   

  在 org.elasticsearch.hadoop.serialization.ScrollReader#read(java.lang.String, org.elasticsearch.hadoop.serialization.Parser.Token, java.lang.String)  下

 1 protected Object read(String fieldName, Token t, String fieldMapping) {
 2     ...
 3 
 4     if (t.isValue()) {
 5         String rawValue = parser.text();
 6         try {
 7             if (isArrayField(fieldMapping)) {
 8                 return singletonList(fieldMapping, parseValue(esType));
 9             } else {
10                 //【重点】按照字段类型来解析数据值,这里的esType值是keyword,对应es的mapping中的结构
11                 return parseValue(esType);
12             }
13         } catch (Exception ex) {
14             throw new EsHadoopParsingException(String.format(Locale.ROOT, "Cannot parse value [%s] for field [%s]", rawValue, fieldName), ex);
15         }
16     }
17     return null;
18 }

  

  在 org.elasticsearch.hadoop.serialization.ScrollReader#parseValue 下

private Object parseValue(FieldType esType) {
    Object obj;
    // special case of handing null (as text() will return "null")
    if (parser.currentToken() == Token.VALUE_NULL) {
        obj = null;
    }
    else {
        //【重点】读取值
        obj = reader.readValue(parser, parser.text(), esType);
    }
    parser.nextToken();
    return obj;
}

  

  在 org.elasticsearch.spark.sql.ScalaRowValueReader#readValue 下

 1 override def readValue(parser: Parser, value: String, esType: FieldType) = {
 2   sparkRowField = if (getCurrentField == null) null else getCurrentField.getFieldName
 3 
 4   if (sparkRowField == null) {
 5     sparkRowField = Utils.ROOT_LEVEL_NAME
 6   }
 7 
 8   //【重点】读取数据值
 9   super.readValue(parser, value, esType)
10 }

  

  在 org.elasticsearch.spark.serialization.ScalaValueReader#readValue 下

 1 def readValue(parser: Parser, value: String, esType: FieldType) = {
 2   if (esType == null || parser.currentToken() == VALUE_NULL) {
 3     nullValue()
 4 
 5   } else {
 6     esType match {
 7       case NULL => nullValue()
 8       case STRING => textValue(value, parser)
 9       case TEXT => textValue(value, parser)
10       //【重点】从这里进入
11       case KEYWORD => textValue(value, parser)
12       ...
13     }
14   }
15 }

  

  在 org.elasticsearch.spark.serialization.ScalaValueReader#textValue 下

//【重点】追踪checkNull() 方法
def textValue(value: String, parser: Parser) = { checkNull (parseText, value, parser) }

  【核心stage44】

   在 org.elasticsearch.spark.serialization.ScalaValueReader#checkNull 下

  这里就是bug的所在了

 1 def checkNull(converter: (String, Parser) => Any, value: String, parser: Parser) = {
 2   if (value != null) {
 3     //【重点】当我的value是"" 空字符串的时候,StringUtils.hasText(value) 会判断为 false,即不认为空字符串是有效值
 4     if (!StringUtils.hasText(value) && emptyAsNull) {
 5       nullValue()
 6     }
 7     else {
 8       converter(value, parser).asInstanceOf[AnyRef]
 9     }
10   }
11   else {
12     nullValue()
13   }
14 }

   

  修复方案就是在这句话之前判断空字符串为false即可,然后在项目中创建一个一模一样的类,在相同包名下,运行的时候就会自动覆盖原来的类,就能解决问题

 1 def checkNull(converter: (String, Parser) => Any, value: String, parser: Parser) = {
 2   if (value != null) {
 3     //【修复】加了一个判断,!"".equals(value) ,当值为空字符串的时候,就到elsed 逻辑输出值即可
 4     if (!"".equals(value) && !StringUtils.hasText(value) && emptyAsNull) {
 5       nullValue()
 6     }
 7     else {
 8       converter(value, parser).asInstanceOf[AnyRef]
 9     }
10   }
11   else {
12     nullValue()
13   }
14 }

 

   这里虽然解决了问题,但是觉得还可以继续再讲一讲后续是如何把rdd提交到Job执行

   【stage45】

   回到【stage33】中,回头看这句话,就知道 executedPlan 就是一个或多个物理执行计划

1 lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)

 

  再回到【stage32】中,看这段代码,执行物理计划,返回实际结果

1 //执行物理计划,返回实际结果。至此,这条sql之旅就结束了
2 val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
3   action(qe.executedPlan)
4 }

   【stage46】

   在 org.apache.spark.sql.execution.SQLExecution#withNewExecutionId 下

  执行body函数,通过 【stage31】  可以知道,body就是 collectFromPlan() 方法,我们接着追踪 collectFromPlan() 方法

1 def withNewExecutionId[T](
2     sparkSession: SparkSession,
3     queryExecution: QueryExecution)(body: => T): T = {
4   
5   ...
6     //执行body 动作,也就是外面传进来的action
7     body
8   ...
9 }

  【stage47】 

   在 org.apache.spark.sql.Dataset#collectFromPlan 下

1 private def collectFromPlan(plan: SparkPlan): Array[T] = {
2   plan.executeCollect().map(boundEnc.fromRow)
3 }

  【stage48】 

  在 org.apache.spark.sql.execution.CollectLimitExec#executeCollect 下

1 //child为SparkPlan,所以是调用SparkPlan.executeTake(limit)
2 override def executeCollect(): Array[InternalRow] = child.executeTake(limit)

  【stage49】 

  在 org.apache.spark.sql.execution.SparkPlan#executeTake 下 

  在这段代码中,做了很重要的两件事

    1、获取前面计划任务生成的RowRdd的方法
    2、提交job,获取结果

 1 def executeTake(n: Int): Array[InternalRow] = {
 2   if (n == 0) {
 3     return new Array[InternalRow](0)
 4   }
 5 
 6   //【重点】获取前面计划任务生成的RowRdd的方法
 7   val childRDD = getByteArrayRdd(n)
 8 
 9   val buf = new ArrayBuffer[InternalRow]
10   val totalParts = childRDD.partitions.length
11   var partsScanned = 0
12   while (buf.size < n && partsScanned < totalParts) {
13     // The number of partitions to try in this iteration. It is ok for this number to be
14     // greater than totalParts because we actually cap it at totalParts in runJob.
15     var numPartsToTry = 1L
16     if (partsScanned > 0) {
17       // If we didn't find any rows after the previous iteration, quadruple and retry.
18       // Otherwise, interpolate the number of partitions we need to try, but overestimate
19       // it by 50%. We also cap the estimation in the end.
20       val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2)
21       if (buf.isEmpty) {
22         numPartsToTry = partsScanned * limitScaleUpFactor
23       } else {
24         // the left side of max is >=1 whenever partsScanned >= 2
25         numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1)
26         numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
27       }
28     }
29 
30     val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
31     val sc = sqlContext.sparkContext
32     
33     //【重点】提交job
34     val res = sc.runJob(childRDD,
35       (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty[Byte], p)
36 
37     buf ++= res.flatMap(decodeUnsafeRows)
38 
39     partsScanned += p.size
40   }
41 
42   if (buf.size > n) {
43     buf.take(n).toArray
44   } else {
45     buf.toArray
46   }
47 }

   至此,EsSparkSQL获取数据的源码解析基本完毕

上一篇:静态环形队列(C语言)


下一篇:protocol buffer中的一些点(pacakge、go_package、proto依赖等)