diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5efad83bcba78..154a58309cff7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -328,7 +328,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { .getOrElse(createJoinWithoutHint()) } - case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) => + case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys) + if canBroadcastBySize(j.right, conf) => Seq(joins.BroadcastHashJoinExec(leftKeys, rightKeys, LeftAnti, BuildRight, None, planLater(j.left), planLater(j.right), isNullAwareAntiJoin = true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 885512d4d1980..b5393cec451a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -1271,6 +1271,22 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } + test("SPARK-45846: optimizeNullAwareAntiJoin should respect autoBroadcastJoinThreshold") { + withSQLConf(SQLConf.OPTIMIZE_NULL_AWARE_ANTI_JOIN.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // When broadcast is disabled, null-aware anti-join should not use BroadcastHashJoinExec + val df = sql("select * from testData where key not in (select a from testData2)") + val physical = df.queryExecution.sparkPlan + // Collect all BroadcastHashJoinExec nodes + val broadcastHashJoins = collect(physical) { + case j: BroadcastHashJoinExec => j + } + // Verify no BroadcastHashJoinExec with isNullAwareAntiJoin is used + assert(broadcastHashJoins.forall(!_.isNullAwareAntiJoin), + "Null-aware anti-join should not use BroadcastHashJoinExec when broadcast is disabled") + } + } + test("SPARK-32399: Full outer shuffled hash join") { val inputDFs = Seq( // Test unique join key