Spark源码学习:Shuffle

Spark源码学习:Shuffle

ShuffleMapStage&ResultStage

在划分Stage时,最后一个Stage称为finalStage, 它本质上是一个ResultStage对象。
前面的所有Stage被称为shuffleMapStage。
ShuffleMapStage的结束伴随着Shuffle文件的写磁盘。

org.apache.spark.scheduler.DAGScheduler

在DAGScheduler里处理Tasks时,根据stage进行模式匹配。这里先看ShuffleMapTask(包含Shuffle Write写操作)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
val tasks: Seq[Task[_]] = try {
val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
stage match {
case stage: ShuffleMapStage =>
stage.pendingPartitions.clear()
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = partitions(id)
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
}

case stage: ResultStage =>
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
val part = partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
stage.rdd.isBarrier())
}
}
} catch {
case NonFatal(e) =>
abortStage(stage, s"Task creation failed: $e\n${Utils.exceptionString(e)}", Some(e))
runningStages -= stage
return
}

org.apache.spark.scheduler.ShuffleMapTask

在ShuffleMapTask中,有一个runTask()方法。

1
dep.shuffleWriterProcessor.write(rdd, dep, mapId, context, partition)

org.apache.spark.shuffle.ShuffleWriteProcessor

跳转到write方法,有个write操作。

1
2
3
4
5
6
7
8
9
10
// 这里会获取shuffleManager。在Spark早期有HashShuffleManager,现在只有SortShuffleManager。
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](
dep.shuffleHandle,
mapId,
context,
createMetricsReporter(context))

writer.write(
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

这里点击查看write方法,发现是一个抽象方法,必然有实现。Ctrl+H快捷查看发现,ShuffleWriter类一共有三个实现类。
分别是:BypassMergeSortShuffleWriter、UnsafeShuffleWriter以及SortShuffleWriter。

org.apache.spark.shuffle.ShuffleWriter

具体选择哪个write类实现由前面的SortShuffleManager管理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent(
handle.shuffleId, _ => new OpenHashSet[Long](16))
mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) }
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf,
metrics,
shuffleExecutorComponents)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
mapId,
env.conf,
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
}
}

在方法里面有具体的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
override def registerShuffle[K, V, C](
shuffleId: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, dependency)
}
}

这里简单整理三个实现类:

ShuffleWriter实现类 handle处理器 使用条件
BypassMergeSortShuffleWriter BypassMergeSortShuffleHandle 1.不能使用预聚合;2.下游的分区数小于等于 200。
UnsafeShuffleWriter SerializedShuffleHandle 1.序列化操作支持重定位(Java序列化不支持,Kryo支持);2.不能使用预聚合;3.下游的分区数不能超过16777216
SortShuffleWriter BaseShuffleHandle 不满足BypassMergeSortShuffleWriter和SortShuffleWriter

org.apache.spark.shuffle.sort.SortShuffleWriter

这里以SortShuffleWriter为例。查看write方法。由于SortShuffleWriter是支持预聚合操作的,这里首先会进行预聚合的判断以及相关操作。

1
2
3
4
5
6
7
8
// 1.这里会做预聚合相关操作
sorter.insertAll(records)
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
// 2.ShuffleWrite操作
val partitionLengths = mapOutputWriter.commitAllPartitions()
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)

需要注意的是在预聚合的判断里,会根据是否预聚合来用不同的数据结构存储数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@volatile private var map = new PartitionedAppendOnlyMap[K, C]
@volatile private var buffer = new PartitionedPairBuffer[K, C]

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined

if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}

这里有一个commitAllPartitions方法,发现跳转的也是一个抽象类方法。我们找到具体的实现类。

org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter

在LocalDiskShuffleMapOutputWriter的commitAllPartitions方法里面,我们发现

1
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);

在这个方法里面我们发现有两个核心文件处理:indexFile和dataFile。

1
2
3
4
5
6
7
8
9
10
11
12
if (indexFile.exists()) {
indexFile.delete()
}
if (dataFile.exists()) {
dataFile.delete()
}
if (!indexTmp.renameTo(indexFile)) {
throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
}
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
}

org.apache.spark.scheduler.ResultTask

在ResultTask中必然包含着Shuffle Read的相关操作。这里同样有一个runTask()方法。

1
2
3
4
5
6
7
8
9
10
11
12
override def runTask(context: TaskContext): U = {
// TODO
func(context, rdd.iterator(partition, context))
}
// 跳转到这个迭代器里面
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}

查看getOrCompute方法,里面有一个computeOrReadCheckpoint(partition, context)方法。

1
2
3
4
5
6
7
8
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
}

点击compute发现是一个抽象方法,这里进行的是Shuffle操作,确定ShuffleRDD实现类。

org.apache.spark.rdd.ShuffledRDD

查看compute方法。

1
2
3
4
5
6
7
8
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
SparkEnv.get.shuffleManager.getReader(
dep.shuffleHandle, split.index, split.index + 1, context, metrics)
.read()
.asInstanceOf[Iterator[(K, C)]]
}

这里发现有个read的抽象方法,我们找到它的实现类,查看read方法的实现。

org.apache.spark.shuffle.BlockStoreShuffleReader

这里的read方法就是Shuffle Read读数据操作的地方。

1
2
3
override def read(): Iterator[Product2[K, C]] = {
// TODO
}
打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2015-2023 henrrywan

请我喝杯咖啡吧~

支付宝
微信