diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a8a3df0c1..3b63e6311 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1980,6 +1980,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: ArraysOverlap => convert(CometArraysOverlap) case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] => convert(CometArrayCompact) + case _: ArrayExcept => + convert(CometArrayExcept) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 8550d5201..0e96b543d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -19,7 +19,9 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove, Attribute, Expression, Literal} +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, Expression, Literal} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -197,6 +199,44 @@ object CometArrayCompact extends CometExpressionSerde with IncompatExpr { } } +object CometArrayExcept extends CometExpressionSerde with CometExprShim with IncompatExpr { + + @tailrec + def isTypeSupported(dt: DataType): Boolean = { + import DataTypes._ + dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | DateType | TimestampType | TimestampNTZType | StringType => + true + case BinaryType => false + case ArrayType(elementType, _) => isTypeSupported(elementType) + case _: StructType => + false + case _ => false + } + } + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExceptExpr = expr.asInstanceOf[ArrayExcept] + val inputTypes = arrayExceptExpr.children.map(_.dataType).toSet + for (dt <- inputTypes) { + if (!isTypeSupported(dt)) { + withInfo(expr, s"data type not supported: $dt") + return None + } + } + val leftArrayExprProto = exprToProto(arrayExceptExpr.left, inputs, binding) + val rightArrayExprProto = exprToProto(arrayExceptExpr.right, inputs, binding) + + val arrayExceptScalarExpr = + scalarExprToProto("array_except", leftArrayExprProto, rightArrayExprProto) + optExprWithInfo(arrayExceptScalarExpr, expr, expr.children: _*) + } +} + object CometArrayJoin extends CometExpressionSerde with IncompatExpr { override def convert( expr: Expression, diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index cef48c50c..48d51842a 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.functions.{array, col, expr, lit, udf} import org.apache.spark.sql.types.StructType import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus +import org.apache.comet.serde.CometArrayExcept import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -318,4 +319,100 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_except - basic test (only integer values)") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + checkSparkAnswerAndOperator( + sql("SELECT array_except(array(_2, _3, _4), array(_3, _4)) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_except(array(_18), array(_19)) from t1")) + checkSparkAnswerAndOperator( + spark.sql( + "SELECT array_except(array(_2, _2, _4), array(_4)) FROM t1 WHERE _2 IS NOT NULL")) + } + } + } + } + + test("array_except - test all types (native Parquet reader)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + // test with array of each column + val fields = + table.schema.fields.filter(field => CometArrayExcept.isTypeSupported(field.dataType)) + for (field <- fields) { + val fieldName = field.name + val typeName = field.dataType.typeName + sql( + s"SELECT cast(array($fieldName, $fieldName) as array<$typeName>) as a, cast(array($fieldName) as array<$typeName>) as b FROM t1") + .createOrReplaceTempView("t2") + val df = sql("SELECT array_except(a, b) FROM t2") + checkSparkAnswerAndOperator(df) + } + } + } + } + + test("array_except - test all types (convert from Parquet)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val options = DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false) + ParquetGenerator.makeParquetFile(random, spark, filename, 100, options) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true", + CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + // test with array of each column + val fields = + table.schema.fields.filter(field => CometArrayExcept.isTypeSupported(field.dataType)) + for (field <- fields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, array($fieldName) as b FROM t1") + .createOrReplaceTempView("t2") + val df = sql("SELECT array_except(a, b) FROM t2") + field.dataType match { + case _: StructType => + // skip due to https://github.com/apache/datafusion-comet/issues/1314 + case _ => + checkSparkAnswer(df) + } + } + } + } + } + }