Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import org.apache.gluten.events.GlutenPlanFallbackEvent

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, FileSourceScanExec, SortExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.utils.GlutenSuiteUtils

import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -66,12 +67,32 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
.write
.format("parquet")
.saveAsTable("tmp3")
spark
.range(100)
.selectExpr(
"cast(id as decimal) as c1",
"cast(id % 3 as int) as c2",
"cast(id % 9 as timestamp) as c3")
.write
.format("orc")
.saveAsTable("tmp4")
spark
.range(100)
.selectExpr(
"cast(id as decimal) as c1",
"cast(id % 3 as int) as c2",
"cast(id % 5 as timestamp) as c3")
.write
.format("orc")
.saveAsTable("tmp5")
}

override protected def afterAll(): Unit = {
spark.sql("drop table tmp1")
spark.sql("drop table tmp2")
spark.sql("drop table tmp3")
spark.sql("drop table tmp4")
spark.sql("drop table tmp5")

super.afterAll()
}
Expand Down Expand Up @@ -420,4 +441,133 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
spark.sparkContext.removeSparkListener(listener)
}
}

test("For decimal-key joins, if one side falls back to Spark, force fallback the other side") {
// Two sides of smj fallback to spark scan -> symmetric -> native SMJ
val sql1 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, tmp5.c2 AS 5c2, " +
"tmp5.c3 AS 5c3 FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false"
) {
runQueryAndCompare(sql1) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case smj: SortMergeJoinExec => smj }.size == 0)
assert(collect(plan) { case smj: SortMergeJoinExecTransformer => smj }.size == 1)
}
}

// The left side of smj fallbacks to spark scan and the right side of smj is native scan
val sql2 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, tmp5.c2 AS 5c2 " +
"FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
runQueryAndCompare(sql2) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case smj: SortMergeJoinExec => smj }.size == 0)
assert(collect(plan) { case smj: SortMergeJoinExecTransformer => smj }.size == 1)
}
}

// The right side of smj fallbacks to spark scan and the left side of smj is native scan
val sql3 = "SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 " +
"FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
withSQLConf(
GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false",
GlutenConfig.COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED.key -> "false") {
runQueryAndCompare(sql3) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case smj: SortMergeJoinExec => smj }.size == 0)
assert(collect(plan) { case smj: SortMergeJoinExecTransformer => smj }.size == 1)
}
}

// Two sides of shj fallback to spark scan
val sql4 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, tmp5.c2 AS 5c2, " +
"tmp5.c3 AS 5c3 FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
runQueryAndCompare(sql4) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case shj: ShuffledHashJoinExec => shj }.size == 0)
assert(collect(plan) { case shj: ShuffledHashJoinExecTransformer => shj }.size == 1)
}
}

// The left side of shj fallbacks to spark scan and the right side of shj is native scan
val sql5 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, tmp5.c2 AS 5c2 " +
"FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
runQueryAndCompare(sql5) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case shj: ShuffledHashJoinExec => shj }.size == 0)
assert(collect(plan) { case shj: ShuffledHashJoinExecTransformer => shj }.size == 1)
}
}

// The right side of shj fallbacks to spark scan and the left side of shj is native scan
val sql6 = "SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 " +
"FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
runQueryAndCompare(sql6) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case shj: ShuffledHashJoinExec => shj }.size == 0)
assert(collect(plan) { case shj: ShuffledHashJoinExecTransformer => shj }.size == 1)
}
}

// Two sides of bhj fallback to spark scan
val sql7 = "SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, tmp5.c2 AS 5c2, " +
"tmp5.c3 AS 5c3 FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
runQueryAndCompare(sql7) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case bhj: BroadcastHashJoinExec => bhj }.size == 0)
assert(collect(plan) { case bhj: BroadcastHashJoinExecTransformer => bhj }.size == 1)
}

// The left side of bhj fallbacks to spark scan and the right side of bhj is native scan
val sql8 = "SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, tmp5.c2 AS 5c2 " +
"FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
runQueryAndCompare(sql8) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case bhj: BroadcastHashJoinExec => bhj }.size == 0)
assert(collect(plan) { case bhj: BroadcastHashJoinExecTransformer => bhj }.size == 1)
}

// The right side of bhj fallbacks to spark scan and the left side of bhj is native scan
val sql9 = "SELECT tmp4.c2 AS 4c2, tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 " +
"FROM tmp4 join tmp5 on tmp4.c1 = tmp5.c1"
runQueryAndCompare(sql9) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case scan: FileSourceScanExec => scan }.size == 2)
assert(collect(plan) { case scan: FileSourceScanExecTransformer => scan }.size == 0)
assert(collect(plan) { case bhj: BroadcastHashJoinExec => bhj }.size == 0)
assert(collect(plan) { case bhj: BroadcastHashJoinExecTransformer => bhj }.size == 1)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.extension.columnar.validator.Validator

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.DecimalType

// Add fallback tags when validator returns negative outcome.
case class AddFallbackTags(validator: Validator) extends Rule[SparkPlan] {
Expand All @@ -29,6 +32,9 @@ case class AddFallbackTags(validator: Validator) extends Rule[SparkPlan] {
case p if FallbackTags.maybeOffloadable(p) => addFallbackTag(p)
case _ =>
}

plan.foreach(validateJoin)

plan
}

Expand All @@ -40,4 +46,131 @@ case class AddFallbackTags(validator: Validator) extends Rule[SparkPlan] {
case Validator.Passed =>
}
}

/**
* Traverses the plan tree looking for join nodes (SortMergeJoin, ShuffledHashJoin,
* BroadcastHashJoin) whose join keys include at least one decimal column.
*
* For each such join, delegates to [[setFallbackTagForOtherSide]] to ensure that if one side's
* scan ([[FileSourceScanExec]] or `HiveTableScanExec`) cannot be offloaded to the native engine,
* the other side is also forced to fall back. This prevents a decimal-value mismatch that would
* produce incorrect (typically empty) join results when one side applies Spark's precision
* coercion and the other side reads raw native values.
*
* AdaptiveSparkPlanExec is handled by descending into its `initialPlan`; all other non-join nodes
* are handled recursively through their children.
*/
private def validateJoin(plan: SparkPlan): Unit =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this to a separate rule?

Copy link
Copy Markdown
Contributor Author

@beliefer beliefer May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created another rule AddFallbackTagsForJoin like AddFallbackTags and change the code of WithRewrites like:

case class WithRewrites(
      validator: Validator,
      rewriteRules: Seq[RewriteSingleNode],
      offloadRules: Seq[OffloadSingleNode])
    extends Rule[SparkPlan] {
    private val validate = AddFallbackTags(validator)
    private val validateJoin = AddFallbackTagsForJoin(validator)
    private val rewrite = RewriteSparkPlanRulesManager(validate, validateJoin, rewriteRules)
    private val offload = LegacyOffload(offloadRules)

    override def apply(plan: SparkPlan): SparkPlan = {
      Seq(rewrite, validate, validateJoin, offload).foldLeft(plan) {
        case (plan, stage) =>
          stage(plan)
      }
    }
  }

Then we must change the constructor of RewriteSparkPlanRulesManager to accept AddFallbackTagsForJoin and change the logic of RewriteSparkPlanRulesManager.
I think it's not worth to do this, so I merged the code into AddFallbackTags.

plan match {
case smj: SortMergeJoinExec
if (smj.leftKeys ++ smj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
setFallbackTagForOtherSide(smj.left, smj.right)
case shj: ShuffledHashJoinExec
if (shj.leftKeys ++ shj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
setFallbackTagForOtherSide(shj.left, shj.right)
case bhj: BroadcastHashJoinExec
if (bhj.leftKeys ++ bhj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
setFallbackTagForOtherSide(bhj.left, bhj.right)
case a: AdaptiveSparkPlanExec =>
validateJoin(a.initialPlan)
case _ => plan.children.foreach(validateJoin(_))
}

/**
* Enforces symmetric scan fallback for the two sides of a decimal-key join.
*
* When the join key is a decimal type, a native (Velox) scan and a vanilla Spark scan
* ([[FileSourceScanExec]] or `HiveTableScanExec`) may produce different representations of the
* same decimal value: the native reader may surface raw uncoerced int128_t values while the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the native side to support this case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one side fallbacks, this side should insert ColumnarToRow, why this representation issue?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation principle of decimal determines their different accuracies, and it is a difficult problem for me to solve now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should find the root cause and fix the result mismatch issue in native side, other one in community may fix this issue, please keep this issue open now.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root cause is the decimal accuracy is different from Java side and native side. The implementation mechanism for decimal in Velox is significantly different from that in Java, and I think it will be difficult to make them consistent in the short term.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinchengchenghh I added more information into #11980.

* vanilla reader applies Spark's precision coercion (returning NULL for out-of-range values). If
* only one side falls back, the join key values diverge and the join returns 0 rows.
*
* This method detects the asymmetric case (exactly one side contains a fallback scan) and adds a
* fallback tag to the native scan on the other side, so that both sides end up using the same
* read path.
*
* @param leftChild
* the left subtree of the join
* @param rightChild
* the right subtree of the join
*/
private def setFallbackTagForOtherSide(leftChild: SparkPlan, rightChild: SparkPlan): Unit = {
val leftHasFallbackScan = hasFallbackScan(leftChild)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only the scan fallback cause this issue, after filter, it may also occur?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No matter what, we should keep the two side are offloaded at the same time or not.

val rightHasFallbackScan = hasFallbackScan(rightChild)
if (leftHasFallbackScan != rightHasFallbackScan) {
val reason = "asymmetric DataSource scan fallback on " +
s"decimal join key: left=$leftHasFallbackScan right=$rightHasFallbackScan"
if (leftHasFallbackScan) {
addFallbackTagToNativeScan(rightChild, reason)
} else {
addFallbackTagToNativeScan(leftChild, reason)
}
}
}

/**
* Returns true if the plan node is a DataSource scan that participates in the decimal-key
* symmetry check.
*
* [[FileSourceScanExec]] is matched directly (compile-time dependency available in gluten-core).
* [[org.apache.spark.sql.hive.execution.HiveTableScanExec]] is matched by simple class name to
* avoid a direct dependency on `spark-hive` in this module.
*/
private def isDataSourceScan(plan: SparkPlan): Boolean =
plan.isInstanceOf[FileSourceScanExec] ||
plan.getClass.getSimpleName == "HiveTableScanExec"

/**
* Returns true if the given plan subtree contains at least one DataSource scan
* ([[FileSourceScanExec]] or `HiveTableScanExec`) that fails native validation and would fall
* back to vanilla Spark execution.
*
* Transparently descends through AQE stage wrappers ([[ShuffleQueryStageExec]] /
* [[BroadcastQueryStageExec]]) so that already-materialized stages are inspected correctly during
* AQE re-planning. For all other node types the check is propagated to children.
*
* @param plan
* the subtree to inspect
* @return
* true if any tracked scan in the subtree fails validation
*/
private def hasFallbackScan(plan: SparkPlan): Boolean =
plan match {
case q: ShuffleQueryStageExec =>
hasFallbackScan(q.plan)
case q: BroadcastQueryStageExec =>
hasFallbackScan(q.plan)
case scan if isDataSourceScan(scan) =>
validator.validate(scan) match {
case Validator.Passed => false
case Validator.Failed(_) => true
}
case _ => plan.children.exists(hasFallbackScan(_))
}

/**
* Recursively finds every DataSource scan ([[FileSourceScanExec]] or `HiveTableScanExec`) in the
* given plan subtree that currently passes native validation and adds a fallback tag with the
* supplied reason.
*
* Like [[hasFallbackScan]], this method descends transparently through [[ShuffleQueryStageExec]]
* and [[BroadcastQueryStageExec]] wrappers so it works correctly in both the initial planning
* pass and the AQE re-planning pass.
*
* @param plan
* the subtree to walk
* @param reason
* a human-readable explanation of why the scan is being forced to fall back (logged and
* surfaced in Gluten's fallback reporting)
*/
private def addFallbackTagToNativeScan(plan: SparkPlan, reason: String): Unit =
plan match {
case q: ShuffleQueryStageExec =>
addFallbackTagToNativeScan(q.plan, reason)
case q: BroadcastQueryStageExec =>
addFallbackTagToNativeScan(q.plan, reason)
case scan if isDataSourceScan(scan) =>
FallbackTags.add(scan, reason)
case _ => plan.children.foreach(addFallbackTagToNativeScan(_, reason))
}
}
Loading
Loading