Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ package org.apache.spark.sql.execution

import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.events.GlutenPlanFallbackEvent
import org.apache.gluten.execution.VeloxWholeStageTransformerSuite
import org.apache.gluten.execution.{ExpandExecTransformer, VeloxWholeStageTransformerSuite}

import org.apache.spark.SparkConf
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.{DataFrame, Row}

import java.sql.Date

import scala.collection.mutable.ArrayBuffer

Expand All @@ -33,10 +36,18 @@ class VeloxExpandSuite extends VeloxWholeStageTransformerSuite {
override def sparkConf: SparkConf = {
super.sparkConf
.set(GlutenConfig.GLUTEN_UI_ENABLED.key, "true")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
// The gluten ui event test suite expects the spark ui to be enable
.set(UI_ENABLED, true)
}

private def assertContainsNativeExpand(df: DataFrame): Unit = {
assert(
getExecutedPlan(df).exists(_.isInstanceOf[ExpandExecTransformer]),
s"Expected ExpandExecTransformer in plan, got:\n${df.queryExecution.executedPlan}"
)
}

test("Expand with duplicated group keys") {
withTable("t1") {
val events = new ArrayBuffer[GlutenPlanFallbackEvent]
Expand Down Expand Up @@ -77,4 +88,128 @@ class VeloxExpandSuite extends VeloxWholeStageTransformerSuite {
}
}
}

test("Expand with round(avg(decimal)) and multiple distinct aggregates") {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianzhenwu Thanks for your fixing.

I tried reproducing the issue in my local environment using the SQL you provided (#12008 (comment) and the two tests here) and the latest main branch(82644d3) with Spark 3.5, but I was unable to reproduce it.

In Spark, the projection expressions are passed directly from ExpandExec and should already be aligned with the output schema https://github.com/apache/spark/blob/c26a127ba33137f36d55bf95cac71471e2a1704f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala#L1398-L1407. Could you provide more details on your environment or help investigate why this occurs on your side?

Thanks for your help!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I encountered this problem using Spark 3.2. I believe it's also possible to reproduce the problem using Spark 3.3. I've tried using AI to explain the issue.

Spark 3.3 can reproduce the issue because its physical ExpandExec contains this decimal expression shape:

CAST((case_dd_decimal26 + case_ccb_decimal26) AS DECIMAL(27,10)) + case_fsv_decimal27

Spark declares the Expand output column as:

DECIMAL(27,10)

The null rows in the same Expand column are also:

CAST(NULL AS DECIMAL(27,10))

But when Velox compiles the non-null decimal arithmetic row, it infers:

DECIMAL(28,10)

So native ExpandNode sees mixed types in the same output column:

row 0: DECIMAL(28,10)
row 1: DECIMAL(27,10)

Then Velox fails with:

The projections type does not match across different rows in the same column.
Got: DECIMAL(27, 10), DECIMAL(28, 10)

Spark 3.5 does not reproduce it because the generated ExpandExec expression is different:

(case_dd_decimal25 + case_ccb_decimal25) + case_fsv_decimal25

It does not insert the intermediate:

CAST(... AS DECIMAL(27,10))

that Spark 3.3 has. With this Spark 3.5 plan shape, Velox’s inferred type stays compatible with Spark’s Expand output type, so all projection rows in the Expand column remain consistent.

So the difference is not the SQL result type. Both Spark versions declare the Expand output as DECIMAL(27,10). The difference is the internal decimal expression tree Spark generates before Gluten/Velox conversion. Spark 3.3’s tree causes Velox to widen one projection row to DECIMAL(28,10); Spark 3.5’s tree does not.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @JkSelf Do you think this fix is ​​correct to address the issue in the Spark32 scenario?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianzhenwu Yes, this issue persists in Spark 3.2 and 3.3. Would it be possible to introduce a new rule to align the output types for spark 32 and 33? This would make the logic much more clear.

withTempPath {
pendingPath =>
withTempPath {
verifiedPath =>
withTempView("pending_events", "verified_events") {
spark
.sql("""
|SELECT * FROM VALUES
| (1L, DATE'2026-04-22', 'A24', 0L),
| (2L, DATE'2026-04-22', 'A24', 0L)
|AS pending_events(order_id, pending_date, pending_reason, pending_timestamp)
|""".stripMargin)
.write
.mode("overwrite")
.parquet(pendingPath.getCanonicalPath)

spark
.sql("""
|SELECT * FROM VALUES
| (1L, 90000L),
| (2L, 180000L)
|AS verified_events(order_id, verified_timestamp)
|""".stripMargin)
.write
.mode("overwrite")
.parquet(verifiedPath.getCanonicalPath)

spark.read
.parquet(pendingPath.getCanonicalPath)
.createOrReplaceTempView("pending_events")
spark.read
.parquet(verifiedPath.getCanonicalPath)
.createOrReplaceTempView("verified_events")

val df = spark.sql(
"""
|WITH sla_calc AS (
| SELECT
| p.pending_date,
| p.pending_reason,
| p.order_id,
| round(
| cast((v.verified_timestamp - p.pending_timestamp) as decimal(38, 18)) /
| 3600.000000000000000000,
| 1) AS sla_hours
| FROM pending_events p
| JOIN verified_events v
| ON p.order_id = v.order_id
|)
|SELECT
| pending_date,
| pending_reason,
| COUNT(DISTINCT order_id) AS total_order,
| round(AVG(sla_hours), 1) AS avg_sla_hours,
| COUNT(DISTINCT CASE WHEN sla_hours > 24 THEN order_id END) AS backlog_24,
| COUNT(DISTINCT CASE WHEN sla_hours > 48 THEN order_id END) AS backlog_48
|FROM sla_calc
|GROUP BY pending_date, pending_reason
|""".stripMargin)

checkAnswer(
df,
Row(Date.valueOf("2026-04-22"), "A24", 2L, BigDecimal("37.5"), 2L, 1L))
assertContainsNativeExpand(df)
}
}
}
}

test("Expand with decimal case-when sum and multiple distinct aggregates") {
withTempPath {
eventsPath =>
withTempView("smart_events") {
spark
.sql("""
|SELECT * FROM VALUES
| (1, 101L, 1001L, 1, 0, 1,
| CAST(1.1000000000 AS DECIMAL(25, 10)),
| CAST(2.2000000000 AS DECIMAL(25, 10)),
| CAST(3.3000000000 AS DECIMAL(25, 10))),
| (1, 102L, 1002L, 0, 1, 0,
| CAST(4.4000000000 AS DECIMAL(25, 10)),
| CAST(5.5000000000 AS DECIMAL(25, 10)),
| CAST(6.6000000000 AS DECIMAL(25, 10)))
|AS smart_events(
| campaign_id,
| order_id,
| checkout_id,
| has_dd,
| has_ccb,
| has_fsv,
| dd_cost_usd,
| ccb_cost_usd,
| fsv_cost_usd)
|""".stripMargin)
.write
.mode("overwrite")
.parquet(eventsPath.getCanonicalPath)

spark.read.parquet(eventsPath.getCanonicalPath).createOrReplaceTempView("smart_events")

val df =
spark.sql("""
|SELECT
| campaign_id,
| COUNT(DISTINCT order_id) AS total_order,
| COUNT(DISTINCT CASE WHEN has_dd = 1 THEN order_id END) AS dd_order,
| COUNT(DISTINCT checkout_id) AS checkout_count,
| SUM(
| CASE WHEN has_dd = 1 THEN dd_cost_usd ELSE 0 END +
| CASE WHEN has_ccb = 1 THEN ccb_cost_usd ELSE 0 END +
| CASE WHEN has_fsv = 1 THEN fsv_cost_usd ELSE 0 END
| ) AS smart_voucher_cost_usd
|FROM smart_events
|GROUP BY campaign_id
|""".stripMargin)

checkAnswer(df, Row(1, 2L, 1L, 2L, BigDecimal("9.9000000000")))
assertContainsNativeExpand(df)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.types.DataType

import java.util.{ArrayList => JArrayList, List => JList}

import scala.util.control.NonFatal

case class ExpandExecTransformer(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: SparkPlan)
extends UnaryExecNode
with UnaryTransformSupport {
with UnaryTransformSupport
with Logging {

// Note: "metrics" is made transient to avoid sending driver-side metrics to tasks.
@transient override lazy val metrics =
Expand All @@ -59,6 +64,97 @@ case class ExpandExecTransformer(
// as UNKNOWN partitioning
override def outputPartitioning: Partitioning = UnknownPartitioning(0)

private def transformerDataType(expression: Expression): Option[DataType] = {
try {
Some(ExpressionConverter.replaceWithExpressionTransformer(expression, child.output).dataType)
} catch {
case NonFatal(_) => None
}
}

private def needsTypeAlignment(expression: Expression, outputType: DataType): Boolean = {
expression.dataType != outputType ||
transformerDataType(expression).exists(_ != outputType)
}

private def alignExpressionType(expression: Expression, outputType: DataType): Expression = {
if (!needsTypeAlignment(expression, outputType)) {
expression
} else {
expression match {
case Literal(null, _) => Literal.create(null, outputType)
case other => other
}
}
}

private def alignProjectionsToOutput(projectSets: Seq[Seq[Expression]]): Seq[Seq[Expression]] = {
projectSets.map {
projectSet =>
projectSet.zipWithIndex.map {
case (expression, colIdx) if colIdx < output.length =>
alignExpressionType(expression, output(colIdx).dataType)
case (expression, _) =>
expression
}
}
}

private def columnDiagnostics(projectSets: Seq[Seq[Expression]], colIdx: Int): String = {
val outputAttr = output(colIdx)
val projectionTypes = projectSets.zipWithIndex.map {
case (row, rowIdx) =>
val expression = row(colIdx)
val nativeType = transformerDataType(expression)
.map(_.catalogString)
.getOrElse("<unavailable>")
s"row[$rowIdx]=${expression.sql}:spark=${expression.dataType.catalogString}," +
s"transformer=$nativeType"
}
s"col[$colIdx]=${outputAttr.name}:output=${outputAttr.dataType.catalogString}," +
s" projections=[${projectionTypes.mkString(", ")}]"
}

private def columnTypesMismatch(projectSets: Seq[Seq[Expression]], colIdx: Int): Boolean = {
val outputType = output(colIdx).dataType
val sparkTypes = projectSets.map(_(colIdx).dataType)
val transformerTypes = projectSets.flatMap(row => transformerDataType(row(colIdx)))
sparkTypes.distinct.size > 1 ||
sparkTypes.exists(_ != outputType) ||
transformerTypes.distinct.size > 1 ||
transformerTypes.exists(_ != outputType)
}

private def projectionTypeMismatchMessageIfAny(
projectSets: Seq[Seq[Expression]]): Option[String] = {
if (projectSets.nonEmpty && output.nonEmpty) {
val mismatchColumns = output.indices.filter(columnTypesMismatch(projectSets, _))
if (mismatchColumns.nonEmpty) {
val diagnostics = mismatchColumns
.take(5)
.map(columnDiagnostics(projectSets, _))
.mkString("; ")
val omittedColumns = mismatchColumns.size - 5
val suffix =
if (omittedColumns > 0) s"; ... $omittedColumns more mismatch column(s)" else ""
return Some(
"ExpandExecTransformer detected projection/output type mismatch before " +
"Substrait conversion. Failing validation to avoid native ExpandRel " +
"with inconsistent Spark or transformer column types: " +
s"$diagnostics$suffix")
}
}
None
}

private def failOnProjectionTypeMismatch(projectSets: Seq[Seq[Expression]]): Unit = {
projectionTypeMismatchMessageIfAny(projectSets).foreach {
message =>
logError(message)
throw new IllegalStateException(message)
}
}

def getRelNode(
context: SubstraitContext,
projections: Seq[Seq[Expression]],
Expand Down Expand Up @@ -104,11 +200,24 @@ case class ExpandExecTransformer(
return ValidationResult.failed("Current backend does not support empty projections in expand")
}

val alignedProjections = alignProjectionsToOutput(projections)
projectionTypeMismatchMessageIfAny(alignedProjections).foreach {
message =>
logError(message)
return ValidationResult.failed(message)
}

val substraitContext = new SubstraitContext
val operatorId = substraitContext.nextOperatorId(this.nodeName)

val relNode =
getRelNode(substraitContext, projections, child.output, operatorId, null, validation = true)
getRelNode(
substraitContext,
alignedProjections,
child.output,
operatorId,
null,
validation = true)

doNativeValidation(substraitContext, relNode)
}
Expand All @@ -120,9 +229,18 @@ case class ExpandExecTransformer(
return childCtx
}

val alignedProjections = alignProjectionsToOutput(projections)
failOnProjectionTypeMismatch(alignedProjections)

val operatorId = context.nextOperatorId(this.nodeName)
val currRel =
getRelNode(context, projections, child.output, operatorId, childCtx.root, validation = false)
getRelNode(
context,
alignedProjections,
child.output,
operatorId,
childCtx.root,
validation = false)
assert(currRel != null, "Expand Rel should be valid")
TransformContext(output, currRel)
}
Expand Down
Loading
Loading