diff --git a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/OffloadDeltaCommand.scala b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/OffloadDeltaCommand.scala index f4b74e4dbec2..d57a9ec157e2 100644 --- a/backends-velox/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/OffloadDeltaCommand.scala +++ b/backends-velox/src-delta33/main/scala/org/apache/spark/sql/execution/datasources/v2/OffloadDeltaCommand.scala @@ -49,11 +49,9 @@ case class OffloadDeltaCommand() extends OffloadSingleNode with DeltaCommand { } } - // Currently only plain OPTIMIZE bin-packing is supported for command offload. OPTIMIZE - // variants with layout-specific semantics, such as ZORDER, REORG, OPTIMIZE FULL, or - // liquid clustering, continue to use Delta's original command path. + // Currently OPTIMIZE bin-packing and ZORDER are supported for command offload. + // REORG, OPTIMIZE FULL, and liquid clustering continue to use Delta's original command path. private def shouldOffloadOptimize(optimize: OptimizeTableCommand): Boolean = { - optimize.zOrderBy.isEmpty && optimize.optimizeContext.reorg.isEmpty && !optimize.optimizeContext.isFull && !isClusteredOptimize(optimize) diff --git a/backends-velox/src-delta33/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala b/backends-velox/src-delta33/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala index 262057118695..d8aacdf307be 100644 --- a/backends-velox/src-delta33/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala +++ b/backends-velox/src-delta33/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala @@ -263,6 +263,119 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { } } + test("native delta optimize zorder command should be offloaded") { + withNativeWriteOffloadConf { + withTempDir { + dir => + val path = dir.getCanonicalPath + spark + .range(0, 128, 1, 8) + .selectExpr("id", "cast(id % 4 as int) as part") + .write + .format("delta") + .mode("append") + .save(path) + spark + .range(128, 256, 1, 8) + .selectExpr("id", "cast(id % 4 as int) as part") + .write + .format("delta") + .mode("append") + .save(path) + + val deltaLog = DeltaLog.forTable(spark, path) + val beforeFiles = files(deltaLog) + + withSQLConf( + DeltaSQLConf.DELTA_OPTIMIZE_ZORDER_COL_STAT_CHECK.key -> "false", + DeltaSQLConf.MDC_ADD_NOISE.key -> "false") { + val optimizeDf = sql(s"OPTIMIZE delta.`$path` ZORDER BY (id, part)") + assertContainsNativeWriteCommand( + optimizeDf.queryExecution.executedPlan, + "OPTIMIZE ZORDER BY") + val metrics = collectOptimizeMetrics(optimizeDf) + + val afterFiles = files(deltaLog) + assertCompactionMetrics(metrics, beforeFiles.size, afterFiles.size, "OPTIMIZE ZORDER") + assertOptimizeCommit(deltaLog, "OPTIMIZE ZORDER") + } + + val result = spark.read.format("delta").load(path) + val summary = result.selectExpr("count(*)", "min(id)", "max(id)").head() + assert(summary.getLong(0) == 256L) + assert(summary.getLong(1) == 0L) + assert(summary.getLong(2) == 255L) + } + } + } + + test("native delta optimize zorder partition predicate command should be offloaded") { + withNativeWriteOffloadConf { + withTempDir { + dir => + val path = dir.getCanonicalPath + spark + .range(0, 40, 1, 4) + .selectExpr("id", "cast(id % 2 as int) as part") + .write + .format("delta") + .partitionBy("part") + .mode("append") + .save(path) + spark + .range(40, 80, 1, 4) + .selectExpr("id", "cast(id % 2 as int) as part") + .write + .format("delta") + .partitionBy("part") + .mode("append") + .save(path) + + val deltaLog = DeltaLog.forTable(spark, path) + val beforeFiles = files(deltaLog) + val beforePart0Paths = beforeFiles + .filter(_.partitionValues.get("part").contains("0")) + .map(_.path) + val beforePart1Count = beforeFiles.count(_.partitionValues.get("part").contains("1")) + + withSQLConf( + DeltaSQLConf.DELTA_OPTIMIZE_ZORDER_COL_STAT_CHECK.key -> "false", + DeltaSQLConf.MDC_ADD_NOISE.key -> "false") { + val optimizeDf = sql(s"OPTIMIZE delta.`$path` WHERE part = 1 ZORDER BY (id)") + assertContainsNativeWriteCommand( + optimizeDf.queryExecution.executedPlan, + "OPTIMIZE WHERE ZORDER BY") + val metrics = collectOptimizeMetrics(optimizeDf) + + val afterFiles = files(deltaLog) + val afterPart0Paths = afterFiles + .filter(_.partitionValues.get("part").contains("0")) + .map(_.path) + val afterPart1Count = afterFiles.count(_.partitionValues.get("part").contains("1")) + assert( + beforePart0Paths.subsetOf(afterPart0Paths), + "OPTIMIZE WHERE part = 1 ZORDER should not remove files from part = 0") + assert( + afterPart1Count < beforePart1Count, + s"Expected fewer active files in part = 1, before=$beforePart1Count " + + s"after=$afterPart1Count") + assertCompactionMetrics( + metrics, + beforeFiles.size, + afterFiles.size, + "partition predicate OPTIMIZE ZORDER", + expectedPartitionsOptimized = Some(1L)) + assertOptimizeCommit(deltaLog, "partition predicate OPTIMIZE ZORDER") + } + + val result = spark.read.format("delta").load(path) + assert(result.select("id").collect().map(_.getLong(0)).toSet == (0L until 80L).toSet) + assert(result.where("part = 0").count() == 40) + assert(result.where("part = 1").count() == 40) + } + } + } + test("delta optimize command should not be offloaded when native write is disabled") { withNativeWriteOffloadConf { withTempDir { diff --git a/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala b/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala index bca0a66d1ad6..3d9423b368f7 100644 --- a/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala +++ b/backends-velox/src-delta40/test/scala/org/apache/spark/sql/delta/DeltaNativeWriteSuite.scala @@ -468,6 +468,119 @@ class DeltaNativeWriteSuite extends DeltaSQLCommandTest { } } + test("native delta optimize zorder command should be offloaded") { + withNativeWriteOffloadConf { + withTempDir { + dir => + val path = dir.getCanonicalPath + spark + .range(0, 128, 1, 8) + .selectExpr("id", "cast(id % 4 as int) as part") + .write + .format("delta") + .mode("append") + .save(path) + spark + .range(128, 256, 1, 8) + .selectExpr("id", "cast(id % 4 as int) as part") + .write + .format("delta") + .mode("append") + .save(path) + + val deltaLog = DeltaLog.forTable(spark, path) + val beforeFiles = files(deltaLog) + + withSQLConf( + DeltaSQLConf.DELTA_OPTIMIZE_ZORDER_COL_STAT_CHECK.key -> "false", + DeltaSQLConf.MDC_ADD_NOISE.key -> "false") { + val optimizeDf = sql(s"OPTIMIZE delta.`$path` ZORDER BY (id, part)") + assertContainsNativeWriteCommand( + Seq(optimizeDf.queryExecution.executedPlan), + "OPTIMIZE ZORDER BY") + val metrics = collectOptimizeMetrics(optimizeDf) + + val afterFiles = files(deltaLog) + assertCompactionMetrics(metrics, beforeFiles.size, afterFiles.size, "OPTIMIZE ZORDER") + assertOptimizeCommit(deltaLog, "OPTIMIZE ZORDER") + } + + val result = spark.read.format("delta").load(path) + val summary = result.selectExpr("count(*)", "min(id)", "max(id)").head() + assert(summary.getLong(0) == 256L) + assert(summary.getLong(1) == 0L) + assert(summary.getLong(2) == 255L) + } + } + } + + test("native delta optimize zorder partition predicate command should be offloaded") { + withNativeWriteOffloadConf { + withTempDir { + dir => + val path = dir.getCanonicalPath + spark + .range(0, 40, 1, 4) + .selectExpr("id", "cast(id % 2 as int) as part") + .write + .format("delta") + .partitionBy("part") + .mode("append") + .save(path) + spark + .range(40, 80, 1, 4) + .selectExpr("id", "cast(id % 2 as int) as part") + .write + .format("delta") + .partitionBy("part") + .mode("append") + .save(path) + + val deltaLog = DeltaLog.forTable(spark, path) + val beforeFiles = files(deltaLog) + val beforePart0Paths = beforeFiles + .filter(_.partitionValues.get("part").contains("0")) + .map(_.path) + val beforePart1Count = beforeFiles.count(_.partitionValues.get("part").contains("1")) + + withSQLConf( + DeltaSQLConf.DELTA_OPTIMIZE_ZORDER_COL_STAT_CHECK.key -> "false", + DeltaSQLConf.MDC_ADD_NOISE.key -> "false") { + val optimizeDf = sql(s"OPTIMIZE delta.`$path` WHERE part = 1 ZORDER BY (id)") + assertContainsNativeWriteCommand( + Seq(optimizeDf.queryExecution.executedPlan), + "OPTIMIZE WHERE ZORDER BY") + val metrics = collectOptimizeMetrics(optimizeDf) + + val afterFiles = files(deltaLog) + val afterPart0Paths = afterFiles + .filter(_.partitionValues.get("part").contains("0")) + .map(_.path) + val afterPart1Count = afterFiles.count(_.partitionValues.get("part").contains("1")) + assert( + beforePart0Paths.subsetOf(afterPart0Paths), + "OPTIMIZE WHERE part = 1 ZORDER should not remove files from part = 0") + assert( + afterPart1Count < beforePart1Count, + s"Expected fewer active files in part = 1, before=$beforePart1Count " + + s"after=$afterPart1Count") + assertCompactionMetrics( + metrics, + beforeFiles.size, + afterFiles.size, + "partition predicate OPTIMIZE ZORDER", + expectedPartitionsOptimized = Some(1L)) + assertOptimizeCommit(deltaLog, "partition predicate OPTIMIZE ZORDER") + } + + val result = spark.read.format("delta").load(path) + assert(result.select("id").collect().map(_.getLong(0)).toSet == (0L until 80L).toSet) + assert(result.where("part = 0").count() == 40) + assert(result.where("part = 1").count() == 40) + } + } + } + test("delta optimize command should not be offloaded when native write is disabled") { withNativeWriteOffloadConf { withTempDir { diff --git a/cpp/velox/operators/functions/DeltaZOrder.h b/cpp/velox/operators/functions/DeltaZOrder.h new file mode 100644 index 000000000000..b8dfca892b04 --- /dev/null +++ b/cpp/velox/operators/functions/DeltaZOrder.h @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/functions/Udf.h" + +#include +#include +#include +#include + +namespace gluten { + +namespace { +constexpr size_t kInterleaveBitsStackInputCapacity = 16; +constexpr size_t kRangePartitionLinearSearchBoundCount = 128; +} // namespace + +template +struct DeltaInterleaveBitsFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type>& inputs) { + interleave(result, inputs); + } + + FOLLY_ALWAYS_INLINE bool callNullable( + out_type& result, + const arg_type>* inputs) { + interleave(result, inputs); + return true; + } + + private: + FOLLY_ALWAYS_INLINE void + writeInterleavedBits(out_type& result, const int32_t* values, size_t valueCount) { + const auto byteCount = valueCount * sizeof(int32_t); + result.resize(byteCount); + if (byteCount == 0) { + return; + } + std::memset(result.data(), 0, byteCount); + + size_t outputBit = 0; + for (int bit = 31; bit >= 0; --bit) { + for (size_t i = 0; i < valueCount; ++i) { + if ((static_cast(values[i]) >> bit) & 1U) { + result.data()[outputBit >> 3] |= static_cast(1U << (7 - (outputBit & 7))); + } + ++outputBit; + } + } + } + + template + FOLLY_ALWAYS_INLINE void interleave(out_type& result, const TInputs* inputs) { + if (inputs == nullptr) { + result.resize(0); + return; + } + + const auto inputCount = inputs->size(); + if (inputCount == 0) { + result.resize(0); + return; + } + + if (inputCount <= kInterleaveBitsStackInputCapacity) { + std::array values; + for (size_t i = 0; i < inputCount; ++i) { + const auto input = inputs->at(i); + values[i] = input.has_value() ? input.value() : 0; + } + writeInterleavedBits(result, values.data(), inputCount); + return; + } + + std::vector values(inputCount); + for (size_t i = 0; i < inputCount; ++i) { + const auto input = inputs->at(i); + values[i] = input.has_value() ? input.value() : 0; + } + writeInterleavedBits(result, values.data(), inputCount); + } + + template + FOLLY_ALWAYS_INLINE void interleave(out_type& result, const TInputs& inputs) { + interleave(result, &inputs); + } +}; + +template +FOLLY_ALWAYS_INLINE void deltaRangePartitionId(int32_t& result, const TInputView* inputs) { + result = 0; + if (inputs == nullptr) { + return; + } + const auto inputCount = inputs->size(); + if (inputCount == 0) { + return; + } + + const auto valueArg = inputs->at(0); + if (!valueArg.has_value()) { + return; + } + + const auto value = valueArg.value(); + const auto boundCount = inputCount - 1; + if (boundCount <= kRangePartitionLinearSearchBoundCount) { + for (size_t i = 1; i < inputCount; ++i) { + const auto bound = inputs->at(i); + if (!bound.has_value() || value <= bound.value()) { + return; + } + ++result; + } + return; + } + + size_t lower = 0; + size_t upper = boundCount; + while (lower < upper) { + const auto mid = lower + (upper - lower) / 2; + const auto bound = inputs->at(mid + 1); + if (!bound.has_value() || value <= bound.value()) { + upper = mid; + } else { + lower = mid + 1; + } + } + result = static_cast(lower); +} + +template +FOLLY_ALWAYS_INLINE void +deltaRangePartitionIdFromBounds(int32_t& result, const TValue& value, const TBoundsView* bounds) { + result = 0; + if (bounds == nullptr) { + return; + } + + const auto boundCount = bounds->size(); + if (boundCount <= kRangePartitionLinearSearchBoundCount) { + for (size_t i = 0; i < boundCount; ++i) { + const auto bound = bounds->at(i); + if (!bound.has_value() || value <= bound.value()) { + return; + } + ++result; + } + return; + } + + size_t lower = 0; + size_t upper = boundCount; + while (lower < upper) { + const auto mid = lower + (upper - lower) / 2; + const auto bound = bounds->at(mid); + if (!bound.has_value() || value <= bound.value()) { + upper = mid; + } else { + lower = mid + 1; + } + } + result = static_cast(lower); +} + +template +struct DeltaRangePartitionIdTinyintFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type>& inputs) { + deltaRangePartitionId(result, &inputs); + } + + FOLLY_ALWAYS_INLINE bool callNullable(int32_t& result, const arg_type>* inputs) { + deltaRangePartitionId(result, inputs); + return true; + } +}; + +template +struct DeltaRangePartitionIdTinyintArrayFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(int32_t& result, int8_t value, const arg_type>& bounds) { + deltaRangePartitionIdFromBounds(result, value, &bounds); + } + + FOLLY_ALWAYS_INLINE bool + callNullable(int32_t& result, const int8_t* value, const arg_type>* bounds) { + result = 0; + if (value != nullptr) { + deltaRangePartitionIdFromBounds(result, *value, bounds); + } + return true; + } +}; + +template +struct DeltaRangePartitionIdSmallintFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type>& inputs) { + deltaRangePartitionId(result, &inputs); + } + + FOLLY_ALWAYS_INLINE bool callNullable(int32_t& result, const arg_type>* inputs) { + deltaRangePartitionId(result, inputs); + return true; + } +}; + +template +struct DeltaRangePartitionIdSmallintArrayFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void + call(int32_t& result, int16_t value, const arg_type>& bounds) { + deltaRangePartitionIdFromBounds(result, value, &bounds); + } + + FOLLY_ALWAYS_INLINE bool + callNullable(int32_t& result, const int16_t* value, const arg_type>* bounds) { + result = 0; + if (value != nullptr) { + deltaRangePartitionIdFromBounds(result, *value, bounds); + } + return true; + } +}; + +template +struct DeltaRangePartitionIdIntegerFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type>& inputs) { + deltaRangePartitionId(result, &inputs); + } + + FOLLY_ALWAYS_INLINE bool callNullable(int32_t& result, const arg_type>* inputs) { + deltaRangePartitionId(result, inputs); + return true; + } +}; + +template +struct DeltaRangePartitionIdIntegerArrayFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void + call(int32_t& result, int32_t value, const arg_type>& bounds) { + deltaRangePartitionIdFromBounds(result, value, &bounds); + } + + FOLLY_ALWAYS_INLINE bool + callNullable(int32_t& result, const int32_t* value, const arg_type>* bounds) { + result = 0; + if (value != nullptr) { + deltaRangePartitionIdFromBounds(result, *value, bounds); + } + return true; + } +}; + +template +struct DeltaRangePartitionIdBigintFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type>& inputs) { + deltaRangePartitionId(result, &inputs); + } + + FOLLY_ALWAYS_INLINE bool callNullable(int32_t& result, const arg_type>* inputs) { + deltaRangePartitionId(result, inputs); + return true; + } +}; + +template +struct DeltaRangePartitionIdBigintArrayFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void + call(int32_t& result, int64_t value, const arg_type>& bounds) { + deltaRangePartitionIdFromBounds(result, value, &bounds); + } + + FOLLY_ALWAYS_INLINE bool + callNullable(int32_t& result, const int64_t* value, const arg_type>* bounds) { + result = 0; + if (value != nullptr) { + deltaRangePartitionIdFromBounds(result, *value, bounds); + } + return true; + } +}; + +template +struct DeltaRangePartitionIdDateFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call( + int32_t& result, + const arg_type>& inputs) { + deltaRangePartitionId(result, &inputs); + } + + FOLLY_ALWAYS_INLINE bool callNullable( + int32_t& result, + const arg_type>* inputs) { + deltaRangePartitionId(result, inputs); + return true; + } +}; + +template +struct DeltaRangePartitionIdDateArrayFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call( + int32_t& result, + const arg_type& value, + const arg_type>& bounds) { + deltaRangePartitionIdFromBounds(result, value, &bounds); + } + + FOLLY_ALWAYS_INLINE bool callNullable( + int32_t& result, + const arg_type* value, + const arg_type>* bounds) { + result = 0; + if (value != nullptr) { + deltaRangePartitionIdFromBounds(result, *value, bounds); + } + return true; + } +}; + +} // namespace gluten diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index dd1be7805c75..ac419ab01f02 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -17,6 +17,7 @@ #include "operators/functions/RegistrationAllFunctions.h" #include "operators/functions/Arithmetic.h" +#include "operators/functions/DeltaZOrder.h" #include "operators/functions/RowConstructorWithNull.h" #include "operators/functions/RowFunctionWithNull.h" #include "velox/expression/SpecialFormRegistry.h" @@ -56,6 +57,27 @@ void registerFunctionOverwrite() { velox::registerFunction({"round"}); velox::registerFunction({"round"}); velox::registerFunction({"round"}); + velox::registerFunction>({"interleave_bits"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); + velox::registerFunction>( + {"range_partition_id"}); auto kRowConstructorWithNull = RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull; velox::exec::registerVectorFunction( diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index c50bae6e77ba..0ba99ba378c1 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -24,10 +24,11 @@ import org.apache.gluten.utils.DecimalArithmeticUtil import org.apache.spark.{SPARK_REVISION, SPARK_VERSION_SHORT} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.expressions.{StringTrimBoth, _} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke, StructsToJsonInvoke} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf @@ -41,6 +42,11 @@ trait Transformable { object ExpressionConverter extends SQLConfHelper with Logging { + private val DeltaInterleaveBitsClassName = + "org.apache.spark.sql.delta.expressions.InterleaveBits" + private val DeltaPartitionerExprClassName = + "org.apache.spark.sql.delta.expressions.PartitionerExpr" + def replaceWithExpressionTransformer( exprs: Seq[Expression], attributeSeq: Seq[Attribute]): Seq[ExpressionTransformer] = { @@ -246,6 +252,95 @@ object ExpressionConverter extends SQLConfHelper with Logging { } } + private def replaceDeltaInterleaveBitsWithExpressionTransformer( + expr: Expression, + attributeSeq: Seq[Attribute], + expressionsMap: Map[Class[_], String]): ExpressionTransformer = { + GenericExpressionTransformer( + ExpressionNames.INTERLEAVE_BITS, + expr.children.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), + expr) + } + + private def replaceDeltaPartitionerExprWithExpressionTransformer( + expr: Expression, + attributeSeq: Seq[Attribute], + expressionsMap: Map[Class[_], String]): ExpressionTransformer = { + val child = expr.children.headOption.getOrElse { + throw new GlutenNotSupportException(s"Delta PartitionerExpr has no child: $expr") + } + validateDeltaPartitionerType(child.dataType, expr) + val partitioner = invokeNoArg(expr, "partitioner") + if (partitioner == null || !isSupportedDeltaRangePartitioner(partitioner)) { + throw new GlutenNotSupportException(s"Unsupported Delta partitioner: $partitioner") + } + val ascending = getFieldValue(partitioner, "ascending").asInstanceOf[Boolean] + if (!ascending) { + throw new GlutenNotSupportException( + "Delta PartitionerExpr with descending bounds is not supported") + } + val rangeBounds = getFieldValue(partitioner, "rangeBounds") + val boundsDataType = ArrayType(child.dataType, containsNull = true) + val boundsLiteral = + LiteralTransformer( + Literal( + new GenericArrayData( + arrayToSeq(rangeBounds).map(extractRangeBoundValue(_, child.dataType)).toArray), + boundsDataType)) + + FunctionArgumentExpressionTransformer( + ExpressionNames.RANGE_PARTITION_ID, + Seq(replaceWithExpressionTransformer0(child, attributeSeq, expressionsMap), boundsLiteral), + expr, + Seq(child.dataType, boundsDataType) + ) + } + + private def validateDeltaPartitionerType(dataType: DataType, expr: Expression): Unit = { + dataType match { + case ByteType | ShortType | IntegerType | LongType | DateType => + case _ => + throw new GlutenNotSupportException( + s"Delta PartitionerExpr is not supported for $dataType: $expr") + } + } + + private def isSupportedDeltaRangePartitioner(partitioner: AnyRef): Boolean = { + partitioner.getClass.getName == "org.apache.spark.RangePartitioner" + } + + private def invokeNoArg(target: AnyRef, methodName: String): AnyRef = { + val method = target.getClass.getMethod(methodName) + method.invoke(target).asInstanceOf[AnyRef] + } + + private def getFieldValue(target: AnyRef, fieldName: String): AnyRef = { + val field = target.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(target).asInstanceOf[AnyRef] + } + + private def arrayToSeq(array: AnyRef): Seq[AnyRef] = { + if (array == null || !array.getClass.isArray) { + throw new GlutenNotSupportException(s"Expected RangePartitioner bounds array, got: $array") + } + (0 until java.lang.reflect.Array.getLength(array)).map { + index => java.lang.reflect.Array.get(array, index).asInstanceOf[AnyRef] + } + } + + private def extractRangeBoundValue(bound: AnyRef, dataType: DataType): Any = { + bound match { + case row: InternalRow => + if (row.isNullAt(0)) { + null + } else { + row.get(0, dataType) + } + case other => other + } + } + private def replaceWithExpressionTransformer0( expr: Expression, attributeSeq: Seq[Attribute], @@ -297,6 +392,12 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceStaticInvokeWithExpressionTransformer(staticInvoke, attributeSeq, expressionsMap)) case invoke: Invoke => Option(replaceInvokeWithExpressionTransformer(invoke, attributeSeq, expressionsMap)) + case _ if expr.getClass.getName == DeltaInterleaveBitsClassName => + Option( + replaceDeltaInterleaveBitsWithExpressionTransformer(expr, attributeSeq, expressionsMap)) + case _ if expr.getClass.getName == DeltaPartitionerExprClassName => + Option( + replaceDeltaPartitionerExprWithExpressionTransformer(expr, attributeSeq, expressionsMap)) case _ => None } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala index cfda2d8782e3..2e8eecbea831 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionTransformer.scala @@ -67,6 +67,22 @@ case class GenericExpressionTransformer( original: Expression) extends ExpressionTransformer +case class FunctionArgumentExpressionTransformer( + substraitExprName: String, + children: Seq[ExpressionTransformer], + original: Expression, + functionTypes: Seq[DataType] = Nil) + extends ExpressionTransformer { + override def doTransform(context: SubstraitContext): ExpressionNode = { + val inputTypes = if (functionTypes.nonEmpty) functionTypes else children.map(_.dataType) + val funcName = ConverterUtils.makeFuncName(substraitExprName, inputTypes) + val functionId = context.registerFunction(funcName) + val childNodes = children.map(_.doTransform(context)).asJava + val typeNode = ConverterUtils.getTypeNode(dataType, nullable) + ExpressionBuilder.makeScalarFunction(functionId, childNodes, typeNode) + } +} + object GenericExpressionTransformer { def apply( substraitExprName: String, diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index d4afb7ff739f..3d2174ba82af 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -353,6 +353,8 @@ object ExpressionNames { final val NULLIF = "nullif" final val NVL = "nvl" final val NVL2 = "nvl2" + final val INTERLEAVE_BITS = "interleave_bits" + final val RANGE_PARTITION_ID = "range_partition_id" // Directly use child expression transformer final val KNOWN_NULLABLE = "known_nullable"