From 74c3bec0521960253888a931d9979b88231e4309 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Fri, 15 Aug 2025 18:30:27 +0400 Subject: [PATCH 1/4] bugfix: fixed misaligned partitions in plans with broadcast reused exchanges --- .../comet/CometBroadcastExchangeExec.scala | 2 +- .../apache/spark/sql/comet/operators.scala | 21 +++++++++-- .../apache/comet/exec/CometExecSuite.scala | 37 +++++++++++++++++++ 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 21b395982b..1760e498a5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -277,7 +277,7 @@ object CometBroadcastExchangeExec { class CometBatchRDD( sc: SparkContext, numPartitions: Int, - value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) + private[comet] val value: broadcast.Broadcast[Array[ChunkedByteBuffer]]) extends RDD[ColumnarBatch](sc, Nil) { override def getPartitions: Array[Partition] = (0 until numPartitions).toArray.map { i => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 593f4f3a45..708cd3d16e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -233,13 +233,19 @@ abstract class CometNativeExec extends CometExec { foreachUntilCometInput(this)(sparkPlans += _) - // Find the first non broadcast plan - val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { + val nonBroadcastPredicate: (SparkPlan, Int) => Boolean = { case (_: CometBroadcastExchangeExec, _) => false case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false case _ => true } + // Find the first non broadcast plan which is not a reused exchange if possible + val firstNonBroadcastPlan = sparkPlans.zipWithIndex + .find { + case (p, idx) if !p.isInstanceOf[ReusedExchangeExec] => nonBroadcastPredicate(p, idx) + case _ => false + } + .orElse(sparkPlans.zipWithIndex.find(nonBroadcastPredicate.tupled)) val containsBroadcastInput = sparkPlans.exists { case _: CometBroadcastExchangeExec => true @@ -303,7 +309,16 @@ abstract class CometNativeExec extends CometExec { } if (inputs.nonEmpty) { - ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) + val fixedInputs = inputs.map { + case cometBatchRDD: CometBatchRDD + if cometBatchRDD.getNumPartitions != firstNonBroadcastPlanNumPartitions => + new CometBatchRDD( + sparkContext, + firstNonBroadcastPlanNumPartitions, + cometBatchRDD.value) + case other => other + } + ZippedPartitionsRDD(sparkContext, fixedInputs.toSeq)(createCometExecIter) } else { val partitionNum = firstNonBroadcastPlanNumPartitions CometExecRDD(sparkContext, partitionNum)(createCometExecIter) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 339f90e81c..c350ee79af 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -2107,6 +2107,43 @@ class CometExecSuite extends CometTestBase { } } + test("ReusedExchange broadcast with incompatible partitions number does not fail") { + // enforce different number of partitions for future broadcasts/exchanges + spark + .range(50) + .withColumnRenamed("id", "x") + .repartition(2) + .writeTo("tbl1") + .using("parquet") + .create() + spark + .range(50) + .withColumnRenamed("id", "y") + .repartition(3) + .writeTo("tbl2") + .using("parquet") + .create() + spark + .range(50) + .withColumnRenamed("id", "z") + .repartition(4) + .writeTo("tbl3") + .using("parquet") + .create() + val df1 = spark.table("tbl1") + val df2 = spark.table("tbl2") + val df3 = spark.table("tbl3") + + val dfWithReusedExchange = df1 + .join(df3.hint("broadcast").join(df1, $"x" === $"z"), "x", "left") + .join( + df3.hint("broadcast").join(df2, $"y" === $"z").withColumnRenamed("z", "z1"), + $"x" === $"y") + + checkSparkAnswerAndOperator(dfWithReusedExchange) + + } + test("SparkToColumnar override node name for columnar input") { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "", From 8eadd56fbb7b4079e98156e648a3af81625cdd23 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Fri, 15 Aug 2025 23:13:21 +0400 Subject: [PATCH 2/4] fixed unit test --- .../org/apache/comet/exec/CometExecSuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index c350ee79af..37f4cb25c1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -2134,13 +2134,15 @@ class CometExecSuite extends CometTestBase { val df2 = spark.table("tbl2") val df3 = spark.table("tbl3") - val dfWithReusedExchange = df1 - .join(df3.hint("broadcast").join(df1, $"x" === $"z"), "x", "left") - .join( - df3.hint("broadcast").join(df2, $"y" === $"z").withColumnRenamed("z", "z1"), - $"x" === $"y") - - checkSparkAnswerAndOperator(dfWithReusedExchange) + Seq("true", "false").foreach(aqeEnabled => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { + val dfWithReusedExchange = df1 + .join(df3.hint("broadcast").join(df1, $"x" === $"z"), "x", "right") + .join( + df3.hint("broadcast").join(df2, $"y" === $"z", "right").withColumnRenamed("z", "z1"), + $"x" === $"y") + checkSparkAnswerAndOperator(dfWithReusedExchange, classOf[ReusedExchangeExec]) + }) } From 3ed46b8d185102949ab43109992af044b832a599 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Sat, 16 Aug 2025 00:20:05 +0400 Subject: [PATCH 3/4] fixed clean up in unit test --- .../apache/comet/exec/CometExecSuite.scala | 75 ++++++++++--------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 37f4cb25c1..ce71a2973c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -2108,42 +2108,45 @@ class CometExecSuite extends CometTestBase { } test("ReusedExchange broadcast with incompatible partitions number does not fail") { - // enforce different number of partitions for future broadcasts/exchanges - spark - .range(50) - .withColumnRenamed("id", "x") - .repartition(2) - .writeTo("tbl1") - .using("parquet") - .create() - spark - .range(50) - .withColumnRenamed("id", "y") - .repartition(3) - .writeTo("tbl2") - .using("parquet") - .create() - spark - .range(50) - .withColumnRenamed("id", "z") - .repartition(4) - .writeTo("tbl3") - .using("parquet") - .create() - val df1 = spark.table("tbl1") - val df2 = spark.table("tbl2") - val df3 = spark.table("tbl3") - - Seq("true", "false").foreach(aqeEnabled => - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { - val dfWithReusedExchange = df1 - .join(df3.hint("broadcast").join(df1, $"x" === $"z"), "x", "right") - .join( - df3.hint("broadcast").join(df2, $"y" === $"z", "right").withColumnRenamed("z", "z1"), - $"x" === $"y") - checkSparkAnswerAndOperator(dfWithReusedExchange, classOf[ReusedExchangeExec]) - }) - + withTable("tbl1", "tbl2", "tbl3") { + // enforce different number of partitions for future broadcasts/exchanges + spark + .range(50) + .withColumnRenamed("id", "x") + .repartition(2) + .writeTo("tbl1") + .using("parquet") + .create() + spark + .range(50) + .withColumnRenamed("id", "y") + .repartition(3) + .writeTo("tbl2") + .using("parquet") + .create() + spark + .range(50) + .withColumnRenamed("id", "z") + .repartition(4) + .writeTo("tbl3") + .using("parquet") + .create() + val df1 = spark.table("tbl1") + val df2 = spark.table("tbl2") + val df3 = spark.table("tbl3") + Seq("true", "false").foreach(aqeEnabled => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { + val dfWithReusedExchange = df1 + .join(df3.hint("broadcast").join(df1, $"x" === $"z"), "x", "right") + .join( + df3 + .hint("broadcast") + .join(df2, $"y" === $"z", "right") + .withColumnRenamed("z", "z1"), + $"x" === $"y") + checkSparkAnswerAndOperator(dfWithReusedExchange, classOf[ReusedExchangeExec]) + }) + } } test("SparkToColumnar override node name for columnar input") { From 9c62ea5c5151244b13418e733f35318684c7e4c5 Mon Sep 17 00:00:00 2001 From: Artem Kupchinskiy Date: Sat, 16 Aug 2025 17:16:17 +0400 Subject: [PATCH 4/4] clarified and simplified picking nonbroadcast plan logic --- .../main/scala/org/apache/spark/sql/comet/operators.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 708cd3d16e..2ac51e91db 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -242,10 +242,10 @@ abstract class CometNativeExec extends CometExec { // Find the first non broadcast plan which is not a reused exchange if possible val firstNonBroadcastPlan = sparkPlans.zipWithIndex .find { - case (p, idx) if !p.isInstanceOf[ReusedExchangeExec] => nonBroadcastPredicate(p, idx) - case _ => false + case (ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => false + case (p, idx) => nonBroadcastPredicate(p, idx) } - .orElse(sparkPlans.zipWithIndex.find(nonBroadcastPredicate.tupled)) + .orElse(sparkPlans.zipWithIndex.find(_._1.isInstanceOf[ReusedExchangeExec])) val containsBroadcastInput = sparkPlans.exists { case _: CometBroadcastExchangeExec => true