diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index af16c41f2ee30..4d33ed81641d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, ExtractSingleColumnNullAwareAntiJoin} import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastPartitioning, IdentityBroadcastMode} import org.apache.spark.sql.classic.Strategy import org.apache.spark.sql.execution.{joins, SparkPlan} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashedRelationBroadcastMode} @@ -41,8 +41,8 @@ object LogicalQueryStageStrategy extends Strategy { plan: LogicalPlan, isNullAware: Boolean): Boolean = plan match { case LogicalQueryStage(_, bqs: BroadcastQueryStageExec) => - bqs.broadcast.mode match { - case HashedRelationBroadcastMode(_, stageIsNullAware) => + bqs.broadcast.outputPartitioning match { + case BroadcastPartitioning(HashedRelationBroadcastMode(_, stageIsNullAware)) => stageIsNullAware == isNullAware case _ => false } @@ -51,8 +51,8 @@ object LogicalQueryStageStrategy extends Strategy { private def isBroadcastStageWithIdentityBroadcastMode(plan: LogicalPlan): Boolean = plan match { case LogicalQueryStage(_, bqs: BroadcastQueryStageExec) => - bqs.broadcast.mode match { - case IdentityBroadcastMode => true + bqs.broadcast.outputPartitioning match { + case BroadcastPartitioning(IdentityBroadcastMode) => true case _ => false } case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 614822813698b..8c695f4f3958d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -45,8 +45,6 @@ import org.apache.spark.util.{SparkFatalException, ThreadUtils} */ trait BroadcastExchangeLike extends Exchange { - def mode: BroadcastMode - /** * The broadcast run ID in job tag */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index ebe6d8858a7e3..bfcf583a7051b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.{PlanTest, SQLHelper} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AggregateHint, ColumnStat, Limit, LocalRelation, LogicalPlan, Project, Range, Sort, SortHint, Statistics, UnresolvedHint} -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.classic.ClassicConversions._ @@ -1175,7 +1175,6 @@ case class MyShuffleExchangeExec(delegate: ShuffleExchangeExec) extends ShuffleE * whether AQE is enabled. */ case class MyBroadcastExchangeExec(delegate: BroadcastExchangeExec) extends BroadcastExchangeLike { - override def mode: BroadcastMode = delegate.mode override val runId: UUID = delegate.runId override def relationFuture: java.util.concurrent.Future[Broadcast[Any]] = delegate.relationFuture