Skip to content

Commit 06674ba

Browse files
committed
[SPARK-53942][SS] Support changing stateless shuffle partitions upon restart of streaming query
### What changes were proposed in this pull request? This PR proposes to support changing stateless shuffle partitions upon restart of streaming query. We don't introduce a new config or se - users can simply do the below to change the number of shuffle partitions: * stop the query * change the value of `spark.sql.shuffle.partitions` * restart the query to take effect Note that state partitions are still fixed and be unchanged from this. That said, the value of `spark.sql.shuffle.partitions` for batch 0 will be the number of state partitions and does not change even if the value of the config has changed upon restart. As an implementation detail, this PR adds a new "internal" SQL config `spark.sql.streaming.internal.stateStore.partitions` to distinguish stateless shuffle partitions vs stateful shuffle partitions. Unlike other internal configs where we still expect someone (admin?) to use them, this config is NOT meant to be an user facing one and no one should set this up directly. We add this config to implement trick for compatibility, nothing else. We don't support compatibility of this config and there's no promise the config to be available in future. This PR states this as a WARN in the config description. That said, the value of the new config is expected to be inherited from `spark.sql.shuffle.partitions` assuming no one will set this up directly. To support compatibility, we employ a trick into offset log - for stateful shuffle partitions, we refer it to `spark.sql.streaming.internal.stateStore.partitions` in session config, and we keep using `spark.sql.shuffle.partitions` in offset log. We handle rebinding between two configs to leave the persistent layer unchanged. This way we can support the query to be both upgraded/downgraded. ### Why are the changes needed? Whenever there is need to change the parallelism of the processing e.g. input volume being changed over time, the size of static table changed over time, skew in stream-static join (though AQE may help resolving this a bit), the only official approach to deal with this was to discard checkpoint and start a new query, implying they have to do full backfill. (For workloads with FEB sink, advanced (and adventurous) users could change the config in their user function, but that's arguably a hack.) Having to discard checkpoint is a one of major pains to use Structured Streaming, and we want to address one of the known reasons. ### Does this PR introduce _any_ user-facing change? Yes, user can change shuffle partitions for stateless operators upon restart, via changing the config `spark.sql.shuffle.partitions`. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52645 from HeartSaVioR/WIP-change-stateless-shuffle-partitions-in-streaming-query. Authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent a1f3dcb commit 06674ba

File tree

15 files changed

+289
-23
lines changed

15 files changed

+289
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -877,9 +877,7 @@ object SQLConf {
877877
.createOptional
878878

879879
val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions")
880-
.doc("The default number of partitions to use when shuffling data for joins or aggregations. " +
881-
"Note: For structured streaming, this configuration cannot be changed between query " +
882-
"restarts from the same checkpoint location.")
880+
.doc("The default number of partitions to use when shuffling data for joins or aggregations.")
883881
.version("1.1.0")
884882
.intConf
885883
.checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive")
@@ -2627,6 +2625,22 @@ object SQLConf {
26272625
.checkValue(k => k >= 0, "Must be greater than or equal to 0")
26282626
.createWithDefault(5)
26292627

2628+
val STATEFUL_SHUFFLE_PARTITIONS_INTERNAL =
2629+
buildConf("spark.sql.streaming.internal.stateStore.partitions")
2630+
.doc("WARN: This config is used internally and is not intended to be user-facing. This " +
2631+
"config can be removed without support of compatibility in any time. " +
2632+
"DO NOT USE THIS CONFIG DIRECTLY AND USE THE CONFIG `spark.sql.shuffle.partitions`. " +
2633+
"The default number of partitions to use when shuffling data for stateful operations. " +
2634+
"If not specified, this config picks up the value of `spark.sql.shuffle.partitions`. " +
2635+
"Note: For structured streaming, this configuration cannot be changed between query " +
2636+
"restarts from the same checkpoint location.")
2637+
.internal()
2638+
.intConf
2639+
.checkValue(_ > 0,
2640+
"The value of spark.sql.streaming.internal.stateStore.partitions must be a positive " +
2641+
"integer.")
2642+
.createOptional
2643+
26302644
val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
26312645
buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
26322646
.internal()

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
package org.apache.spark.sql.execution.streaming.checkpointing
1919

20+
import scala.language.existentials
21+
2022
import org.json4s.{Formats, NoTypeHints}
2123
import org.json4s.jackson.Serialization
2224

2325
import org.apache.spark.internal.Logging
2426
import org.apache.spark.internal.LogKeys.{CONFIG, DEFAULT_VALUE, NEW_VALUE, OLD_VALUE, TIP}
27+
import org.apache.spark.internal.config.ConfigEntry
2528
import org.apache.spark.io.CompressionCodec
2629
import org.apache.spark.sql.RuntimeConfig
2730
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream}
@@ -85,6 +88,11 @@ object OffsetSeq {
8588
* @param batchTimestampMs: The current batch processing timestamp.
8689
* Time unit: milliseconds
8790
* @param conf: Additional conf_s to be persisted across batches, e.g. number of shuffle partitions.
91+
* CAVEAT: This does not apply the logic we handle in [[OffsetSeqMetadata]] object, e.g.
92+
* deducing the default value of SQL config if the entry does not exist in the offset log,
93+
* resolving the re-bind of config key (the config key in offset log is not same with the
94+
* actual key in session), etc. If you need to get the value with applying the logic, use
95+
* [[OffsetSeqMetadata.readValue()]].
8896
*/
8997
case class OffsetSeqMetadata(
9098
batchWatermarkMs: Long = 0,
@@ -101,13 +109,35 @@ object OffsetSeqMetadata extends Logging {
101109
* log in the checkpoint position.
102110
*/
103111
private val relevantSQLConfs = Seq(
104-
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
112+
STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
105113
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION,
106114
STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC,
107115
STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
108116
PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN, STREAMING_STATE_STORE_ENCODING_FORMAT
109117
)
110118

119+
/**
120+
* This is an extension of `relevantSQLConfs`. The characteristic is the same, but we persist the
121+
* value of config A as config B in offset log. This exists for compatibility purpose e.g. if
122+
* user upgrades Spark and runs a streaming query, but has to downgrade due to some issues.
123+
*
124+
* A config should be only bound to either `relevantSQLConfs` or `rebindSQLConfs` (key or value).
125+
*/
126+
private val rebindSQLConfsSessionToOffsetLog: Map[ConfigEntry[_], ConfigEntry[_]] = {
127+
Map(
128+
// TODO: The proper way to handle this is to make the number of partitions in the state
129+
// metadata as the source of truth, but it requires major changes if we want to take care
130+
// of compatibility.
131+
STATEFUL_SHUFFLE_PARTITIONS_INTERNAL -> SHUFFLE_PARTITIONS
132+
)
133+
}
134+
135+
/**
136+
* Reversed index of `rebindSQLConfsSessionToOffsetLog`.
137+
*/
138+
private val rebindSQLConfsOffsetLogToSession: Map[ConfigEntry[_], ConfigEntry[_]] =
139+
rebindSQLConfsSessionToOffsetLog.map { case (k, v) => (v, k) }.toMap
140+
111141
/**
112142
* Default values of relevant configurations that are used for backward compatibility.
113143
* As new configurations are added to the metadata, existing checkpoints may not have those
@@ -132,56 +162,80 @@ object OffsetSeqMetadata extends Logging {
132162
STREAMING_STATE_STORE_ENCODING_FORMAT.key -> "unsaferow"
133163
)
134164

165+
def readValue[T](metadataLog: OffsetSeqMetadata, confKey: ConfigEntry[T]): String = {
166+
readValueOpt(metadataLog, confKey).getOrElse(confKey.defaultValueString)
167+
}
168+
169+
def readValueOpt[T](
170+
metadataLog: OffsetSeqMetadata,
171+
confKey: ConfigEntry[T]): Option[String] = {
172+
val actualKey = if (rebindSQLConfsSessionToOffsetLog.contains(confKey)) {
173+
rebindSQLConfsSessionToOffsetLog(confKey)
174+
} else confKey
175+
176+
metadataLog.conf.get(actualKey.key).orElse(relevantSQLConfDefaultValues.get(actualKey.key))
177+
}
178+
135179
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
136180

137181
def apply(
138182
batchWatermarkMs: Long,
139183
batchTimestampMs: Long,
140184
sessionConf: RuntimeConfig): OffsetSeqMetadata = {
141185
val confs = relevantSQLConfs.map { conf => conf.key -> sessionConf.get(conf.key) }.toMap
142-
OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs)
186+
val confsFromRebind = rebindSQLConfsSessionToOffsetLog.map {
187+
case (confInSession, confInOffsetLog) =>
188+
confInOffsetLog.key -> sessionConf.get(confInSession.key)
189+
}.toMap
190+
OffsetSeqMetadata(batchWatermarkMs, batchTimestampMs, confs++ confsFromRebind)
143191
}
144192

145193
/** Set the SparkSession configuration with the values in the metadata */
146194
def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: SQLConf): Unit = {
147-
val configs = sessionConf.getAllConfs
148-
OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey =>
149-
150-
metadata.conf.get(confKey) match {
195+
def setOneSessionConf(confKeyInOffsetLog: String, confKeyInSession: String): Unit = {
196+
metadata.conf.get(confKeyInOffsetLog) match {
151197

152198
case Some(valueInMetadata) =>
153199
// Config value exists in the metadata, update the session config with this value
154-
val optionalValueInSession = sessionConf.getConfString(confKey, null)
200+
val optionalValueInSession = sessionConf.getConfString(confKeyInSession, null)
155201
if (optionalValueInSession != null && optionalValueInSession != valueInMetadata) {
156-
logWarning(log"Updating the value of conf '${MDC(CONFIG, confKey)}' in current " +
157-
log"session from '${MDC(OLD_VALUE, optionalValueInSession)}' " +
202+
logWarning(log"Updating the value of conf '${MDC(CONFIG, confKeyInSession)}' in " +
203+
log"current session from '${MDC(OLD_VALUE, optionalValueInSession)}' " +
158204
log"to '${MDC(NEW_VALUE, valueInMetadata)}'.")
159205
}
160-
sessionConf.setConfString(confKey, valueInMetadata)
206+
sessionConf.setConfString(confKeyInSession, valueInMetadata)
161207

162208
case None =>
163209
// For backward compatibility, if a config was not recorded in the offset log,
164210
// then either inject a default value (if specified in `relevantSQLConfDefaultValues`) or
165211
// let the existing conf value in SparkSession prevail.
166-
relevantSQLConfDefaultValues.get(confKey) match {
212+
relevantSQLConfDefaultValues.get(confKeyInOffsetLog) match {
167213

168214
case Some(defaultValue) =>
169-
sessionConf.setConfString(confKey, defaultValue)
170-
logWarning(log"Conf '${MDC(CONFIG, confKey)}' was not found in the offset log, " +
171-
log"using default value '${MDC(DEFAULT_VALUE, defaultValue)}'")
215+
sessionConf.setConfString(confKeyInSession, defaultValue)
216+
logWarning(log"Conf '${MDC(CONFIG, confKeyInSession)}' was not found in the offset " +
217+
log"log, using default value '${MDC(DEFAULT_VALUE, defaultValue)}'")
172218

173219
case None =>
174-
val value = sessionConf.getConfString(confKey, null)
220+
val value = sessionConf.getConfString(confKeyInSession, null)
175221
val valueStr = if (value != null) {
176222
s" Using existing session conf value '$value'."
177223
} else {
178224
" No value set in session conf."
179225
}
180-
logWarning(log"Conf '${MDC(CONFIG, confKey)}' was not found in the offset log. " +
181-
log"${MDC(TIP, valueStr)}")
182-
226+
logWarning(log"Conf '${MDC(CONFIG, confKeyInSession)}' was not found in the " +
227+
log"offset log. ${MDC(TIP, valueStr)}")
183228
}
184229
}
185230
}
231+
232+
OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey =>
233+
setOneSessionConf(confKey, confKey)
234+
}
235+
236+
rebindSQLConfsOffsetLogToSession.foreach {
237+
case (confInOffsetLog, confInSession) =>
238+
setOneSessionConf(confInOffsetLog.key, confInSession.key)
239+
}
186240
}
187241
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ class IncrementalExecution(
105105

106106
private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()
107107

108-
private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
108+
private[sql] val numStateStores = OffsetSeqMetadata.readValueOpt(offsetSeqMetadata,
109+
SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL)
109110
.map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
110111
.getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
111112

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/StreamExecution.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ abstract class StreamExecution(
155155
protected def sources: Seq[SparkDataStream]
156156

157157
/** Isolated spark session to run the batches with. */
158-
protected val sparkSessionForStream: SparkSession = sparkSession.cloneSession()
158+
protected[sql] val sparkSessionForStream: SparkSession = sparkSession.cloneSession()
159159

160160
/**
161161
* Manages the metadata from this checkpoint location.
@@ -320,6 +320,16 @@ abstract class StreamExecution(
320320
sparkSessionForStream.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
321321
}
322322

323+
sparkSessionForStream.conf.get(SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL) match {
324+
case Some(_) => // no-op
325+
case None =>
326+
// Take the default value of `spark.sql.shuffle.partitions`.
327+
val shufflePartitionValue = sparkSessionForStream.conf.get(SQLConf.SHUFFLE_PARTITIONS)
328+
sparkSessionForStream.conf.set(
329+
SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL.key,
330+
shufflePartitionValue)
331+
}
332+
323333
getLatestExecutionContext().updateStatusMessage("Initializing sources")
324334
// force initialization of the logical plan so that the sources can be created
325335
logicalPlan
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
v1
2+
{"nextBatchWatermarkMs":0}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"id":"295ee44f-dd99-45cf-a21d-9a760b439c45"}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
v1
2+
{"batchWatermarkMs":0,"batchTimestampMs":1760948082021,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.stateStore.rocksdb.formatVersion":"5","spark.sql.streaming.stateStore.encodingFormat":"unsaferow","spark.sql.streaming.statefulOperator.useStrictDistribution":"true","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"5","spark.sql.streaming.join.stateFormatVersion":"2","spark.sql.streaming.stateStore.compression.codec":"lz4","spark.sql.optimizer.pruneFiltersCanPruneStreamingSubplan":"false"}}
3+
0

0 commit comments

Comments
 (0)