您的位置:首页 > 博客中心 > 数据库 >

第七篇:Spark SQL 源码分析之Physical Plan 到 RDD的具体实现

时间:2022-03-16 11:12

 /** */

  接上一篇文章,本文将介绍Physical Plan的toRDD的具体实现细节:

  我们都知道一段sql,真正的执行是当你调用它的collect()方法才会执行Spark Job,最后计算得到RDD。

[java]    
  1. lazy val toRdd: RDD[Row] = executedPlan.execute()  

  Spark Plan基本包含4种操作类型,即BasicOperator基本类型,还有就是Join、Aggregate和Sort这种稍复杂的。

  如图:

  技术分享

一、BasicOperator

1.1、Project

  Project 的大致含义是:传入一系列表达式Seq[NamedExpression],给定输入的Row,经过Convert(Expression的计算eval)操作,生成一个新的Row。   Project的实现是调用其child.execute()方法,然后调用mapPartitions对每一个Partition进行操作。
  这个f函数其实是new了一个MutableProjection,然后循环的对每个partition进行Convert。 [java]    
  1. case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {  
  2.   override def output = projectList.map(_.toAttribute)  
  3.   override def execute() = child.execute().mapPartitions { iter => //对每个分区进行f映射  
  4.     @transient val reusableProjection = new MutableProjection(projectList)   
  5.     iter.map(reusableProjection)  
  6.   }  
  7. }  
  通过观察MutableProjection的定义,可以发现,就是bind references to a schema 和 eval的过程:   将一个Row转换为另一个已经定义好schema column的Row。
  如果输入的Row已经有Schema了,则传入的Seq[Expression]也会bound到当前的Schema。 [java]    
  1. case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {  
  2.   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =  
  3.     this(expressions.map(BindReferences.bindReference(_, inputSchema))) //bound schema  
  4.   
  5.   private[this] val exprArray = expressions.toArray  
  6.   private[this] val mutableRow = new GenericMutableRow(exprArray.size) //新的Row  
  7.   def currentValue: Row = mutableRow  
  8.   def apply(input: Row): Row = {  
  9.     var i = 0  
  10.     while (i < exprArray.length) {  
  11.       mutableRow(i) = exprArray(i).eval(input)  //根据输入的input,即一个Row,计算生成的Row  
  12.       i += 1  
  13.     }  
  14.     mutableRow //返回新的Row  
  15.   }  
  16. }  

1.2、Filter

 Filter的具体实现是传入的condition进行对input row的eval计算,最后返回的是一个Boolean类型,  如果表达式计算成功,返回true,则这个分区的这条数据就会保存下来,否则会过滤掉。 [java]    
  1. case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {  
  2.   override def output = child.output  
  3.   
  4.   override def execute() = child.execute().mapPartitions { iter =>  
  5.     iter.filter(condition.eval(_).asInstanceOf[Boolean]) //计算表达式 eval(input row)  
  6.   }  
  7. }  

1.3、Sample

  Sample取样操作其实是调用了child.execute()的结果后,返回的是一个RDD,对这个RDD调用其sample函数,原生方法。 [java]    
  1. case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)  
  2.   extends UnaryNode  
  3. {  
  4.   override def output = child.output  
  5.   
  6.   // TODO: How to pick seed?  
  7.   override def execute() = child.execute().sample(withReplacement, fraction, seed)  
  8. }  

1.4、Union

  Union操作支持多个子查询的Union,所以传入的child是一个Seq[SparkPlan]   execute()方法的实现是对其所有的children,每一个进行execute(),即select查询的结果集合RDD。   通过调用SparkContext的union方法,将所有子查询的结果合并起来。 [java]    
  1. case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {  
  2.   // TODO: attributes output by union should be distinct for nullability purposes  
  3.   override def output = children.head.output  
  4.   override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) //子查询的结果进行union  
  5.   
  6.   override def otherCopyArgs = sqlContext :: Nil  
  7. }  

1.5、Limit

  Limit操作在RDD的原生API里也有,即take().   但是Limit的实现分2种情况:   第一种是 limit作为结尾的操作符,即select xxx from yyy limit zzz。 并且是被executeCollect调用,则直接在driver里使用take方法。   第二种是 limit不是作为结尾的操作符,即limit后面还有查询,那么就在每个分区调用limit,最后repartition到一个分区来计算global limit. [java]    
  1. case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)  
  2.   extends UnaryNode {  
  3.   // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:  
  4.   // partition local limit -> exchange into one partition -> partition local limit again  
  5.   
  6.   override def otherCopyArgs = sqlContext :: Nil  
  7.   
  8.   override def output = child.output  
  9.   
  10.   override def executeCollect() = child.execute().map(_.copy()).take(limit) //直接在driver调用take  
  11.   
  12.   override def execute() = {  
  13.     val rdd = child.execute().mapPartitions { iter =>  
  14.       val mutablePair = new MutablePair[Boolean, Row]()  
  15.       iter.take(limit).map(row => mutablePair.update(false, row)) //每个分区先计算limit  
  16.     }  
  17.     val part = new HashPartitioner(1)  
  18.     val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) //需要shuffle,来repartition  
  19.     shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))  
  20.     shuffled.mapPartitions(_.take(limit).map(_._2)) //最后单独一个partition来take limit  
  21.   }  
  22. }  

1.6、TakeOrdered

  TakeOrdered是经过排序后的limit N,一般是用在sort by 操作符后的limit。   可以简单理解为TopN操作符。 [java]    
  1. case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)  
  2.                       (@transient sqlContext: SQLContext) extends UnaryNode {  
  3.   override def otherCopyArgs = sqlContext :: Nil  
  4.   
  5.   override def output = child.output  
  6.   
  7.   @transient  
  8.   lazy val ordering = new RowOrdering(sortOrder) //这里是通过RowOrdering来实现排序的  
  9.   
  10.   override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)  
  11.   
  12.   // TODO: Terminal split should be implemented differently from non-terminal split.  
  13.   // TODO: Pick num splits based on |limit|.  
  14.   override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)  
  15. }  

1.7、Sort

  Sort也是通过RowOrdering这个类来实现排序的,child.execute()对每个分区进行map,每个分区根据RowOrdering的order来进行排序,生成一个新的有序集合。   也是通过调用Spark RDD的sorted方法来实现的。 [java]    
  1. case class Sort(  
  2.     sortOrder: Seq[SortOrder],  
  3.     global: Boolean,  
  4.     child: SparkPlan)  
  5.   extends UnaryNode {  
  6.   override def requiredChildDistribution =  
  7.     if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil  
  8.   
  9.   @transient  
  10.   lazy val ordering = new RowOrdering(sortOrder) //排序顺序  
  11.   
  12.   override def execute() = attachTree(this, "sort") {  
  13.     // TODO: Optimize sorting operation?  
  14.     child.execute()  
  15.       .mapPartitions(  
  16.         iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, //每个分区调用sorted方法,传入<span start="1">
  17. object ExistingRdd {  
  18.   def convertToCatalyst(a: Any): Any = a match {  
  19.     case o: Option[_] => o.orNull  
  20.     case s: Seq[Any] => s.map(convertToCatalyst)  
  21.     case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)  
  22.     case other => other  
  23.   }  
  24.   
  25.   def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {  
  26.     data.mapPartitions { iterator =>  
  27.       if (iterator.isEmpty) {  
  28.         Iterator.empty  
  29.       } else {  
  30.         val bufferedIterator = iterator.buffered  
  31.         val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)  
  32.   
  33.         bufferedIterator.map { r =>  
  34.           var i = 0  
  35.           while (i < mutableRow.length) {  
  36.             mutableRow(i) = convertToCatalyst(r.productElement(i))  
  37.             i += 1  
  38.           }  
  39.   
  40.           mutableRow  
  41.         }  
  42.       }  
  43.     }  
  44.   }  
  45.   
  46.   def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {  
  47.     ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))  
  48.   }  
  49. }  
 

二、 Join Related Operators

  HashJoin:

  在讲解Join Related Operator之前,有必要了解一下HashJoin这个位于execution包下的joins.scala文件里的trait。   Join操作主要包含BroadcastHashJoinLeftSemiJoinHashShuffledHashJoin均实现了HashJoin这个trait.   主要类图如下:   技术分享      HashJoin这个trait的主要成员有:   buildSide是左连接还是右连接,有一种基准的意思。   leftKeys是左孩子的expressions, rightKeys是右孩子的expressions。   left是左孩子物理计划,right是右孩子物理计划。   buildSideKeyGenerator是一个Projection是根据传入的Row对象来计算buildSide的Expression的。   streamSideKeyGenerator是一个MutableProjection是根据传入的Row对象来计算streamSide的Expression的。   这里buildSide如果是left的话,可以理解为buildSide是左表,那么去连接这个左表的右表就是streamSide。   技术分享   HashJoin关键的操作是joinIterators,简单来说就是join两个表,把每个表看着Iterators[Row].   方式:   1、首先遍历buildSide,计算buildKeys然后利用一个HashMap,形成 (buildKeys, Iterators[Row])的格式。   2、遍历StreamedSide,计算streamedKey,去HashMap里面去匹配key,来进行join   3、最后生成一个joinRow,这个将2个row对接。   见代码注释: [java]    
  1. trait HashJoin {  
  2.   val leftKeys: Seq[Expression]  
  3.   val rightKeys: Seq[Expression]  
  4.   val buildSide: BuildSide  
  5.   val left: SparkPlan  
  6.   val right: SparkPlan  
  7.   
  8.   lazy val (buildPlan, streamedPlan) = buildSide match {  //模式匹配,将physical plan封装形成Tuple2,如果是buildLeft,那么就是(left,right),否则是(right,left)  
  9.     case BuildLeft => (left, right)  
  10.     case BuildRight => (right, left)  
  11.   }  
  12.   
  13.   lazy val (buildKeys, streamedKeys) = buildSide match { //模式匹配,将expression进行封装<span start="1">
  14. class JoinedRow extends Row {  
  15.   private[this] var row1: Row = _  
  16.   private[this] var row2: Row = _  
  17.   .........  
  18.    def copy() = {  
  19.     val totalSize = row1.size + row2.size   
  20.     val copiedValues = new Array[Any](totalSize)  
  21.     var i = 0  
  22.     while(i < totalSize) {  
  23.       copiedValues(i) = apply(i)  
  24.       i += 1  
  25.     }  
  26.     new GenericRow(copiedValues) //返回一个新的合并后的Row  
  27.   }  

2.1、LeftSemiJoinHash

 left semi join,不多说了,hive早期版本里替代IN和EXISTS 的版本。  将右表的join keys放到HashSet里,然后遍历左表,查找左表的join key是否能匹配。 [java]    
  1. case class LeftSemiJoinHash(  
  2.     leftKeys: Seq[Expression],  
  3.     rightKeys: Seq[Expression],  
  4.     left: SparkPlan,  
  5.     right: SparkPlan) extends BinaryNode with HashJoin {  
  6.   
  7.   val buildSide = BuildRight //buildSide是以右表为基准  
  8.   
  9.   override def requiredChildDistribution =  
  10.     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil  
  11.   
  12.   override def output = left.output  
  13.   
  14.   def execute() = {  
  15.     buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => //右表的物理计划执行后生成RDD,利用zipPartitions对Partition进行合并。然后用上述方法实现。  
  16.       val hashSet = new java.util.HashSet[Row]()  
  17.       var currentRow: Row = null  
  18.   
  19.       // Create a Hash set of buildKeys  
  20.       while (buildIter.hasNext) {  
  21.         currentRow = buildIter.next()  
  22.         val rowKey = buildSideKeyGenerator(currentRow)  
  23.         if(!rowKey.anyNull) {  
  24.           val keyExists = hashSet.contains(rowKey)  
  25.           if (!keyExists) {  
  26.             hashSet.add(rowKey)  
  27.           }  
  28.         }  
  29.       }  
  30.   
  31.       val joinKeys = streamSideKeyGenerator()  
  32.       streamIter.filter(current => {  
  33.         !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)  
  34.       })  
  35.     }  
  36.   }  
  37. }  

2.2、BroadcastHashJoin

 名约: 广播HashJoin,呵呵。   是InnerHashJoin的实现。这里用到了concurrent并发里的future,异步的广播buildPlan的表执行后的的RDD。   如果接收到了广播后的表,那么就用streamedPlan来匹配这个广播的表。   实现是RDD的mapPartitions和HashJoin里的joinIterators最后生成join的结果。 [java]    
  1. case class BroadcastHashJoin(  
  2.      leftKeys: Seq[Expression],  
  3.      rightKeys: Seq[Expression],  
  4.      buildSide: BuildSide,  
  5.      left: SparkPlan,  
  6.      right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {  
  7.   
  8.   override def otherCopyArgs = sqlContext :: Nil  
  9.   
  10.   override def outputPartitioning: Partitioning = left.outputPartitioning  
  11.   
  12.   override def requiredChildDistribution =  
  13.     UnspecifiedDistribution :: UnspecifiedDistribution :: Nil  
  14.   
  15.   @transient  
  16.   lazy val broadcastFuture = future {  //利用SparkContext广播表  
  17.     sqlContext.sparkContext.broadcast(buildPlan.executeCollect())  
  18.   }  
  19.   
  20.   def execute() = {  
  21.     val broadcastRelation = Await.result(broadcastFuture, 5.minute)  
  22.   
  23.     streamedPlan.execute().mapPartitions { streamedIter =>  
  24.       joinIterators(broadcastRelation.value.iterator, streamedIter) //调用joinIterators对每个分区map  
  25.     }  
  26.   }  
  27. }  

2.3、ShuffleHashJoin

ShuffleHashJoin顾名思义就是需要shuffle数据,outputPartitioning是左孩子的的Partitioning。 会根据这个Partitioning进行shuffle。然后利用SparkContext里的zipPartitions方法对每个分区进行zip。 这里的requiredChildDistribution,的是ClusteredDistribution,这个会在HashPartitioning里面进行匹配。 关于这里面的分区这里不赘述,可以去org.apache.spark.sql.catalyst.plans.physical下的partitioning里面去查看。 [java]    
  1. case class ShuffledHashJoin(  
  2.     leftKeys: Seq[Expression],  
  3.     rightKeys: Seq[Expression],  
  4.     buildSide: BuildSide,  
  5.     left: SparkPlan,  
  6.     right: SparkPlan) extends BinaryNode with HashJoin {  
  7.   
  8.   override def outputPartitioning: Partitioning = left.outputPartitioning  
  9.   
  10.   override def requiredChildDistribution =  
  11.     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil  
  12.   
  13.   def execute() = {  
  14.     buildPlan.execute().zipPartitions(streamedPlan.execute()) {  
  15.       (buildIter, streamIter) => joinIterators(buildIter, streamIter)  
  16.     }  
  17.   }  
  18. }  


未完待续 :)  

原创文章,转载请注明:

转载自:,作者: 

本文链接地址:

注:本文基于协议,欢迎转载、转发和评论,但是请保留本文作者署名和文章链接。如若需要用于商业目的或者与授权方面的协商,请联系我。

技术分享

转自:http://blog.csdn.net/oopsoom/article/details/38274621

本类排行

今日推荐

热门手游