Skip to content

Commit e934c43

Browse files
Dylan Wonganishshri-db
authored andcommitted
[SPARK-54585][SS] Fix State Store rollback when thread is in interrupted state
### What changes were proposed in this pull request? 1. Modifies `ChecksumCancellableFSDataOutputStream.cancel()` to cancel both the main stream and checksum stream synchronously instead of using Futures with awaitResult. 2. Moves `changelogWriter.foreach(_.abort())` and `changelogWriter = None` in a try finally block within `RocksDB.rollback()`. ### Why are the changes needed? For fix 1: When cancel() is called while the thread is in an interrupted state (e.g., during task cancellation), the previous implementation would fail. The code submitted Futures to cancel each stream, then called awaitResult() to wait for completion. However, awaitResult() checks the thread's interrupt flag and throws InterruptedException immediately if the thread is interrupted. For fix 2: Consider the case where `abort()` is called on `RocksDBStateStoreProvider`. This calls `rollback()` on the `RocksDB` instance, which in turn calls `changelogWriter.foreach(_.abort())` and then sets `changelogWriter = None`. However, if `changelogWriter.abort()` throws an exception, the finally block still sets `backingFileStream` and `compressedStream` to `null`. The exception propagates, and we never reach the line that sets `changelogWriter = None`. This leaves the RocksDB instance in an inconsistent state: - changelogWriter = Some(changelogWriterWeAttemptedToAbort) - changelogWriterWeAttemptedToAbort.backingFileStream = null - changelogWriterWeAttemptedToAbort.compressedStream = null Now consider calling `RocksDB.load()` again. This calls `replayChangelog()`, which calls `put()`, which calls `changelogWriter.put()`. At this point, the assertion `assert(compressedStream != null)` fails, causing an exception while loading the StateStore. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added test `"SPARK-54585: Interrupted task calling rollback does not throw an exception"` which simulates the case when a thread in the interrupted state and begins a rollback ### Was this patch authored or co-authored using generative AI tooling? No Closes #53313 from dylanwong250/SPARK-54585. Authored-by: Dylan Wong <[email protected]> Signed-off-by: Anish Shrigondekar <[email protected]>
1 parent 7df7dad commit e934c43

File tree

3 files changed

+84
-19
lines changed

3 files changed

+84
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.apache.spark.internal.LogKeys.{CHECKSUM, NUM_BYTES, PATH, TIMEOUT}
3939
import org.apache.spark.sql.errors.QueryExecutionErrors
4040
import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager.CancellableFSDataOutputStream
4141
import org.apache.spark.util.ThreadUtils
42+
import org.apache.spark.util.Utils
4243

4344
/** Information about the creator of the checksum file. Useful for debugging */
4445
case class ChecksumFileCreatorInfo(
@@ -500,16 +501,14 @@ class ChecksumCancellableFSDataOutputStream(
500501
@volatile private var closed = false
501502

502503
override def cancel(): Unit = {
503-
val mainFuture = Future {
504+
// Cancel both streams synchronously rather than using futures. If the current thread is
505+
// interrupted and we call this method, scheduling work on futures would immediately throw
506+
// InterruptedException leaving the streams in an inconsistent state.
507+
Utils.tryWithSafeFinally {
504508
mainStream.cancel()
505-
}(uploadThreadPool)
506-
507-
val checksumFuture = Future {
509+
} {
508510
checksumStream.cancel()
509-
}(uploadThreadPool)
510-
511-
awaitResult(mainFuture, Duration.Inf)
512-
awaitResult(checksumFuture, Duration.Inf)
511+
}
513512
}
514513

515514
override def close(): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,17 +1650,23 @@ class RocksDB(
16501650
* Drop uncommitted changes, and roll back to previous version.
16511651
*/
16521652
def rollback(): Unit = {
1653-
numKeysOnWritingVersion = numKeysOnLoadedVersion
1654-
numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion
1655-
loadedVersion = -1L
1656-
lastCommitBasedStateStoreCkptId = None
1657-
lastCommittedStateStoreCkptId = None
1658-
loadedStateStoreCkptId = None
1659-
sessionStateStoreCkptId = None
1660-
lineageManager.clear()
1661-
changelogWriter.foreach(_.abort())
1662-
// Make sure changelogWriter gets recreated next time.
1663-
changelogWriter = None
1653+
logInfo(
1654+
log"Rolling back uncommitted changes on version ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
1655+
try {
1656+
numKeysOnWritingVersion = numKeysOnLoadedVersion
1657+
numInternalKeysOnWritingVersion = numInternalKeysOnLoadedVersion
1658+
loadedVersion = -1L
1659+
lastCommitBasedStateStoreCkptId = None
1660+
lastCommittedStateStoreCkptId = None
1661+
loadedStateStoreCkptId = None
1662+
sessionStateStoreCkptId = None
1663+
lineageManager.clear()
1664+
changelogWriter.foreach(_.abort())
1665+
} finally {
1666+
// Make sure changelogWriter gets recreated next time even if the changelogWriter aborts with
1667+
// an exception.
1668+
changelogWriter = None
1669+
}
16641670
logInfo(log"Rolled back to ${MDC(LogKeys.VERSION_NUM, loadedVersion)}")
16651671
}
16661672

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest
7070

7171
implicit def toArray(str: String): Array[Byte] = if (str != null) str.getBytes else null
7272

73+
implicit def toStr(bytes: Array[Byte]): String = if (bytes != null) new String(bytes) else null
74+
75+
def toStr(kv: ByteArrayPair): (String, String) = (toStr(kv.key), toStr(kv.value))
76+
7377
case class FailureConf(ifEnableStateStoreCheckpointIds: Boolean, fileType: String) {
7478
override def toString: String = {
7579
s"ifEnableStateStoreCheckpointIds = $ifEnableStateStoreCheckpointIds, " +
@@ -824,6 +828,62 @@ class RocksDBCheckpointFailureInjectionSuite extends StreamTest
824828
}
825829
}
826830

831+
/**
832+
* Test that verifies that when a task is interrupted, the store's rollback() method does not
833+
* throw an exception and the store can still be used after the rollback.
834+
*/
835+
test("SPARK-54585: Interrupted task calling rollback does not throw an exception") {
836+
val hadoopConf = new Configuration()
837+
hadoopConf.set(
838+
STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key,
839+
fileManagerClassName
840+
)
841+
withTempDirAllowFailureInjection { (remoteDir, _) =>
842+
val sqlConf = new SQLConf()
843+
sqlConf.setConfString("spark.sql.streaming.checkpoint.fileChecksum.enabled", "true")
844+
val rocksdbChangelogCheckpointingConfKey =
845+
RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + ".changelogCheckpointing.enabled"
846+
sqlConf.setConfString(rocksdbChangelogCheckpointingConfKey, "true")
847+
val conf = RocksDBConf(StateStoreConf(sqlConf))
848+
849+
withDB(
850+
remoteDir.getAbsolutePath,
851+
version = 0,
852+
conf = conf,
853+
hadoopConf = hadoopConf
854+
) { db =>
855+
db.put("key0", "value0")
856+
val checkpointId1 = commitAndGetCheckpointId(db)
857+
858+
db.load(1, checkpointId1)
859+
db.put("key1", "value1")
860+
val checkpointId2 = commitAndGetCheckpointId(db)
861+
862+
db.load(2, checkpointId2)
863+
db.put("key2", "value2")
864+
865+
// Simulate what happens when a task is killed, the thread's interrupt flag is set.
866+
// This replicates the scenario where TaskContext.markTaskFailed() is called and
867+
// the task failure listener invokes RocksDBStateStore.abort() -> rollback().
868+
Thread.currentThread().interrupt()
869+
870+
// rollback() should not throw an exception
871+
db.rollback()
872+
873+
// Clear the interrupt flag for subsequent operations
874+
Thread.interrupted()
875+
876+
// Reload the store and insert a new value
877+
db.load(2, checkpointId2)
878+
db.put("key3", "value3")
879+
880+
// Verify the store has the correct values
881+
assert(db.iterator().map(toStr).toSet ===
882+
Set(("key0", "value0"), ("key1", "value1"), ("key3", "value3")))
883+
}
884+
}
885+
}
886+
827887
def commitAndGetCheckpointId(db: RocksDB): Option[String] = {
828888
val (v, ci) = db.commit()
829889
ci.stateStoreCkptId

0 commit comments

Comments
 (0)