位置:首页 > Spark手册 - Partitioner源码 >

Spark手册 - Partitioner源码

作者:小牛君|发布时间:2017-06-16

小牛学堂的课程大纲最近进行了再一次升级,并且同时推出Java大数据平台开发班、Python爬虫与数据挖掘班、Spark项目班、Spark大神班、机器学习算法实战班、BI数据分析实战班, 目前这类人群凤毛麟角,导致这个行业的平均薪资极高,为此小牛学堂集合了行业的诸多大牛开设对应班级,为想学习的同学提供机会!
如果想了解详细情况,请联系 今日值班讲师 或者直接加入千人QQ群进行咨询:Spark大数据交流学习群613807316

以下是本文正文:


1.  Partitioner源码

1.1.   abstract class

abstract class Partitioner extends Serializable {
 
def numPartitions: Int
 
def getPartition(key: Any): Int
}

1.2.   Partitioner伴生对象内的方法

def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
 
val rdds = (Seq(rdd) ++ others)
 
val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
 
if (hasPartitioner.nonEmpty) {
    hasPartitioner.maxBy(_.partitions.length).
partitioner.get
  }
else {
   
if (rdd.context.conf.contains("spark.default.parallelism")) {
     
new HashPartitioner(rdd.context.defaultParallelism)
    }
else {
     
new HashPartitioner(rdds.map(_.partitions.length).max)
    }
  }
}
 

 

 

1.3.   HashPartitioner

class HashPartitioner(partitions: Int) extends Partitioner {
 
 
def numPartitions: Int = partitions
 
def getPartition(key: Any): Int = key match {
   
case null => 0
   
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }
 
override def equals(other: Any): Boolean = other match {
   
case h: HashPartitioner =>
      h.numPartitions == numPartitions
   
case _ =>
     
false
 
}
 
override def hashCode: Int = numPartitions
}
 

 

 

1.4.   RangePartitioner

1.4.1.   核心代码

rangeBounds是通过结果分区数量和随机采样方法算出的以key排序的分区范围数组

def numPartitions: Int = rangeBounds.length + 1
 
private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
 
 
def getPartition(key: Any): Int = {
 
val k = key.asInstanceOf[K]
 
var partition = 0
  //查找key在rangeBounds中的位置
 
if (rangeBounds.length <= 128) {
   
// 如果范围数组<=128,则用while循环进行查找
   
while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
      partition +=
1
   
}
  }
else {
   
//如果范围数组>128,则使用2分查找
   
partition = binarySearch(rangeBounds, k)
   
if (partition < 0) {
      partition = -partition-
1
   
}
   
if (partition > rangeBounds.length) {
      partition =
rangeBounds.length
    }
  }
 
if (ascending) {
    partition
  }
else {
   
rangeBounds.length - partition
  }
}
 

 

1.4.2.   rangeBounds

private var rangeBounds: Array[K] = {
 
if (partitions <= 1) {
    Array.empty
 
}
else {
   
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
   
val sampleSize = math.min(20.0 * partitions, 1e6)
   
// Assume the input partitions are roughly balanced and over-sample a little bit.
   
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
   
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
   
if (numItems == 0L) {
      Array.empty
   
}
else {
     
// If a partition contains much more than the average number of items, we re-sample from it
      // to ensure that enough items are collected from that partition.
     
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
     
val candidates = ArrayBuffer.empty[(K, Float)]
     
val imbalancedPartitions = mutable.Set.empty[Int]
      sketched.foreach {
case (idx, n, sample) =>
       
if (fraction * n > sampleSizePerPartition) {
          imbalancedPartitions += idx
        }
else {
         
// The weight is 1 over the sampling probability.
         
val weight = (n.toDouble / sample.length).toFloat
         
for (key <- sample) {
            candidates += ((key, weight))
          }
        }
      }
     
if (imbalancedPartitions.nonEmpty) {
       
// Re-sample imbalanced partitions with the desired sampling probability.
       
val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
       
val seed = byteswap32(-rdd.id - 1)
       
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
       
val weight = (1.0 / fraction).toFloat
        candidates ++= reSampled.map(x => (x, weight))
      }
      RangePartitioner.determineBounds(candidates, partitions)
    }
  }
}
 

 

1.4.3.   sketch

def sketch[K : ClassTag](
    rdd: RDD[
K],
    sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[
K])]) = {
 
val shift = rdd.id
 
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
 
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
   
val seed = byteswap32(idx ^ (shift << 16))
   
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
      iter, sampleSizePerPartition, seed)
   
Iterator((idx, n, sample))
  }.collect()
 
val numItems = sketched.map(_._2).sum
  (numItems, sketched)
}
 

 

1.4.4.   determineBounds

def determineBounds[K : Ordering : ClassTag](
    candidates: ArrayBuffer[(
K, Float)],
    partitions: Int): Array[
K] = {
 
val ordering = implicitly[Ordering[K]]
 
val ordered = candidates.sortBy(_._1)
 
val numCandidates = ordered.size
 
val sumWeights = ordered.map(_._2.toDouble).sum
 
val step = sumWeights / partitions
 
var cumWeight = 0.0
 
var target = step
 
val bounds = ArrayBuffer.empty[K]
 
var i = 0
 
var j = 0
 
var previousBound = Option.empty[K]
 
while ((i < numCandidates) && (j < partitions - 1)) {
   
val (key, weight) = ordered(i)
    cumWeight += weight
   
if (cumWeight >= target) {
     
// Skip duplicate values.
     
if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
        bounds += key
        target += step
        j +=
1
       
previousBound = Some(key)
      }
    }
    i +=
1
 
}
  bounds.toArray
}
 

 

 

 



了解更多详情请联系 今日值班讲师 或者直接加入千人QQ群进行咨询:Spark大数据交流学习群613807316

分享到: