Skip to content

Commit a8ad249

Browse files
committed
feat: enable iceberg compat tests, more tests for complex types
1 parent 81d7eb9 commit a8ad249

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

+11-11
Original file line numberDiff line numberDiff line change
@@ -384,13 +384,13 @@ class CometSparkSessionExtensions
384384

385385
// Comet JVM + native scan for V1 and V2
386386
case op if isCometScan(op) =>
387-
val nativeOp = QueryPlanSerde.operator2Proto(op).get
388-
CometScanWrapper(nativeOp, op)
387+
val nativeOp = QueryPlanSerde.operator2Proto(op)
388+
CometScanWrapper(nativeOp.get, op)
389389

390390
case op if shouldApplySparkToColumnar(conf, op) =>
391391
val cometOp = CometSparkToColumnarExec(op)
392-
val nativeOp = QueryPlanSerde.operator2Proto(cometOp).get
393-
CometScanWrapper(nativeOp, cometOp)
392+
val nativeOp = QueryPlanSerde.operator2Proto(cometOp)
393+
CometScanWrapper(nativeOp.get, cometOp)
394394

395395
case op: ProjectExec =>
396396
val newOp = transform1(op)
@@ -498,15 +498,15 @@ class CometSparkSessionExtensions
498498
val child = op.child
499499
val modes = aggExprs.map(_.mode).distinct
500500

501-
if (!modes.isEmpty && modes.size != 1) {
501+
if (modes.nonEmpty && modes.size != 1) {
502502
// This shouldn't happen as all aggregation expressions should share the same mode.
503503
// Fallback to Spark nevertheless here.
504504
op
505505
} else {
506506
// For a final mode HashAggregate, we only need to transform the HashAggregate
507507
// if there is Comet partial aggregation.
508508
val sparkFinalMode = {
509-
!modes.isEmpty && modes.head == Final && findCometPartialAgg(child).isEmpty
509+
modes.nonEmpty && modes.head == Final && findCometPartialAgg(child).isEmpty
510510
}
511511

512512
if (sparkFinalMode) {
@@ -520,7 +520,7 @@ class CometSparkSessionExtensions
520520
// distinct aggregate functions or only have group by, the aggExprs is empty and
521521
// modes is empty too. If aggExprs is not empty, we need to verify all the
522522
// aggregates have the same mode.
523-
assert(modes.length == 1 || modes.length == 0)
523+
assert(modes.length == 1 || modes.isEmpty)
524524
CometHashAggregateExec(
525525
nativeOp,
526526
op,
@@ -529,7 +529,7 @@ class CometSparkSessionExtensions
529529
aggExprs,
530530
resultExpressions,
531531
child.output,
532-
if (modes.nonEmpty) Some(modes.head) else None,
532+
modes.headOption,
533533
child,
534534
SerializedPlan(None))
535535
case None =>
@@ -540,7 +540,7 @@ class CometSparkSessionExtensions
540540

541541
case op: ShuffledHashJoinExec
542542
if CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
543-
op.children.forall(isCometNative(_)) =>
543+
op.children.forall(isCometNative) =>
544544
val newOp = transform1(op)
545545
newOp match {
546546
case Some(nativeOp) =>
@@ -574,7 +574,7 @@ class CometSparkSessionExtensions
574574

575575
case op: BroadcastHashJoinExec
576576
if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
577-
op.children.forall(isCometNative(_)) =>
577+
op.children.forall(isCometNative) =>
578578
val newOp = transform1(op)
579579
newOp match {
580580
case Some(nativeOp) =>
@@ -1288,7 +1288,7 @@ object CometSparkSessionExtensions extends Logging {
12881288
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
12891289
}
12901290

1291-
private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
1291+
def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
12921292
// Only consider converting leaf nodes to columnar currently, so that all the following
12931293
// operators can have a chance to be converted to columnar. Leaf operators that output
12941294
// columnar batches, such as Spark's vectorized readers, will also be converted to native

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

+4-1
Original file line numberDiff line numberDiff line change
@@ -2721,7 +2721,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
27212721

27222722
case op
27232723
if isCometSink(op) && op.output.forall(a =>
2724-
supportedDataType(a.dataType, allowComplex = usingDataFusionParquetExec(conf))) =>
2724+
supportedDataType(
2725+
a.dataType,
2726+
allowComplex =
2727+
usingDataFusionParquetExec(conf) || op.isInstanceOf[CometSparkToColumnarExec])) =>
27252728
// These operators are source of Comet native execution chain
27262729
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
27272730
val source = op.simpleStringWithNodeId()

0 commit comments

Comments
 (0)