Skip to content

Commit 8b072b3

Browse files
committed
Replace SparkEnv.get.conf in ShuffleBlockFetcherIterator with optional FallbackStorage instance, add test
1 parent 14a2c76 commit 8b072b3

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark._
2323
import org.apache.spark.internal.{config, Logging}
2424
import org.apache.spark.io.CompressionCodec
2525
import org.apache.spark.serializer.SerializerManager
26-
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
26+
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, FallbackStorage, ShuffleBlockFetcherIterator}
2727
import org.apache.spark.util.CompletionIterator
2828
import org.apache.spark.util.collection.ExternalSorter
2929

@@ -88,7 +88,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
8888
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
8989
SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
9090
readMetrics,
91-
fetchContinuousBlocksInBatch).toCompletionIterator
91+
fetchContinuousBlocksInBatch,
92+
FallbackStorage.getFallbackStorage(SparkEnv.get.conf)).toCompletionIterator
9293

9394
val serializerInstance = dep.serializer.newInstance()
9495

core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging {
8989
}
9090
}
9191

92+
/**
93+
* Read a ManagedBuffer.
94+
*/
95+
def read(blockId: BlockId): ManagedBuffer = FallbackStorage.read(conf, blockId)
96+
9297
def exists(shuffleId: Int, filename: String): Boolean = {
9398
val hash = JavaUtils.nonNegativeHash(filename)
9499
fallbackFileSystem.exists(new Path(fallbackPath, s"$appId/$shuffleId/$hash/$filename"))

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ import scala.util.{Failure, Success}
3232
import io.netty.util.internal.OutOfDirectMemoryError
3333
import org.roaringbitmap.RoaringBitmap
3434

35-
import org.apache.spark.{MapOutputTracker, SparkEnv, SparkException, TaskContext}
35+
import org.apache.spark.{MapOutputTracker, SparkException, TaskContext}
3636
import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
3737
import org.apache.spark.errors.SparkCoreErrors
38-
import org.apache.spark.internal.{config, Logging}
38+
import org.apache.spark.internal.Logging
3939
import org.apache.spark.internal.LogKeys._
4040
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
4141
import org.apache.spark.network.shuffle._
@@ -101,6 +101,7 @@ final class ShuffleBlockFetcherIterator(
101101
checksumAlgorithm: String,
102102
shuffleMetrics: ShuffleReadMetricsReporter,
103103
doBatchFetch: Boolean,
104+
fallbackStorage: Option[FallbackStorage],
104105
clock: Clock = new SystemClock())
105106
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
106107

@@ -980,10 +981,9 @@ final class ShuffleBlockFetcherIterator(
980981
log"${MDC(MAX_ATTEMPTS, maxAttemptsOnNettyOOM)} retries due to Netty OOM"
981982
logError(logMessage)
982983
errorMsg = logMessage.message
983-
} else if (
984-
SparkEnv.get.conf.get(config.STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH).isDefined) {
984+
} else if (fallbackStorage.isDefined) {
985985
try {
986-
val buf = FallbackStorage.read(SparkEnv.get.conf, blockId)
986+
val buf = fallbackStorage.get.read(blockId)
987987
results.put(SuccessFetchResult(blockId, mapIndex, address, buf.size(), buf,
988988
isNetworkReqDone = false))
989989
result = null

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
196196
checksumEnabled: Boolean = true,
197197
checksumAlgorithm: String = "ADLER32",
198198
shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
199-
doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
199+
doBatchFetch: Boolean = false,
200+
fallbackStorage: Option[FallbackStorage] = None): ShuffleBlockFetcherIterator = {
200201
val tContext = taskContext.getOrElse(TaskContext.empty())
201202
new ShuffleBlockFetcherIterator(
202203
tContext,
@@ -222,7 +223,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
222223
checksumEnabled,
223224
checksumAlgorithm,
224225
shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()),
225-
doBatchFetch)
226+
doBatchFetch,
227+
fallbackStorage)
226228
}
227229
// scalastyle:on argcount
228230

@@ -1127,6 +1129,54 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
11271129
assert(e.getMessage.contains("fetch failed after 10 retries due to Netty OOM"))
11281130
}
11291131

1132+
test("SPARK-XXXXX: missing blocks attempts to read from fallback storage") {
1133+
val blockManager = createMockBlockManager()
1134+
1135+
configureMockTransfer(Map.empty)
1136+
val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2)
1137+
val blockId = ShuffleBlockId(0, 0, 0)
1138+
val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
1139+
(remoteBmId, Seq((blockId, 1L, 0)))
1140+
)
1141+
1142+
// iterator with no FallbackStorage cannot find the block
1143+
{
1144+
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress = blocksByAddress)
1145+
val e = intercept[FetchFailedException] {
1146+
iterator.next()
1147+
}
1148+
assert(e.getCause != null)
1149+
assert(e.getCause.isInstanceOf[BlockNotFoundException])
1150+
assert(e.getCause.getMessage.contains("Block shuffle_0_0_0 not found"))
1151+
}
1152+
1153+
// iterator with FallbackStorage that does not store the block cannot find it either
1154+
val fallbackStorage = mock(classOf[FallbackStorage])
1155+
1156+
{
1157+
when(fallbackStorage.read(ShuffleBlockId(0, 0, 1))).thenReturn(new TestManagedBuffer(127))
1158+
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress = blocksByAddress,
1159+
fallbackStorage = Some(fallbackStorage))
1160+
val e = intercept[FetchFailedException] {
1161+
iterator.next()
1162+
}
1163+
assert(e.getCause != null)
1164+
assert(e.getCause.isInstanceOf[BlockNotFoundException])
1165+
assert(e.getCause.getMessage.contains("Block shuffle_0_0_0 not found"))
1166+
}
1167+
1168+
// iterator with FallbackStorage that stores the block can find it
1169+
{
1170+
when(fallbackStorage.read(ShuffleBlockId(0, 0, 0))).thenReturn(new TestManagedBuffer(127))
1171+
val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress = blocksByAddress,
1172+
fallbackStorage = Some(fallbackStorage))
1173+
assert(iterator.hasNext)
1174+
val (id, _) = iterator.next()
1175+
assert(id === ShuffleBlockId(0, 0, 0))
1176+
assert(!iterator.hasNext)
1177+
}
1178+
}
1179+
11301180
/**
11311181
* Prepares the transfer to trigger success for all the blocks present in blockChunks. It will
11321182
* trigger failure of block which is not part of blockChunks.

0 commit comments

Comments
 (0)