diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba4118..7607c128bb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -190,48 +190,60 @@ public Map getMetrics() { } } + private boolean isShuffleBlock(String[] blockIdParts) { + // length == 4: ShuffleBlockId + // length == 5: ContinuousShuffleBlockId + return (blockIdParts.length == 4 || blockIdParts.length == 5) && + blockIdParts[0].equals("shuffle"); + } + private class ManagedBufferIterator implements Iterator { private int index = 0; private final String appId; private final String execId; private final int shuffleId; - // An array containing mapId and reduceId pairs. - private final int[] mapIdAndReduceIds; + // An array containing mapId, reduceId and numBlocks tuple + private final int[] shuffleBlockIds; ManagedBufferIterator(String appId, String execId, String[] blockIds) { this.appId = appId; this.execId = execId; String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { + if (!isShuffleBlock(blockId0Parts)) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); } this.shuffleId = Integer.parseInt(blockId0Parts[1]); - mapIdAndReduceIds = new int[2 * blockIds.length]; + shuffleBlockIds = new int[3 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + if (!isShuffleBlock(blockIdParts)) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); } if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockIds[i]); } - mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); - mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + shuffleBlockIds[3 * i] = Integer.parseInt(blockIdParts[2]); + shuffleBlockIds[3 * i + 1] = Integer.parseInt(blockIdParts[3]); + if (blockIdParts.length == 4) { + shuffleBlockIds[3 * i + 2] = 1; + } else { + shuffleBlockIds[3 * i + 2] = Integer.parseInt(blockIdParts[4]); + } } } @Override public boolean hasNext() { - return index < mapIdAndReduceIds.length; + return index < shuffleBlockIds.length; } @Override public ManagedBuffer next() { final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, - mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); - index += 2; + shuffleBlockIds[index], shuffleBlockIds[index + 1], shuffleBlockIds[index + 2]); + index += 3; metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); return block; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e6399897be..795d36d7aa 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -157,7 +157,7 @@ public void registerExecutor( } /** - * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId, numBlocks). We make assumptions * about how the hash and sort based shuffles store their data. */ public ManagedBuffer getBlockData( @@ -165,13 +165,14 @@ public ManagedBuffer getBlockData( String execId, int shuffleId, int mapId, - int reduceId) { + int reduceId, + int numBlocks) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId, numBlocks); } /** @@ -232,13 +233,13 @@ private void deleteExecutorDirs(String[] dirs) { * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId, int numBlocks) { File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.index"); try { ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); - ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId); + ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId, numBlocks); return new FileSegmentManagedBuffer( conf, getFile(executor.localDirs, executor.subDirsPerLocalDir, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 386738ece5..470e1040e9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -59,9 +59,9 @@ public int getSize() { /** * Get index offset for a particular reducer. */ - public ShuffleIndexRecord getIndex(int reduceId) { + public ShuffleIndexRecord getIndex(int reduceId, int numBlocks) { long offset = offsets.get(reduceId); - long nextOffset = offsets.get(reduceId + 1); + long nextOffset = offsets.get(reduceId + numBlocks); return new ShuffleIndexRecord(offset, nextOffset - offset); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7846b71d5a..baa7146604 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -83,8 +83,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0, 1)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1, 1)).thenReturn(block1Marker); ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); @@ -106,8 +106,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6d201b8fe8..850fc0942d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -65,7 +65,7 @@ public void testBadRequests() throws IOException { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { - resolver.getBlockData("app0", "exec1", 1, 1, 0); + resolver.getBlockData("app0", "exec1", 1, 1, 0, 1); fail("Should have failed"); } catch (RuntimeException e) { assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); @@ -74,7 +74,7 @@ public void testBadRequests() throws IOException { // Invalid shuffle manager try { resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - resolver.getBlockData("app0", "exec2", 1, 1, 0); + resolver.getBlockData("app0", "exec2", 1, 1, 0, 1); fail("Should have failed"); } catch (UnsupportedOperationException e) { // pass @@ -84,7 +84,7 @@ public void testBadRequests() throws IOException { resolver.registerExecutor("app0", "exec3", dataContext.createExecutorInfo(SORT_MANAGER)); try { - resolver.getBlockData("app0", "exec3", 1, 1, 0); + resolver.getBlockData("app0", "exec3", 1, 1, 0, 1); fail("Should have failed"); } catch (Exception e) { // pass @@ -98,18 +98,25 @@ public void testSortShuffleBlocks() throws IOException { dataContext.createExecutorInfo(SORT_MANAGER)); InputStream block0Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 0, 1).createInputStream(); String block0 = CharStreams.toString( new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = - resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 1, 1).createInputStream(); String block1 = CharStreams.toString( new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); assertEquals(sortBlock1, block1); + + InputStream block01Stream = + resolver.getBlockData("app0", "exec0", 0, 0, 0, 2).createInputStream(); + String block01 = CharStreams.toString( + new InputStreamReader(block01Stream, StandardCharsets.UTF_8)); + block01Stream.close(); + assertEquals(sortBlock0 + sortBlock1, block01); } @Test diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 195fd4f818..95eb369a2f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -31,10 +31,12 @@ import scala.util.control.NonFatal import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.io.CompressionCodec import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ContinuousShuffleBlockId, ShuffleBlockId} import org.apache.spark.util._ /** @@ -280,10 +282,23 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } + protected def supportsContinuousBlockBatchFetch(serializerRelocatable: Boolean): Boolean = { + if (!serializerRelocatable) { + false + } else { + if (!conf.getBoolean("spark.shuffle.compress", true)) { + true + } else { + val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + CompressionCodec.supportsConcatenationOfSerializedStreams(compressionCodec) + } + } + } + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, serializerRelocatable = false) } /** @@ -295,8 +310,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -633,13 +651,17 @@ private[spark] class MapOutputTrackerMaster( } // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => shuffleStatus.withMapStatuses { statuses => - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses, + supportsContinuousBlockBatchFetch(serializerRelocatable)) } case None => Seq.empty @@ -669,12 +691,16 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr /** Remembers which map output locations are currently being fetched on an executor. */ private val fetching = new HashSet[Int] - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + serializerRelocatable: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses, + supportsContinuousBlockBatchFetch(serializerRelocatable)) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -849,6 +875,7 @@ private[spark] object MapOutputTracker extends Logging { * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) * @param statuses List of map statuses, indexed by map ID. + * @param supportsContinuousBlockBatchFetch if true, merge contiguous partitions in one IO * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. @@ -857,7 +884,8 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus], + supportsContinuousBlockBatchFetch: Boolean): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.zipWithIndex) { @@ -866,9 +894,16 @@ private[spark] object MapOutputTracker extends Logging { logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { - for (part <- startPartition until endPartition) { + if (endPartition - startPartition > 1 && supportsContinuousBlockBatchFetch) { + val totalSize: Long = (startPartition until endPartition).map(status.getSizeForBlock).sum splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + ((ContinuousShuffleBlockId(shuffleId, mapId, + startPartition, endPartition - startPartition), totalSize)) + } else { + for (part <- startPartition until endPartition) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + } } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 1d4b05caaa..99ee035c37 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -109,6 +109,7 @@ private[spark] class SerializerManager( private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle + case _: ContinuousShuffleBlockId => compressShuffle case _: BroadcastBlockId => compressBroadcast case _: RDDBlockId => compressRdds case _: TempLocalBlockId => compressShuffleSpill diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0562d45ff5..e72e0d4e06 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -46,7 +46,9 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition, + dep.serializer.supportsRelocationOfSerializedObjects), serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 449f60273b..80b7d4ff50 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -191,7 +191,7 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(blockId: ShuffleBlockIdBase): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) @@ -203,13 +203,21 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { + channel.position(blockId.reduceId * 8) val offset = in.readLong() + var expectedPosition = 0 + blockId match { + case bid: ContinuousShuffleBlockId => + val tempId = blockId.reduceId + bid.numBlocks + channel.position(tempId * 8) + expectedPosition = tempId * 8 + 8 + case _ => + expectedPosition = blockId.reduceId * 8 + 16 + } val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index d1ecbc1bf0..8b62c00bc6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.ShuffleBlockIdBase private[spark] /** @@ -34,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: ShuffleBlockIdBase): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18..14a4df5393 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -38,7 +38,7 @@ sealed abstract class BlockId { // convenience methods def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD: Boolean = isInstanceOf[RDDBlockId] - def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockIdBase] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name @@ -51,11 +51,25 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). +trait ShuffleBlockIdBase extends BlockId { + def shuffleId: Int + def mapId: Int + def reduceId: Int +} + @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends ShuffleBlockIdBase { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ContinuousShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int, numBlocks: Int) + extends ShuffleBlockIdBase { + override def name: String = + "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + "_" + numBlocks +} + @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" @@ -104,6 +118,7 @@ class UnrecognizedBlockId(name: String) object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val CONTINUE_SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r @@ -118,6 +133,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case CONTINUE_SHUFFLE(shuffleId, mapId, reduceId, length) => + ContinuousShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt, length.toInt) case SHUFFLE_DATA(shuffleId, mapId, reduceId) => ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4..5f2750af73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -372,7 +372,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockIdBase]) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -577,7 +577,7 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. val buf = new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockIdBase]).nioByteBuffer()) Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a4..a34edc0009 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -515,8 +515,9 @@ final class ShuffleBlockFetcherIterator( private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { - case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case blockId: ShuffleBlockIdBase => + throw new FetchFailedException( + address, blockId.shuffleId, blockId.mapId, blockId.reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 50b8ea754d..d1af26ad8b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -26,8 +26,9 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ContinuousShuffleBlockId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -298,4 +299,37 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("fetch contiguous partitions") { + val rpcEnv = createRpcEnv("test") + val serializer = new KryoSerializer(conf) + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + assert(tracker.containsShuffle(10)) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size2000 = MapStatus.decompressSize(MapStatus.compressSize(2000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(1000L, 10000L, 2000L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(10000L, 2000L, 1000L))) + val statuses1 = tracker.getMapSizesByExecutorId(10, 0, 2, serializerRelocatable = true) + assert(statuses1.toSet === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ContinuousShuffleBlockId(10, 0, 0, 2), size1000 + size10000))), + (BlockManagerId("b", "hostB", 1000), + ArrayBuffer((ContinuousShuffleBlockId(10, 1, 0, 2), size10000 + size2000)))) + .toSet) + val statuses2 = tracker.getMapSizesByExecutorId(10, 2, 3, serializerRelocatable = true) + assert(statuses2.toSet === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 2), size2000))), + (BlockManagerId("b", "hostB", 1000), + ArrayBuffer((ShuffleBlockId(10, 1, 2), size1000)))) + .toSet) + assert(0 == tracker.getNumCachedSerializedBroadcast) + tracker.stop() + rpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index dba1172d5f..3a567a1aed 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -101,7 +101,9 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, + serializer.supportsRelocationOfSerializedObjects)) + .thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index ff4755833a..7acfba6530 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -51,7 +51,7 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } - test("shuffle") { + test("shuffle block") { val id = ShuffleBlockId(1, 2, 3) assertSame(id, ShuffleBlockId(1, 2, 3)) assertDifferent(id, ShuffleBlockId(3, 2, 3)) @@ -64,6 +64,20 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("continuous shuffle block") { + val id = ContinuousShuffleBlockId(1, 2, 3, 4) + assertSame(id, ContinuousShuffleBlockId(1, 2, 3, 4)) + assertDifferent(id, ContinuousShuffleBlockId(3, 2, 3, 4)) + assert(id.name === "shuffle_1_2_3_4") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.reduceId === 3) + assert(id.numBlocks === 4) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("shuffle data") { val id = ShuffleDataBlockId(4, 5, 6) assertSame(id, ShuffleDataBlockId(4, 5, 6))