Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51596][SS] Fix concurrent StateStoreProvider maintenance and closing #50391

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,11 @@ trait StateStoreProvider {
*/
def stateStoreId: StateStoreId

/** Called when the provider instance is unloaded from the executor */
/**
* Called when the provider instance is unloaded from the executor
* WARNING: IF PROVIDER FROM [[StateStore.loadedProviders]],
* CLOSE MUST ONLY BE CALLED FROM MAINTENANCE THREAD!
*/
def close(): Unit

/**
Expand Down Expand Up @@ -960,14 +964,45 @@ object StateStore extends Logging {

val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
providerIdsToUnload.foreach(unload(_))
providerIdsToUnload.foreach(id => {
loadedProviders.remove(id).foreach( provider => {
// Trigger maintenance thread to immediately do maintenance on and close the provider.
// Doing maintenance first allows us to do maintenance for a constantly-moving state
// store.
logInfo(log"Task thread trigger maintenance on " +
log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER, id)}")
doMaintenanceOnProvider(id, provider, alreadyRemovedFromLoadedProviders = true)
})
})
provider
}
}

/** Unload a state store provider */
def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
loadedProviders.remove(storeProviderId).foreach(_.close())
/**
* Unload a state store provider.
* If alreadyRemovedFromLoadedProviders is None, provider will be
* removed from loadedProviders and closed.
* If alreadyRemovedFromLoadedProviders is Some, provider will be closed
* using passed in provider.
* WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD!
*/
def unload(storeProviderId: StateStoreProviderId,
alreadyRemovedStoreFromLoadedProviders: Option[StateStoreProvider] = None): Unit = {
var toCloseProviders: List[StateStoreProvider] = Nil

alreadyRemovedStoreFromLoadedProviders match {
case Some(provider) =>
toCloseProviders = provider :: toCloseProviders
case None =>
// Copy provider to a local list so we can release loadedProviders lock when closing.
loadedProviders.synchronized {
loadedProviders.remove(storeProviderId).foreach { provider =>
toCloseProviders = provider :: toCloseProviders
}
}
}

toCloseProviders.foreach(_.close())
}

/** Unload all state store providers: unit test purpose */
Expand Down Expand Up @@ -1038,6 +1073,14 @@ object StateStore extends Logging {
}
}

// Block until we can process this partition
private def awaitProcessThisPartition(id: StateStoreProviderId): Unit =
maintenanceThreadPoolLock.synchronized {
while (!processThisPartition(id)) {
maintenanceThreadPoolLock.wait()
}
}

/**
* Execute background maintenance task in all the loaded store providers if they are still
* the active instances according to the coordinator.
Expand All @@ -1051,47 +1094,7 @@ object StateStore extends Logging {
loadedProviders.toSeq
}.foreach { case (id, provider) =>
if (processThisPartition(id)) {
maintenanceThreadPool.execute(() => {
val startTime = System.currentTimeMillis()
try {
provider.doMaintenance()
if (!verifyIfStoreInstanceActive(id)) {
unload(id)
logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}")
}
} catch {
case NonFatal(e) =>
logWarning(log"Error managing ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}, " +
log"unloading state store provider", e)
// When we get a non-fatal exception, we just unload the provider.
//
// By not bubbling the exception to the maintenance task thread or the query execution
// thread, it's possible for a maintenance thread pool task to continue failing on
// the same partition. Additionally, if there is some global issue that will cause
// all maintenance thread pool tasks to fail, then bubbling the exception and
// stopping the pool is faster than waiting for all tasks to see the same exception.
//
// However, we assume that repeated failures on the same partition and global issues
// are rare. The benefit to unloading just the partition with an exception is that
// transient issues on a given provider do not affect any other providers; so, in
// most cases, this should be a more performant solution.
unload(id)
} finally {
val duration = System.currentTimeMillis() - startTime
val logMsg =
log"Finished maintenance task for " +
log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" +
log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}\n"
if (duration > 5000) {
logInfo(logMsg)
} else {
logDebug(logMsg)
}
maintenanceThreadPoolLock.synchronized {
maintenancePartitions.remove(id)
}
}
})
doMaintenanceOnProvider(id, provider)
} else {
logInfo(log"Not processing partition ${MDC(LogKeys.PARTITION_ID, id)} " +
log"for maintenance because it is currently " +
Expand All @@ -1100,6 +1103,69 @@ object StateStore extends Logging {
}
}

private def doMaintenanceOnProvider(id: StateStoreProviderId, provider: StateStoreProvider,
alreadyRemovedFromLoadedProviders: Boolean = false): Unit = {
maintenanceThreadPool.execute(() => {
val startTime = System.currentTimeMillis()
if (alreadyRemovedFromLoadedProviders) {
// If provider is already removed from loadedProviders, we MUST process
// this partition to close it, so we block until we can.
awaitProcessThisPartition(id)
}
val awaitingPartitionDuration = System.currentTimeMillis() - startTime
try {
provider.doMaintenance()
// If shouldRemoveFromLoadedProviders is false, we don't need to verify
// with the coordinator as we know it definitely should be unloaded.
if (alreadyRemovedFromLoadedProviders || !verifyIfStoreInstanceActive(id)) {
if (alreadyRemovedFromLoadedProviders) {
unload(id, Some(provider))
} else {
unload(id)
}
logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}")
}
} catch {
case NonFatal(e) =>
logWarning(log"Error managing ${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}, " +
log"unloading state store provider", e)
// When we get a non-fatal exception, we just unload the provider.
//
// By not bubbling the exception to the maintenance task thread or the query execution
// thread, it's possible for a maintenance thread pool task to continue failing on
// the same partition. Additionally, if there is some global issue that will cause
// all maintenance thread pool tasks to fail, then bubbling the exception and
// stopping the pool is faster than waiting for all tasks to see the same exception.
//
// However, we assume that repeated failures on the same partition and global issues
// are rare. The benefit to unloading just the partition with an exception is that
// transient issues on a given provider do not affect any other providers; so, in
// most cases, this should be a more performant solution.
if (alreadyRemovedFromLoadedProviders) {
unload(id, Some(provider))
} else {
unload(id)
}
} finally {
val duration = System.currentTimeMillis() - startTime
val logMsg =
log"Finished maintenance task for " +
log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" +
log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}" +
log" and awaiting_partition_time=" +
log"${MDC(LogKeys.TIME_UNITS, awaitingPartitionDuration)}\n"
if (duration > 5000) {
logInfo(logMsg)
} else {
logDebug(logMsg)
}
maintenanceThreadPoolLock.synchronized {
maintenancePartitions.remove(id)
}
}
})
}

private def reportActiveStoreInstance(
storeProviderId: StateStoreProviderId,
otherProviderIds: Seq[StateStoreProviderId]): Seq[StateStoreProviderId] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,38 @@ private object FakeStateStoreProviderWithMaintenanceError {
val errorOnMaintenance = new AtomicBoolean(false)
}

class FakeStateStoreProviderTracksCloseThread extends StateStoreProvider {
import FakeStateStoreProviderTracksCloseThread._
private var id: StateStoreId = null

override def init(
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
useColumnFamilies: Boolean,
storeConfs: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false,
stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
id = stateStoreId
}

override def stateStoreId: StateStoreId = id

override def close(): Unit = {
closeThreadNames = Thread.currentThread.getName :: closeThreadNames
}

override def getStore(version: Long, uniqueId: Option[String]): StateStore = null

override def doMaintenance(): Unit = {}
}

private object FakeStateStoreProviderTracksCloseThread {
var closeThreadNames: List[String] = Nil
}

@ExtendedSQLTest
class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
with BeforeAndAfter {
Expand Down Expand Up @@ -563,8 +595,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
StateStore.get(storeProviderId2, keySchema, valueSchema,
NoPrefixKeyStateEncoderSpec(keySchema),
0, None, None, useColumnFamilies = false, storeConf, hadoopConf)
assert(!StateStore.isLoaded(storeProviderId1))
assert(StateStore.isLoaded(storeProviderId2))
// Close runs asynchronously, so we need to call eventually with a small timeout
eventually(timeout(5.seconds)) {
assert(!StateStore.isLoaded(storeProviderId1))
assert(StateStore.isLoaded(storeProviderId2))
}
}
}

Expand Down Expand Up @@ -1082,7 +1117,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
}

abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
extends StateStoreCodecsTest with PrivateMethodTester {
extends StateStoreCodecsTest with PrivateMethodTester with BeforeAndAfter {
import StateStoreTestsHelper._

type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
Expand Down Expand Up @@ -1718,6 +1753,60 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(encoderSpec == deserializedEncoderSpec)
}

test("SPARK-51596: unloading only occurs on maintenance thread but occurs promptly") {
// Reset closeThreadNames
FakeStateStoreProviderTracksCloseThread.closeThreadNames = Nil

val sqlConf = getDefaultSQLConf(
SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get
)
// Make maintenance interval very large (30s) so that task thread runs before maintenance.
sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L)
// Use the `MaintenanceErrorOnCertainPartitionsProvider` to run the test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - this seems copied from below -- this should now be FakeStateStoreProviderTracksCloseThread right?

sqlConf.setConf(
SQLConf.STATE_STORE_PROVIDER_CLASS,
classOf[FakeStateStoreProviderTracksCloseThread].getName
)

val conf = new SparkConf().setMaster("local").setAppName("test")

withSpark(SparkContext.getOrCreate(conf)) { sc =>
withCoordinatorRef(sc) { coordinatorRef =>
val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-48997"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-48997"
val rootLocation = s"${Utils.createTempDir().getAbsolutePath}/spark-51596"

val providerId =
StateStoreProviderId(StateStoreId(rootLocation, 0, 0), UUID.randomUUID)
val providerId2 =
StateStoreProviderId(StateStoreId(rootLocation, 0, 1), UUID.randomUUID)

// Create provider to start the maintenance task + pool
StateStore.get(
providerId,
keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
0, None, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration()
)

// Report instance active on another executor
coordinatorRef.reportActiveInstance(providerId, "otherhost", "otherexec", Seq.empty)

// Load another provider to trigger task unload
StateStore.get(
providerId2,
keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
0, None, None, useColumnFamilies = false, new StateStoreConf(sqlConf), new Configuration()
)

// Wait for close to occur. Timeout is less than maintenance interval,
// so should only close by task triggering.
eventually(timeout(5.seconds)) {
assert(FakeStateStoreProviderTracksCloseThread.closeThreadNames.size == 1)
FakeStateStoreProviderTracksCloseThread.closeThreadNames.foreach { name =>
assert(name.contains("state-store-maintenance-thread"))}
}
}
}
}

/** Return a new provider with a random id */
def newStoreProvider(): ProviderClass

Expand Down