-
Notifications
You must be signed in to change notification settings - Fork 605
[GLUTEN-11980][CORE] For decimal-key joins, if one side falls back to Spark, force fallback the other side #12000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,10 @@ import org.apache.gluten.extension.columnar.FallbackTags | |
| import org.apache.gluten.extension.columnar.validator.Validator | ||
|
|
||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} | ||
| import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} | ||
| import org.apache.spark.sql.types.DecimalType | ||
|
|
||
| // Add fallback tags when validator returns negative outcome. | ||
| case class AddFallbackTags(validator: Validator) extends Rule[SparkPlan] { | ||
|
|
@@ -29,6 +32,9 @@ case class AddFallbackTags(validator: Validator) extends Rule[SparkPlan] { | |
| case p if FallbackTags.maybeOffloadable(p) => addFallbackTag(p) | ||
| case _ => | ||
| } | ||
|
|
||
| plan.foreach(validateJoin) | ||
|
|
||
| plan | ||
| } | ||
|
|
||
|
|
@@ -40,4 +46,131 @@ case class AddFallbackTags(validator: Validator) extends Rule[SparkPlan] { | |
| case Validator.Passed => | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Traverses the plan tree looking for join nodes (SortMergeJoin, ShuffledHashJoin, | ||
| * BroadcastHashJoin) whose join keys include at least one decimal column. | ||
| * | ||
| * For each such join, delegates to [[setFallbackTagForOtherSide]] to ensure that if one side's | ||
| * scan ([[FileSourceScanExec]] or `HiveTableScanExec`) cannot be offloaded to the native engine, | ||
| * the other side is also forced to fall back. This prevents a decimal-value mismatch that would | ||
| * produce incorrect (typically empty) join results when one side applies Spark's precision | ||
| * coercion and the other side reads raw native values. | ||
| * | ||
| * AdaptiveSparkPlanExec is handled by descending into its `initialPlan`; all other non-join nodes | ||
| * are handled recursively through their children. | ||
| */ | ||
| private def validateJoin(plan: SparkPlan): Unit = | ||
| plan match { | ||
| case smj: SortMergeJoinExec | ||
| if (smj.leftKeys ++ smj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) => | ||
| setFallbackTagForOtherSide(smj.left, smj.right) | ||
| case shj: ShuffledHashJoinExec | ||
| if (shj.leftKeys ++ shj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) => | ||
| setFallbackTagForOtherSide(shj.left, shj.right) | ||
| case bhj: BroadcastHashJoinExec | ||
| if (bhj.leftKeys ++ bhj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) => | ||
| setFallbackTagForOtherSide(bhj.left, bhj.right) | ||
| case a: AdaptiveSparkPlanExec => | ||
| validateJoin(a.initialPlan) | ||
| case _ => plan.children.foreach(validateJoin(_)) | ||
| } | ||
|
|
||
| /** | ||
| * Enforces symmetric scan fallback for the two sides of a decimal-key join. | ||
| * | ||
| * When the join key is a decimal type, a native (Velox) scan and a vanilla Spark scan | ||
| * ([[FileSourceScanExec]] or `HiveTableScanExec`) may produce different representations of the | ||
| * same decimal value: the native reader may surface raw uncoerced int128_t values while the | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we update the native side to support this case?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If one side fallbacks, this side should insert ColumnarToRow, why this representation issue?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation principle of decimal determines their different accuracies, and it is a difficult problem for me to solve now.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should find the root cause and fix the result mismatch issue in native side, other one in community may fix this issue, please keep this issue open now.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The root cause is the decimal accuracy is different from Java side and native side. The implementation mechanism for decimal in Velox is significantly different from that in Java, and I think it will be difficult to make them consistent in the short term.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jinchengchenghh I added more information into #11980. |
||
| * vanilla reader applies Spark's precision coercion (returning NULL for out-of-range values). If | ||
| * only one side falls back, the join key values diverge and the join returns 0 rows. | ||
| * | ||
| * This method detects the asymmetric case (exactly one side contains a fallback scan) and adds a | ||
| * fallback tag to the native scan on the other side, so that both sides end up using the same | ||
| * read path. | ||
| * | ||
| * @param leftChild | ||
| * the left subtree of the join | ||
| * @param rightChild | ||
| * the right subtree of the join | ||
| */ | ||
| private def setFallbackTagForOtherSide(leftChild: SparkPlan, rightChild: SparkPlan): Unit = { | ||
| val leftHasFallbackScan = hasFallbackScan(leftChild) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not only the scan fallback cause this issue, after filter, it may also occur?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No matter what, we should keep the two side are offloaded at the same time or not. |
||
| val rightHasFallbackScan = hasFallbackScan(rightChild) | ||
| if (leftHasFallbackScan != rightHasFallbackScan) { | ||
| val reason = "asymmetric DataSource scan fallback on " + | ||
| s"decimal join key: left=$leftHasFallbackScan right=$rightHasFallbackScan" | ||
| if (leftHasFallbackScan) { | ||
| addFallbackTagToNativeScan(rightChild, reason) | ||
| } else { | ||
| addFallbackTagToNativeScan(leftChild, reason) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns true if the plan node is a DataSource scan that participates in the decimal-key | ||
| * symmetry check. | ||
| * | ||
| * [[FileSourceScanExec]] is matched directly (compile-time dependency available in gluten-core). | ||
| * [[org.apache.spark.sql.hive.execution.HiveTableScanExec]] is matched by simple class name to | ||
| * avoid a direct dependency on `spark-hive` in this module. | ||
| */ | ||
| private def isDataSourceScan(plan: SparkPlan): Boolean = | ||
| plan.isInstanceOf[FileSourceScanExec] || | ||
| plan.getClass.getSimpleName == "HiveTableScanExec" | ||
|
|
||
| /** | ||
| * Returns true if the given plan subtree contains at least one DataSource scan | ||
| * ([[FileSourceScanExec]] or `HiveTableScanExec`) that fails native validation and would fall | ||
| * back to vanilla Spark execution. | ||
| * | ||
| * Transparently descends through AQE stage wrappers ([[ShuffleQueryStageExec]] / | ||
| * [[BroadcastQueryStageExec]]) so that already-materialized stages are inspected correctly during | ||
| * AQE re-planning. For all other node types the check is propagated to children. | ||
| * | ||
| * @param plan | ||
| * the subtree to inspect | ||
| * @return | ||
| * true if any tracked scan in the subtree fails validation | ||
| */ | ||
| private def hasFallbackScan(plan: SparkPlan): Boolean = | ||
| plan match { | ||
| case q: ShuffleQueryStageExec => | ||
| hasFallbackScan(q.plan) | ||
| case q: BroadcastQueryStageExec => | ||
| hasFallbackScan(q.plan) | ||
| case scan if isDataSourceScan(scan) => | ||
| validator.validate(scan) match { | ||
| case Validator.Passed => false | ||
| case Validator.Failed(_) => true | ||
| } | ||
| case _ => plan.children.exists(hasFallbackScan(_)) | ||
| } | ||
|
|
||
| /** | ||
| * Recursively finds every DataSource scan ([[FileSourceScanExec]] or `HiveTableScanExec`) in the | ||
| * given plan subtree that currently passes native validation and adds a fallback tag with the | ||
| * supplied reason. | ||
| * | ||
| * Like [[hasFallbackScan]], this method descends transparently through [[ShuffleQueryStageExec]] | ||
| * and [[BroadcastQueryStageExec]] wrappers so it works correctly in both the initial planning | ||
| * pass and the AQE re-planning pass. | ||
| * | ||
| * @param plan | ||
| * the subtree to walk | ||
| * @param reason | ||
| * a human-readable explanation of why the scan is being forced to fall back (logged and | ||
| * surfaced in Gluten's fallback reporting) | ||
| */ | ||
| private def addFallbackTagToNativeScan(plan: SparkPlan, reason: String): Unit = | ||
| plan match { | ||
| case q: ShuffleQueryStageExec => | ||
| addFallbackTagToNativeScan(q.plan, reason) | ||
| case q: BroadcastQueryStageExec => | ||
| addFallbackTagToNativeScan(q.plan, reason) | ||
| case scan if isDataSourceScan(scan) => | ||
| FallbackTags.add(scan, reason) | ||
| case _ => plan.children.foreach(addFallbackTagToNativeScan(_, reason)) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we move this to a separate rule?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created another rule
AddFallbackTagsForJoinlikeAddFallbackTagsand change the code ofWithRewriteslike:Then we must change the constructor of
RewriteSparkPlanRulesManagerto acceptAddFallbackTagsForJoinand change the logic ofRewriteSparkPlanRulesManager.I think it's not worth to do this, so I merged the code into
AddFallbackTags.