Skip to content

Commit a8cb80a

Browse files
committed
[SPARK-56482][SQL][FOLLOWUP] Simplify UnionExec codegen and narrow partition-index gate
### What changes were proposed in this pull request? Followup to SPARK-56482 (#55425). Two groups of changes to `UnionExec`'s whole-stage codegen path. **Code cleanness:** - Hoist `metricTerm("numOutputRows")` to `doProduce` and store it on the instance. `doConsume` runs once per child during emission, so the previous code registered the same metric N times in `references[]` for an N-child Union; now once. - Drop the dead `assert` in `perChildProjections` and the duplicate `allChildOutputDataTypesMatch` lazy val. The dataType comparison now has a single source of truth in the `type-mismatch` branch of the gate. - Inline the one-shot `hasAnyPartitionIndexDependentDescendant` lazy val. - Drop the unreachable `case other` in the `UnionPartition` match and replace with `asInstanceOf`. `unionedInputRDD` is built as `new UnionRDD(...)` two lines up, and `getPartitions` only ever returns `UnionPartition[_]`. - Factor `isPlainUnion` helper used by the gate and `doExecute` so the invariant "codegen path matches `sparkContext.union` semantics" lives in one place. - Bind `currentPartitionIndexVar` to the array-deref expression `((int[]) refs[K])[partitionIndex]` directly. An earlier revision hoisted this to a `childLocalIdx` local at helper entry, but `SampleExec.doConsume` reads `currentPartitionIndexVar` from inside an `addMutableState` initializer, which is emitted into the state-init function — outside the per-child helper — so the local was not in scope and the generated code failed to compile. The expression form resolves in any emission scope (helper parameter or `BufferedRowIterator` field). - Drop the `try/finally` around codegen state restoration. Codegen failure aborts the whole stage, so the restoration is unreachable. **Gate narrowing:** - Narrow `hasPartitionIndexDependentCodegen` to exclude `InputFileName`, `InputFileBlockStart`, and `InputFileBlockLength`. These are `Nondeterministic` but read from `InputFileBlockHolder` (a per-task thread-local) and do not embed `partitionIndex`, so they are safe under fusion. Queries like `SELECT input_file_name() FROM a UNION ALL SELECT input_file_name() FROM b` now fuse. ### Why are the changes needed? The cleanups remove accidental complexity in the fused code path: an N-fold metric reference, two duplicated dataType comparisons, an unreachable defensive guard, and a `try/finally` that protects against an unreachable case. The gate narrowing turns a missed optimization (file-scan unions) into a fused plan. ### Does this PR introduce _any_ user-facing change? No. `spark.sql.codegen.wholeStage.union.enabled` remains off by default; when on, the new behavior fuses additional plans (file-scan unions with `input_file_name()`) that the previous gate over-rejected. ### How was this patch tested? `UnionCodegenSuite`, `UnionCodegenAnsiSuite`, `UnionCodegenAqeSuite`, and the relevant `SQLMetricsSuite` test all pass. Three tests added: - `partitioning-aware union falls back to non-codegen` — covers a `supportCodegenFailureReason` branch that lacked explicit coverage. - `input_file_name child fuses (Nondeterministic but partition-index-free)` — validates the gate narrowing. - `union with sample children fuses (or falls back) without crashing` — regression test for the `currentPartitionIndexVar` binding (caught by LuciferYang in review). The `columnar` fallback branch is not covered by a new test: reliably constructing a plan where `Union.supportsColumnar` is true via the user-facing API turned out to be brittle, since `ApplyColumnarRulesAndInsertTransitions` aggressively rebalances columnar/row transitions. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code Closes #55719 from cloud-fan/SPARK-56482-followup. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit d905e73) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent fd05859 commit a8cb80a

2 files changed

Lines changed: 128 additions & 82 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 77 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -901,61 +901,47 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan with CodegenSup
901901
}
902902
}
903903

904-
// `WidenSetOperationTypes` inserts a `Project(Cast)` above each child whose
905-
// dataType differs from the widened set type, so on the codegen path
906-
// `src.dataType == tgt.dataType` holds. The Alias only remaps each child
907-
// attribute onto the union's output exprId/name/metadata. Mismatched cases
908-
// are gated upstream by `allChildOutputDataTypesMatch`, so the assert is a
909-
// defensive guard.
904+
// True when the codegen path applies: `outputPartitioning` is `UnknownPartitioning`,
905+
// and `unionedInputRDD` matches the semantics of `sparkContext.union(...)` in `doExecute`.
906+
private def isPlainUnion: Boolean = outputPartitioning.isInstanceOf[UnknownPartitioning]
907+
908+
// Per-child projection from the child's output to the union's output. The wrapped
909+
// child is always the source `Attribute` (deterministic by construction); the Alias
910+
// only remaps the exprId/name/metadata. `WidenSetOperationTypes` aligns top-level
911+
// dataTypes, but nested nullability differences bypass it; those cases are caught
912+
// by the `type-mismatch` gate below, which is the single source of truth for the
913+
// `src.dataType == tgt.dataType` invariant `doConsume` relies on.
910914
@transient private lazy val perChildProjections: IndexedSeq[Seq[NamedExpression]] =
911915
children.toIndexedSeq.map { child =>
912916
child.output.zip(output).map { case (src, tgt) =>
913-
assert(src.dataType == tgt.dataType,
914-
s"UnionExec child output dataType ${src.dataType} does not match " +
915-
s"union output dataType ${tgt.dataType}; supportCodegen should " +
916-
"have returned false via the 'type-mismatch' reason.")
917917
Alias(src, tgt.name)(
918918
exprId = tgt.exprId,
919919
qualifier = tgt.qualifier,
920920
explicitMetadata = Some(tgt.metadata))
921921
}
922922
}
923923

924-
// True iff every child output dataType matches the corresponding union
925-
// output dataType, including all nested nullabilities.
926-
// `Union.allChildrenCompatible` ignores nested nullability, so children
927-
// differing only there bypass `WidenSetOperationTypes`; `UnionExec.output`
928-
// then merges those flags via `StructType.unionLikeMerge`, leaving src/tgt
929-
// mismatched.
930-
@transient private lazy val allChildOutputDataTypesMatch: Boolean =
931-
children.forall { c =>
932-
c.output.zip(output).forall { case (src, tgt) => src.dataType == tgt.dataType }
933-
}
934-
935-
// Memoized: `supportCodegen` is called multiple times during planning.
936-
@transient private lazy val hasAnyPartitionIndexDependentDescendant: Boolean =
937-
children.exists(UnionExec.hasPartitionIndexDependentCodegen)
938-
939924
// Memoized: consulted by `supportCodegen` (called multiple times by
940925
// `CollapseCodegenStages`) and by `metrics`. Conf and children are stable
941926
// for a given UnionExec instance; cross-plan staleness is impossible since
942927
// UnionExec is a case class and `withNewChildren` produces a fresh instance.
943928
@transient private lazy val supportCodegenFailureReason: Option[String] = {
944929
if (!conf.getConf(SQLConf.WHOLESTAGE_UNION_CODEGEN_ENABLED)) {
945930
Some("union-codegen-disabled")
946-
} else if (!outputPartitioning.isInstanceOf[UnknownPartitioning]) {
931+
} else if (!isPlainUnion) {
947932
Some("partitioning-aware")
948933
} else if (children.exists(_.exists(_.isInstanceOf[UnionExec]))) {
949934
Some("nested-union")
950935
} else if (children.exists(_.exists(UnionExec.isKnownMultiInputRDDCodegen))) {
951936
Some("multi-rdd-child")
952-
} else if (hasAnyPartitionIndexDependentDescendant) {
937+
} else if (children.exists(UnionExec.hasPartitionIndexDependentCodegen)) {
953938
Some("partition-index-dependent-child")
954939
} else if (children.size > conf.getConf(SQLConf.WHOLESTAGE_UNION_MAX_CHILDREN)) {
955940
Some("max-children-exceeded")
956941
} else if (supportsColumnar) {
957942
Some("columnar")
958-
} else if (!allChildOutputDataTypesMatch) {
943+
} else if (children.exists(c =>
944+
c.output.zip(output).exists { case (src, tgt) => src.dataType != tgt.dataType })) {
959945
Some("type-mismatch")
960946
} else {
961947
None
@@ -1002,61 +988,66 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan with CodegenSup
1002988

1003989
override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(unionedInputRDD)
1004990

1005-
// Driver-side cursor written by `doProduce` and read by `doConsume` during
1006-
// single-threaded code emission; resets to -1 once emission completes.
991+
// Set in `doProduce`, read in `doConsume` during single-threaded code
992+
// emission. `numOutputRowsTerm` is registered once per stage so the
993+
// metric appears in `references[]` exactly once instead of once per
994+
// child. `currentEmittingChild` tells `doConsume` which child's
995+
// projection to bind.
996+
@transient private var numOutputRowsTerm: String = _
1007997
@transient private var currentEmittingChild: Int = -1
1008998

1009999
override protected def doProduce(ctx: CodegenContext): String = {
1000+
numOutputRowsTerm = metricTerm(ctx, "numOutputRows")
1001+
10101002
// For each partition of the unioned RDD, record its owning child and its
10111003
// index within that child's RDD. Read both fields directly off the
10121004
// `UnionPartition` so the lookup arrays do not assume `UnionRDD` lays
10131005
// partitions out in child order.
1014-
val (partitionToChild, partitionToLocalIdx) =
1015-
unionedInputRDD.partitions.map {
1016-
case up: UnionPartition[_] => (up.parentRddIndex, up.parentPartition.index)
1017-
case other =>
1018-
throw SparkException.internalError(
1019-
s"UnionExec: Unexpected partition type ${other.getClass.getName}")
1020-
}.unzip
1006+
val (partitionToChild, partitionToLocalIdx) = unionedInputRDD.partitions.map { p =>
1007+
val up = p.asInstanceOf[UnionPartition[_]]
1008+
(up.parentRddIndex, up.parentPartition.index)
1009+
}.unzip
10211010
val p2cRef = ctx.addReferenceObj("partitionToChild", partitionToChild)
10221011
val p2lRef = ctx.addReferenceObj("partitionToLocalIdx", partitionToLocalIdx)
10231012
val childIndexVar = ctx.freshName("unionChildIdx")
10241013

1025-
// Each child's produce output is wrapped in its own helper method. The
1026-
// outer `switch` in `doProduce`'s return value dispatches to the helper.
1014+
// Each child's produced code is wrapped in its own helper method.
10271015
// Without this, the fused method's bytecode grows linearly with the
10281016
// number of children and quickly exceeds HotSpot's per-method limit,
10291017
// forcing the whole stage to run interpreted.
10301018
//
1031-
// `partitionIndex` is passed as a parameter (shadowing the superclass
1032-
// field) rather than read from the enclosing scope. `addNewFunction` may
1033-
// spill helpers into a nested class when the outer class fills up, and a
1034-
// nested class cannot access the protected
1035-
// `BufferedRowIterator.partitionIndex` field. Using the parameter name
1036-
// `partitionIndex` keeps any child-emitted reference to that identifier
1037-
// resolving locally.
1019+
// The helper takes `int partitionIndex` as a parameter; `addNewFunction`
1020+
// may spill helpers into a nested class once the outer class fills up,
1021+
// and a nested class cannot access the protected
1022+
// `BufferedRowIterator.partitionIndex` field.
1023+
//
1024+
// `currentPartitionIndexVar` is rebound to an array-deref expression
1025+
// (rather than a local) so leaf operators (`RangeExec`, `SampleExec`)
1026+
// see the child-local index regardless of where their code is emitted.
1027+
// `SampleExec.doConsume` uses `addMutableState`, whose initializer is
1028+
// emitted into the state-init function, not the helper - a local in
1029+
// the helper would not be in scope there. The expression resolves
1030+
// against `partitionIndex` (the helper parameter inside the helper,
1031+
// and the `BufferedRowIterator` field elsewhere) in every context.
10381032
val savedPartIdxVar = ctx.currentPartitionIndexVar
1039-
val cases = try {
1040-
children.zipWithIndex.map { case (c, i) =>
1041-
currentEmittingChild = i
1042-
ctx.currentPartitionIndexVar = s"((int[]) $p2lRef)[partitionIndex]"
1043-
val producedCode = c.asInstanceOf[CodegenSupport].produce(ctx, this)
1044-
val helper = ctx.freshName("unionChildProcess")
1045-
val qualifiedHelper = ctx.addNewFunction(helper,
1046-
s"""
1047-
|private void $helper(int partitionIndex) throws java.io.IOException {
1048-
| $producedCode
1049-
|}
1050-
""".stripMargin)
1051-
s"""case $i: {
1052-
| $qualifiedHelper(partitionIndex);
1053-
| break;
1054-
|}""".stripMargin
1055-
}
1056-
} finally {
1057-
currentEmittingChild = -1
1058-
ctx.currentPartitionIndexVar = savedPartIdxVar
1033+
ctx.currentPartitionIndexVar = s"((int[]) $p2lRef)[partitionIndex]"
1034+
val cases = children.zipWithIndex.map { case (c, i) =>
1035+
currentEmittingChild = i
1036+
val producedCode = c.asInstanceOf[CodegenSupport].produce(ctx, this)
1037+
val helper = ctx.freshName("unionChildProcess")
1038+
val qualifiedHelper = ctx.addNewFunction(helper,
1039+
s"""
1040+
|private void $helper(int partitionIndex) throws java.io.IOException {
1041+
| $producedCode
1042+
|}
1043+
""".stripMargin)
1044+
s"""case $i: {
1045+
| $qualifiedHelper(partitionIndex);
1046+
| break;
1047+
|}""".stripMargin
10591048
}
1049+
currentEmittingChild = -1
1050+
ctx.currentPartitionIndexVar = savedPartIdxVar
10601051

10611052
s"""
10621053
|int $childIndexVar = ((int[]) $p2cRef)[partitionIndex];
@@ -1071,24 +1062,17 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan with CodegenSup
10711062

10721063
override def doConsume(
10731064
ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
1074-
require(currentEmittingChild >= 0,
1075-
"UnionExec.doConsume invoked outside doProduce emission window")
10761065
val i = currentEmittingChild
1077-
// The wrapped child in each `perChildProjections(i)` element is always an
1078-
// `Attribute`, which is deterministic by definition; no
1079-
// `evaluateRequiredVariables` call is needed to force single-evaluation
1080-
// of non-deterministic expressions.
1081-
val bound = BindReferences.bindReferences(perChildProjections(i), children(i).output)
1082-
1066+
require(i >= 0, "UnionExec.doConsume invoked outside doProduce emission window")
10831067
// Route BoundReference reads through `currentVars` (the incoming row is
10841068
// delivered as variables under WSCG, not via ctx.INPUT_ROW).
1069+
val bound = BindReferences.bindReferences(perChildProjections(i), children(i).output)
10851070
ctx.currentVars = input
10861071
ctx.INPUT_ROW = null
10871072
val projectedExprCodes = bound.map(_.genCode(ctx))
10881073

1089-
val numOutput = metricTerm(ctx, "numOutputRows")
10901074
s"""
1091-
|$numOutput.add(1L);
1075+
|$numOutputRowsTerm.add(1L);
10921076
|${consume(ctx, projectedExprCodes)}
10931077
""".stripMargin
10941078
}
@@ -1103,7 +1087,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan with CodegenSup
11031087
override def usedInputs: AttributeSet = AttributeSet.empty
11041088

11051089
protected override def doExecute(): RDD[InternalRow] = {
1106-
if (outputPartitioning.isInstanceOf[UnknownPartitioning]) {
1090+
if (isPlainUnion) {
11071091
sparkContext.union(children.map(_.execute()))
11081092
} else {
11091093
// This union has a known partitioning, i.e., its children have the same partitioning
@@ -1138,13 +1122,24 @@ object UnionExec {
11381122
}
11391123

11401124
/**
1141-
* True if any expression in the subtree is [[Nondeterministic]]. Such
1142-
* expressions may embed the raw `partitionIndex` field via
1143-
* `addPartitionInitializationStatement`, which would read the global
1125+
* True if any expression in the subtree embeds the raw `partitionIndex` field
1126+
* via `addPartitionInitializationStatement`, which would read the global
11441127
* UnionRDD index instead of the child-local one under fusion.
1128+
*
1129+
* The check uses [[Nondeterministic]] as the proxy: every catalyst expression
1130+
* that calls `addPartitionInitializationStatement` referencing `partitionIndex`
1131+
* is `Nondeterministic`. The `InputFile*` expressions are `Nondeterministic`
1132+
* but read from `InputFileBlockHolder` (a per-task thread-local) and do not
1133+
* embed `partitionIndex`, so they are safe under fusion.
11451134
*/
1146-
def hasPartitionIndexDependentCodegen(p: SparkPlan): Boolean = p.exists {
1147-
plan => plan.expressions.exists(_.exists(_.isInstanceOf[Nondeterministic]))
1135+
def hasPartitionIndexDependentCodegen(p: SparkPlan): Boolean = p.exists { plan =>
1136+
plan.expressions.exists(_.exists {
1137+
case _: InputFileName => false
1138+
case _: InputFileBlockStart => false
1139+
case _: InputFileBlockLength => false
1140+
case _: Nondeterministic => true
1141+
case _ => false
1142+
})
11481143
}
11491144
}
11501145

sql/core/src/test/scala/org/apache/spark/sql/execution/UnionCodegenSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,57 @@ class UnionCodegenSuite extends QueryTest with SharedSparkSession {
602602
"numOutputRows should be 0 for all-empty union")
603603
}
604604
}
605+
606+
test("SPARK-56482: partitioning-aware union falls back to non-codegen") {
607+
// After repartition, both children expose a `HashPartitioning` on the same key,
608+
// so `UnionExec.outputPartitioning` is non-Unknown and the codegen path is denied.
609+
// AQE is disabled here so the executedPlan exposes the UnionExec directly
610+
// (under AQE the plan is wrapped in `AdaptiveSparkPlanExec`, which does not
611+
// surface its inputPlan via `children`).
612+
withSQLConf(
613+
SQLConf.UNION_OUTPUT_PARTITIONING.key -> "true",
614+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
615+
val a = rangeDF(100).repartition(4, col("id"))
616+
val b = rangeDF(100, 200).repartition(4, col("id"))
617+
val df = a.union(b)
618+
assert(!unionInsideWSCG(df),
619+
"Partitioning-aware union must not fuse into WSCG")
620+
val unionExec = df.queryExecution.executedPlan.collectFirst {
621+
case u: UnionExec => u
622+
}.get
623+
assert(!unionExec.metrics.contains("numOutputRows"),
624+
"numOutputRows metric must not be registered on the partitioning-aware path")
625+
assertFlagParity(() => a.union(b).orderBy("id"))
626+
}
627+
}
628+
629+
test("SPARK-56482: input_file_name child fuses (Nondeterministic but partition-index-free)") {
630+
// `InputFileName` is `Nondeterministic` but reads from `InputFileBlockHolder`
631+
// (a per-task thread-local) and does not embed `partitionIndex`. The gate's
632+
// narrow check should let this fuse.
633+
withTempPath { dir =>
634+
val path = dir.getCanonicalPath
635+
rangeDF(20).write.parquet(path)
636+
val a = spark.read.parquet(path).select(col("id"), input_file_name().as("f"))
637+
val b = spark.read.parquet(path).select(col("id"), input_file_name().as("f"))
638+
val df = a.union(b).filter(col("id") > 0)
639+
assert(unionInsideWSCG(df),
640+
"Union with input_file_name child should fuse into WSCG")
641+
assertFlagParity(() => a.union(b).orderBy("id", "f"))
642+
}
643+
}
644+
645+
test("SPARK-56482: union with sample children fuses (or falls back) without crashing") {
646+
// `SampleExec.doConsume` reads `currentPartitionIndexVar` from inside an
647+
// `addMutableState` initializer, which is emitted into the state-init
648+
// function rather than the per-child helper. The bound expression must
649+
// therefore resolve in any emission scope, not just inside the helper.
650+
val a = rangeDF(20).sample(false, 0.5, 1L)
651+
val b = rangeDF(20).sample(false, 0.5, 1L)
652+
val df = a.union(b).filter(col("id") > 0)
653+
df.collect()
654+
assertFlagParity(() => a.union(b).orderBy("id"))
655+
}
605656
}
606657

607658
/** Runs [[UnionCodegenSuite]] with ANSI mode enabled. */

0 commit comments

Comments
 (0)