@@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSh
2525import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanHelper , AQEShuffleReadExec }
2626import org .apache .spark .sql .execution .exchange .ShuffleExchangeExec
2727import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , BroadcastNestedLoopJoinExec , SortMergeJoinExec }
28+ import org .apache .spark .sql .internal .SQLConf
2829import org .apache .spark .utils .GlutenSuiteUtils
2930
3031import scala .collection .mutable .ArrayBuffer
@@ -66,12 +67,53 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
6667 .write
6768 .format(" parquet" )
6869 .saveAsTable(" tmp3" )
70+ // ORC files are written with DECIMAL(38, 18) (Hive's native storage precision).
71+ // tmp4/tmp5 declare DECIMAL(20, 0) pointing to the same ORC files,
72+ // so the reader must handle a precision/scale mismatch.
73+ spark
74+ .range(100 )
75+ .selectExpr(
76+ " cast(id as decimal(38, 18)) as c1" ,
77+ " cast(id % 3 as int) as c2" ,
78+ " cast(id % 9 as timestamp) as c3" )
79+ .write
80+ .format(" orc" )
81+ .saveAsTable(" tmp4_wide" )
82+ spark
83+ .range(100 )
84+ .selectExpr(
85+ " cast(id as decimal(38, 18)) as c1" ,
86+ " cast(id % 3 as int) as c2" ,
87+ " cast(id % 5 as timestamp) as c3" )
88+ .write
89+ .format(" orc" )
90+ .saveAsTable(" tmp5_wide" )
91+ val loc4 = spark
92+ .sql(" DESCRIBE FORMATTED tmp4_wide" )
93+ .filter(" col_name = 'Location'" )
94+ .select(" data_type" )
95+ .collect()(0 )
96+ .getString(0 )
97+ val loc5 = spark
98+ .sql(" DESCRIBE FORMATTED tmp5_wide" )
99+ .filter(" col_name = 'Location'" )
100+ .select(" data_type" )
101+ .collect()(0 )
102+ .getString(0 )
103+ spark.sql(
104+ s " CREATE TABLE tmp4 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) USING ORC LOCATION ' $loc4' " )
105+ spark.sql(
106+ s " CREATE TABLE tmp5 (c1 DECIMAL(20, 0), c2 INT, c3 TIMESTAMP) USING ORC LOCATION ' $loc5' " )
69107 }
70108
71109 override protected def afterAll (): Unit = {
72110 spark.sql(" drop table tmp1" )
73111 spark.sql(" drop table tmp2" )
74112 spark.sql(" drop table tmp3" )
113+ spark.sql(" drop table tmp4_wide" )
114+ spark.sql(" drop table tmp5_wide" )
115+ spark.sql(" drop table tmp4" )
116+ spark.sql(" drop table tmp5" )
75117
76118 super .afterAll()
77119 }
@@ -420,4 +462,127 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
420462 spark.sparkContext.removeSparkListener(listener)
421463 }
422464 }
465+
466+ test(" For decimal-key joins, if one side falls back to Spark, force fallback the other side" ) {
467+ // ORC files are written with DECIMAL(38, 18) (Hive's native storage precision).
468+ // The metastore tables tmp4/tmp5 declare DECIMAL(20, 0) and point to the
469+ // same ORC files, so the reader must handle a precision/scale mismatch.
470+ // Selecting only c2 (INT) -> native FileSourceScanExecTransformer.
471+ // Selecting c3 (TIMESTAMP) in addition -> native validation fails ->
472+ // vanilla FileSourceScanExec.
473+
474+ // -- SortMergeJoin ------------------------------------------------------------------
475+
476+ val sql1 = " SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
477+ " tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
478+ withSQLConf(
479+ GlutenConfig .COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED .key -> " false" ,
480+ GlutenConfig .COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED .key -> " false" ) {
481+ checkAnswer(
482+ spark.sql(sql1),
483+ spark.sql(
484+ " SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
485+ " tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
486+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
487+ )
488+ }
489+
490+ val sql2 = " SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
491+ " tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
492+ withSQLConf(
493+ GlutenConfig .COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED .key -> " false" ,
494+ GlutenConfig .COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED .key -> " false" ) {
495+ checkAnswer(
496+ spark.sql(sql2),
497+ spark.sql(
498+ " SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
499+ " tmp5_wide.c2 AS 5c2 " +
500+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
501+ )
502+ }
503+
504+ val sql3 = " SELECT /*+ MERGE(tmp4) */ tmp4.c2 AS 4c2, " +
505+ " tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
506+ withSQLConf(
507+ GlutenConfig .COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED .key -> " false" ,
508+ GlutenConfig .COLUMNAR_SHUFFLED_HASH_JOIN_ENABLED .key -> " false" ) {
509+ checkAnswer(
510+ spark.sql(sql3),
511+ spark.sql(
512+ " SELECT tmp4_wide.c2 AS 4c2, " +
513+ " tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
514+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
515+ )
516+ }
517+
518+ // -- ShuffledHashJoin ---------------------------------------------------------------
519+
520+ val sql4 = " SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
521+ " tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
522+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
523+ checkAnswer(
524+ spark.sql(sql4),
525+ spark.sql(
526+ " SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
527+ " tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
528+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
529+ )
530+ }
531+
532+ val sql5 = " SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
533+ " tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
534+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
535+ checkAnswer(
536+ spark.sql(sql5),
537+ spark.sql(
538+ " SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
539+ " tmp5_wide.c2 AS 5c2 " +
540+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
541+ )
542+ }
543+
544+ val sql6 = " SELECT /*+ SHUFFLE_HASH(tmp4) */ tmp4.c2 AS 4c2, " +
545+ " tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
546+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
547+ checkAnswer(
548+ spark.sql(sql6),
549+ spark.sql(
550+ " SELECT tmp4_wide.c2 AS 4c2, " +
551+ " tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
552+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
553+ )
554+ }
555+
556+ // -- BroadcastHashJoin --------------------------------------------------------------
557+
558+ val sql7 = " SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
559+ " tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
560+ checkAnswer(
561+ spark.sql(sql7),
562+ spark.sql(
563+ " SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
564+ " tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
565+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
566+ )
567+
568+ val sql8 = " SELECT tmp4.c2 AS 4c2, tmp4.c3 AS 4c3, " +
569+ " tmp5.c2 AS 5c2 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
570+ checkAnswer(
571+ spark.sql(sql8),
572+ spark.sql(
573+ " SELECT tmp4_wide.c2 AS 4c2, tmp4_wide.c3 AS 4c3, " +
574+ " tmp5_wide.c2 AS 5c2 " +
575+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
576+ )
577+
578+ val sql9 = " SELECT tmp4.c2 AS 4c2, " +
579+ " tmp5.c2 AS 5c2, tmp5.c3 AS 5c3 FROM tmp4 JOIN tmp5 ON tmp4.c1 = tmp5.c1"
580+ checkAnswer(
581+ spark.sql(sql9),
582+ spark.sql(
583+ " SELECT tmp4_wide.c2 AS 4c2, " +
584+ " tmp5_wide.c2 AS 5c2, tmp5_wide.c3 AS 5c3 " +
585+ " FROM tmp4_wide JOIN tmp5_wide ON tmp4_wide.c1 = tmp5_wide.c1" )
586+ )
587+ }
423588}
0 commit comments