diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxExpandSuite.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxExpandSuite.scala index 1af4e9fba0de..6fead0ce2a8a 100644 --- a/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxExpandSuite.scala +++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/VeloxExpandSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.execution import org.apache.gluten.config.GlutenConfig import org.apache.gluten.events.GlutenPlanFallbackEvent -import org.apache.gluten.execution.VeloxWholeStageTransformerSuite +import org.apache.gluten.execution.{ExpandExecTransformer, VeloxWholeStageTransformerSuite} import org.apache.spark.SparkConf import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql.{DataFrame, Row} + +import java.sql.Date import scala.collection.mutable.ArrayBuffer @@ -33,10 +36,18 @@ class VeloxExpandSuite extends VeloxWholeStageTransformerSuite { override def sparkConf: SparkConf = { super.sparkConf .set(GlutenConfig.GLUTEN_UI_ENABLED.key, "true") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") // The gluten ui event test suite expects the spark ui to be enable .set(UI_ENABLED, true) } + private def assertContainsNativeExpand(df: DataFrame): Unit = { + assert( + getExecutedPlan(df).exists(_.isInstanceOf[ExpandExecTransformer]), + s"Expected ExpandExecTransformer in plan, got:\n${df.queryExecution.executedPlan}" + ) + } + test("Expand with duplicated group keys") { withTable("t1") { val events = new ArrayBuffer[GlutenPlanFallbackEvent] @@ -77,4 +88,128 @@ class VeloxExpandSuite extends VeloxWholeStageTransformerSuite { } } } + + test("Expand with round(avg(decimal)) and multiple distinct aggregates") { + withTempPath { + pendingPath => + withTempPath { + verifiedPath => + withTempView("pending_events", "verified_events") { + spark + .sql(""" + |SELECT * FROM VALUES + | (1L, DATE'2026-04-22', 'A24', 0L), + | (2L, DATE'2026-04-22', 'A24', 0L) + |AS pending_events(order_id, pending_date, pending_reason, pending_timestamp) + |""".stripMargin) + .write + .mode("overwrite") + .parquet(pendingPath.getCanonicalPath) + + spark + .sql(""" + |SELECT * FROM VALUES + | (1L, 90000L), + | (2L, 180000L) + |AS verified_events(order_id, verified_timestamp) + |""".stripMargin) + .write + .mode("overwrite") + .parquet(verifiedPath.getCanonicalPath) + + spark.read + .parquet(pendingPath.getCanonicalPath) + .createOrReplaceTempView("pending_events") + spark.read + .parquet(verifiedPath.getCanonicalPath) + .createOrReplaceTempView("verified_events") + + val df = spark.sql( + """ + |WITH sla_calc AS ( + | SELECT + | p.pending_date, + | p.pending_reason, + | p.order_id, + | round( + | cast((v.verified_timestamp - p.pending_timestamp) as decimal(38, 18)) / + | 3600.000000000000000000, + | 1) AS sla_hours + | FROM pending_events p + | JOIN verified_events v + | ON p.order_id = v.order_id + |) + |SELECT + | pending_date, + | pending_reason, + | COUNT(DISTINCT order_id) AS total_order, + | round(AVG(sla_hours), 1) AS avg_sla_hours, + | COUNT(DISTINCT CASE WHEN sla_hours > 24 THEN order_id END) AS backlog_24, + | COUNT(DISTINCT CASE WHEN sla_hours > 48 THEN order_id END) AS backlog_48 + |FROM sla_calc + |GROUP BY pending_date, pending_reason + |""".stripMargin) + + checkAnswer( + df, + Row(Date.valueOf("2026-04-22"), "A24", 2L, BigDecimal("37.5"), 2L, 1L)) + assertContainsNativeExpand(df) + } + } + } + } + + test("Expand with decimal case-when sum and multiple distinct aggregates") { + withTempPath { + eventsPath => + withTempView("smart_events") { + spark + .sql(""" + |SELECT * FROM VALUES + | (1, 101L, 1001L, 1, 0, 1, + | CAST(1.1000000000 AS DECIMAL(25, 10)), + | CAST(2.2000000000 AS DECIMAL(25, 10)), + | CAST(3.3000000000 AS DECIMAL(25, 10))), + | (1, 102L, 1002L, 0, 1, 0, + | CAST(4.4000000000 AS DECIMAL(25, 10)), + | CAST(5.5000000000 AS DECIMAL(25, 10)), + | CAST(6.6000000000 AS DECIMAL(25, 10))) + |AS smart_events( + | campaign_id, + | order_id, + | checkout_id, + | has_dd, + | has_ccb, + | has_fsv, + | dd_cost_usd, + | ccb_cost_usd, + | fsv_cost_usd) + |""".stripMargin) + .write + .mode("overwrite") + .parquet(eventsPath.getCanonicalPath) + + spark.read.parquet(eventsPath.getCanonicalPath).createOrReplaceTempView("smart_events") + + val df = + spark.sql(""" + |SELECT + | campaign_id, + | COUNT(DISTINCT order_id) AS total_order, + | COUNT(DISTINCT CASE WHEN has_dd = 1 THEN order_id END) AS dd_order, + | COUNT(DISTINCT checkout_id) AS checkout_count, + | SUM( + | CASE WHEN has_dd = 1 THEN dd_cost_usd ELSE 0 END + + | CASE WHEN has_ccb = 1 THEN ccb_cost_usd ELSE 0 END + + | CASE WHEN has_fsv = 1 THEN fsv_cost_usd ELSE 0 END + | ) AS smart_voucher_cost_usd + |FROM smart_events + |GROUP BY campaign_id + |""".stripMargin) + + checkAnswer(df, Row(1, 2L, 1L, 2L, BigDecimal("9.9000000000"))) + assertContainsNativeExpand(df) + } + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala index 47aa66f16b26..282ae1468faa 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ExpandExecTransformer.scala @@ -25,18 +25,23 @@ import org.apache.gluten.substrait.expression.ExpressionNode import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.types.DataType import java.util.{ArrayList => JArrayList, List => JList} +import scala.util.control.NonFatal + case class ExpandExecTransformer( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) extends UnaryExecNode - with UnaryTransformSupport { + with UnaryTransformSupport + with Logging { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. @transient override lazy val metrics = @@ -59,6 +64,97 @@ case class ExpandExecTransformer( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + private def transformerDataType(expression: Expression): Option[DataType] = { + try { + Some(ExpressionConverter.replaceWithExpressionTransformer(expression, child.output).dataType) + } catch { + case NonFatal(_) => None + } + } + + private def needsTypeAlignment(expression: Expression, outputType: DataType): Boolean = { + expression.dataType != outputType || + transformerDataType(expression).exists(_ != outputType) + } + + private def alignExpressionType(expression: Expression, outputType: DataType): Expression = { + if (!needsTypeAlignment(expression, outputType)) { + expression + } else { + expression match { + case Literal(null, _) => Literal.create(null, outputType) + case other => other + } + } + } + + private def alignProjectionsToOutput(projectSets: Seq[Seq[Expression]]): Seq[Seq[Expression]] = { + projectSets.map { + projectSet => + projectSet.zipWithIndex.map { + case (expression, colIdx) if colIdx < output.length => + alignExpressionType(expression, output(colIdx).dataType) + case (expression, _) => + expression + } + } + } + + private def columnDiagnostics(projectSets: Seq[Seq[Expression]], colIdx: Int): String = { + val outputAttr = output(colIdx) + val projectionTypes = projectSets.zipWithIndex.map { + case (row, rowIdx) => + val expression = row(colIdx) + val nativeType = transformerDataType(expression) + .map(_.catalogString) + .getOrElse("") + s"row[$rowIdx]=${expression.sql}:spark=${expression.dataType.catalogString}," + + s"transformer=$nativeType" + } + s"col[$colIdx]=${outputAttr.name}:output=${outputAttr.dataType.catalogString}," + + s" projections=[${projectionTypes.mkString(", ")}]" + } + + private def columnTypesMismatch(projectSets: Seq[Seq[Expression]], colIdx: Int): Boolean = { + val outputType = output(colIdx).dataType + val sparkTypes = projectSets.map(_(colIdx).dataType) + val transformerTypes = projectSets.flatMap(row => transformerDataType(row(colIdx))) + sparkTypes.distinct.size > 1 || + sparkTypes.exists(_ != outputType) || + transformerTypes.distinct.size > 1 || + transformerTypes.exists(_ != outputType) + } + + private def projectionTypeMismatchMessageIfAny( + projectSets: Seq[Seq[Expression]]): Option[String] = { + if (projectSets.nonEmpty && output.nonEmpty) { + val mismatchColumns = output.indices.filter(columnTypesMismatch(projectSets, _)) + if (mismatchColumns.nonEmpty) { + val diagnostics = mismatchColumns + .take(5) + .map(columnDiagnostics(projectSets, _)) + .mkString("; ") + val omittedColumns = mismatchColumns.size - 5 + val suffix = + if (omittedColumns > 0) s"; ... $omittedColumns more mismatch column(s)" else "" + return Some( + "ExpandExecTransformer detected projection/output type mismatch before " + + "Substrait conversion. Failing validation to avoid native ExpandRel " + + "with inconsistent Spark or transformer column types: " + + s"$diagnostics$suffix") + } + } + None + } + + private def failOnProjectionTypeMismatch(projectSets: Seq[Seq[Expression]]): Unit = { + projectionTypeMismatchMessageIfAny(projectSets).foreach { + message => + logError(message) + throw new IllegalStateException(message) + } + } + def getRelNode( context: SubstraitContext, projections: Seq[Seq[Expression]], @@ -104,11 +200,24 @@ case class ExpandExecTransformer( return ValidationResult.failed("Current backend does not support empty projections in expand") } + val alignedProjections = alignProjectionsToOutput(projections) + projectionTypeMismatchMessageIfAny(alignedProjections).foreach { + message => + logError(message) + return ValidationResult.failed(message) + } + val substraitContext = new SubstraitContext val operatorId = substraitContext.nextOperatorId(this.nodeName) val relNode = - getRelNode(substraitContext, projections, child.output, operatorId, null, validation = true) + getRelNode( + substraitContext, + alignedProjections, + child.output, + operatorId, + null, + validation = true) doNativeValidation(substraitContext, relNode) } @@ -120,9 +229,18 @@ case class ExpandExecTransformer( return childCtx } + val alignedProjections = alignProjectionsToOutput(projections) + failOnProjectionTypeMismatch(alignedProjections) + val operatorId = context.nextOperatorId(this.nodeName) val currRel = - getRelNode(context, projections, child.output, operatorId, childCtx.root, validation = false) + getRelNode( + context, + alignedProjections, + child.output, + operatorId, + childCtx.root, + validation = false) assert(currRel != null, "Expand Rel should be valid") TransformContext(output, currRel) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala index 13c14bffc002..a6e35aedcf41 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala @@ -17,6 +17,7 @@ package org.apache.gluten.extension.columnar.rewrite import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.expression.ExpressionConverter import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.utils.PullOutProjectHelper @@ -26,8 +27,10 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression} import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.types.{DataType, DecimalType} import scala.collection.mutable +import scala.util.control.NonFatal /** * The native engine only supports executing Expressions within the project operator. When there are @@ -95,11 +98,80 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { SparkShimLoader.getSparkShims.getWindowGroupLimitExecShim(plan) windowGroupLimitExecShim.orderSpec.exists(o => isNotAttribute(o.child)) || windowGroupLimitExecShim.partitionSpec.exists(isNotAttribute) - case expand: ExpandExec => expand.projections.flatten.exists(isNotAttributeAndLiteral) + case expand: ExpandExec => needsExpandPreProject(expand) case _ => false } } + private def isNullLiteral(expression: Expression): Boolean = expression match { + case Literal(null, _) => true + case _ => false + } + + private def isCastToType(expression: Expression, dataType: DataType): Boolean = expression match { + case Cast(_, castType, _, _) if castType == dataType => true + case _ => false + } + + private def needsExplicitExpandOutputCast( + expression: Expression, + outputType: DataType): Boolean = { + outputType.isInstanceOf[DecimalType] && + isNotAttributeAndLiteral(expression) && + !isCastToType(expression, outputType) + } + + private def needsExpandProjectionTypeAlignment( + expression: Expression, + outputType: DataType, + inputAttributes: Seq[Attribute]): Boolean = { + !isNullLiteral(expression) && + (expression.dataType != outputType || + transformerDataType(expression, inputAttributes).exists(_ != outputType) || + needsExplicitExpandOutputCast(expression, outputType)) + } + + private def transformerDataType( + expression: Expression, + inputAttributes: Seq[Attribute]): Option[DataType] = { + try { + Some(ExpressionConverter.replaceWithExpressionTransformer( + expression, + inputAttributes).dataType) + } catch { + case NonFatal(_) => None + } + } + + private def needsExpandPreProject(expand: ExpandExec): Boolean = { + expand.projections.exists { + projection => + projection.zip(expand.output).exists { + case (expression, outputAttr) => + isNotAttributeAndLiteral(expression) || + needsExpandProjectionTypeAlignment(expression, outputAttr.dataType, expand.child.output) + } + } + } + + private def alignExpandProjectionExpression( + expression: Expression, + outputType: DataType, + inputAttributes: Seq[Attribute]): Expression = { + if (!needsExpandProjectionTypeAlignment(expression, outputType, inputAttributes)) { + expression + } else { + expression match { + case Literal(null, _) => + Literal.create(null, outputType) + case _ if isCastToType(expression, outputType) => + expression + case other => + Cast(other, outputType) + } + } + } + /** * Pull out Expressions in SortOrder's children, and return the new SortOrder that contains only * Attributes. @@ -258,13 +330,23 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper { case expand: ExpandExec if needsPreProject(expand) => val expressionMap = new mutable.HashMap[Expression, NamedExpression]() val newProjections = - expand.projections.toIndexedSeq.map( - _.toIndexedSeq.map( + expand.projections.toIndexedSeq.map(_.toIndexedSeq.zipWithIndex.map { + case (expression, colIdx) => + val alignedExpression = + if (colIdx < expand.output.length) { + alignExpandProjectionExpression( + expression, + expand.output(colIdx).dataType, + expand.child.output) + } else { + expression + } replaceExpressionWithAttribute( - _, + alignedExpression, expressionMap, replaceBoundReference = false, - replaceLiteral = false))) + replaceLiteral = false) + }) val newProject = ProjectExec( eliminateProjectList(expand.child.outputSet, expressionMap.values.toSeq),