Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.internal.SQLConf

class VeloxInsertSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "placeholder"
override protected val fileFormat: String = "parquet"

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
}

test("storeAssignmentPolicy default ANSI is independent from ANSI mode") {
withTable("store_assignment_ansi_src", "store_assignment_ansi") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
assert(SQLConf.get.storeAssignmentPolicy == SQLConf.StoreAssignmentPolicy.ANSI)

createTableWithValue("store_assignment_ansi_src", "STRING", "'2147483648'")
createTable("store_assignment_ansi", "INT")
assertUnsafeCastAnalysisException("STRING", "INT") {
insertIntoFrom("store_assignment_ansi", "store_assignment_ansi_src").collect()
}

withSQLConf(
SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString) {
val insert = insertIntoFrom("store_assignment_ansi", "store_assignment_ansi_src")
insert.collect()
checkGlutenPlan[ProjectExecTransformer](insert)
checkAnswer(spark.table("store_assignment_ansi"), Row(null))
}
}
}
}

test("storeAssignmentPolicy preserves configured cast modes") {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
withTable("store_assignment_ansi_src", "store_assignment_ansi") {
createTableWithValue("store_assignment_ansi_src", "STRING", "'2147483648'")
createTable("store_assignment_ansi", "INT")

withSQLConf(
SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.ANSI.toString) {
assertUnsafeCastAnalysisException("STRING", "INT") {
insertIntoFrom("store_assignment_ansi", "store_assignment_ansi_src").collect()
}
checkAnswer(spark.table("store_assignment_ansi"), Seq.empty[Row])
}
}
}

withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
withTable("store_assignment_legacy_src", "store_assignment_legacy") {
createTableWithValue("store_assignment_legacy_src", "STRING", "'2147483648'")
createTable("store_assignment_legacy", "INT")

withSQLConf(
SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.LEGACY.toString) {
val insert = insertIntoFrom("store_assignment_legacy", "store_assignment_legacy_src")
insert.collect()
checkGlutenPlan[ProjectExecTransformer](insert)
checkAnswer(spark.table("store_assignment_legacy"), Row(null))
}
}
}
}

test("storeAssignmentPolicy strict rejects unsafe insert casts") {
withTable("store_assignment_strict_src", "store_assignment_strict") {
withSQLConf(
SQLConf.STORE_ASSIGNMENT_POLICY.key -> SQLConf.StoreAssignmentPolicy.STRICT.toString) {
createTableWithValue("store_assignment_strict_src", "INT", "1")
createTable("store_assignment_strict", "TINYINT")

assertUnsafeCastAnalysisException("INT", "TINYINT") {
insertIntoFrom("store_assignment_strict", "store_assignment_strict_src").collect()
}
checkAnswer(spark.table("store_assignment_strict"), Seq.empty[Row])
}
}
}

private def createTable(table: String, dataType: String): Unit =
spark.sql(s"CREATE TABLE $table (c $dataType) USING PARQUET")

private def createTableWithValue(table: String, dataType: String, value: String): Unit = {
createTable(table, dataType)
spark.sql(s"INSERT INTO $table VALUES ($value)").collect()
}

private def insertIntoFrom(target: String, source: String) =
spark.sql(s"INSERT INTO $target SELECT c FROM $source")

private def assertUnsafeCastAnalysisException(
fromType: String,
toType: String)(f: => Unit): Unit = {
val exception = intercept[AnalysisException](f)
val message = exceptionMessages(exception)
assert(message.contains(fromType), message)
assert(message.contains(toType), message)
assert(message.contains("cast") || message.contains("Cast"), message)
}

private def exceptionMessages(e: Throwable): String = {
val message = Option(e.getMessage).getOrElse("")
if (e.getCause == null) {
message
} else {
message + "\n" + exceptionMessages(e.getCause)
}
}
}
1 change: 1 addition & 0 deletions cpp/velox/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ set(VELOX_SRCS
memory/VeloxMemoryManager.cc
operators/functions/RegistrationAllFunctions.cc
operators/functions/RowConstructorWithNull.cc
operators/functions/SparkCastModeSpecialForms.cc
operators/functions/SparkExprToSubfieldFilterParser.cc
operators/plannodes/RowVectorStream.cc
operators/hashjoin/HashTableBuilder.cc
Expand Down
2 changes: 2 additions & 0 deletions cpp/velox/operators/functions/RegistrationAllFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "operators/functions/Arithmetic.h"
#include "operators/functions/RowConstructorWithNull.h"
#include "operators/functions/RowFunctionWithNull.h"
#include "operators/functions/SparkCastModeSpecialForms.h"
#include "velox/expression/SpecialFormRegistry.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/iceberg/Register.h"
Expand Down Expand Up @@ -83,6 +84,7 @@ void registerFunctionOverwrite() {

void registerAllFunctions() {
velox::functions::sparksql::registerFunctions("");
registerSparkCastModeSpecialForms();
velox::aggregate::prestosql::registerAllAggregateFunctions(
"", true /*registerCompanionFunctions*/, false /*onlyPrestoSignatures*/, true /*overwrite*/);
velox::functions::aggregate::sparksql::registerAggregateFunctions(
Expand Down
118 changes: 118 additions & 0 deletions cpp/velox/operators/functions/SparkCastModeSpecialForms.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "operators/functions/SparkCastModeSpecialForms.h"

#include "velox/expression/SpecialFormRegistry.h"
#include "velox/functions/sparksql/specialforms/SparkCastExpr.h"
#include "velox/functions/sparksql/specialforms/SparkCastHooks.h"

namespace gluten {
namespace {

using namespace facebook::velox;
using facebook::velox::functions::sparksql::SparkCastExpr;
using facebook::velox::functions::sparksql::SparkCastHooks;

bool isIntegralType(const TypePtr& type) {
return type == TINYINT() || type == SMALLINT() || type == INTEGER() ||
type == BIGINT();
}

// Keep this in sync with Velox's SparkCastCallToSpecialForm::isAnsiSupported.
// Velox's helper is private today; this local copy is needed for expression-level
// ANSI and legacy cast modes.
bool isAnsiSupported(const TypePtr& fromType, const TypePtr& toType) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also a check on the Velox side. Will we maintain the ANSI support check only here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended to mirror the existing Velox ANSI support side check not to replace it, since that check is private today, I added this local helper for the expression level ANSI/legacy cast mode. I’ll add a comment to make sure we keep it aligned with Velox

if (fromType->isVarchar()) {
return toType->isBoolean() || toType->isDate() || isIntegralType(toType);
}
return false;
}

exec::ExprPtr makeSparkCastExpr(
const TypePtr& type,
exec::ExprPtr&& input,
bool trackCpuUsage,
bool isTryCast,
bool allowOverflow,
const core::QueryConfig& config) {
return std::make_shared<SparkCastExpr>(
type,
std::move(input),
trackCpuUsage,
isTryCast,
std::make_shared<SparkCastHooks>(config, allowOverflow));
}

class SparkAnsiCastCallToSpecialForm : public exec::CastCallToSpecialForm {
public:
exec::ExprPtr constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const core::QueryConfig& config) override {
VELOX_CHECK_EQ(
compiledChildren.size(),
1,
"ANSI CAST statements expect exactly 1 argument, received {}.",
compiledChildren.size());

const auto& fromType = compiledChildren[0]->type();
const bool isTryCast = !isAnsiSupported(fromType, type);
return makeSparkCastExpr(
type,
std::move(compiledChildren[0]),
trackCpuUsage,
isTryCast,
isTryCast,
config);
}
};

class SparkLegacyCastCallToSpecialForm : public exec::CastCallToSpecialForm {
public:
exec::ExprPtr constructSpecialForm(
const TypePtr& type,
std::vector<exec::ExprPtr>&& compiledChildren,
bool trackCpuUsage,
const core::QueryConfig& config) override {
VELOX_CHECK_EQ(
compiledChildren.size(),
1,
"LEGACY CAST statements expect exactly 1 argument, received {}.",
compiledChildren.size());

return makeSparkCastExpr(
type,
std::move(compiledChildren[0]),
trackCpuUsage,
true,
true,
config);
}
};

} // namespace

void registerSparkCastModeSpecialForms() {
exec::registerFunctionCallToSpecialForm(
kSparkAnsiCast, std::make_unique<SparkAnsiCastCallToSpecialForm>());
exec::registerFunctionCallToSpecialForm(
kSparkLegacyCast, std::make_unique<SparkLegacyCastCallToSpecialForm>());
}

} // namespace gluten
27 changes: 27 additions & 0 deletions cpp/velox/operators/functions/SparkCastModeSpecialForms.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

namespace gluten {

constexpr const char* kSparkAnsiCast = "spark_ansi_cast";
constexpr const char* kSparkLegacyCast = "spark_legacy_cast";

void registerSparkCastModeSpecialForms();

} // namespace gluten
32 changes: 26 additions & 6 deletions cpp/velox/substrait/SubstraitToVeloxExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "SubstraitToVeloxExpr.h"
#include "TypeUtils.h"
#include "operators/functions/SparkCastModeSpecialForms.h"
#include "velox/vector/FlatVector.h"
#include "velox/vector/VariantToVector.h"

Expand Down Expand Up @@ -146,16 +147,24 @@ TypePtr getScalarType(const ::substrait::Expression::Literal& literal) {
}
}

/// Whether is try cast.
bool isTryCast(::substrait::Expression::Cast::FailureBehavior failureBehavior) {
enum class SparkCastMode {
kLegacy,
kAnsi,
kTry,
};

SparkCastMode sparkCastMode(
::substrait::Expression::Cast::FailureBehavior failureBehavior) {
switch (failureBehavior) {
case ::substrait::Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_UNSPECIFIED:
return SparkCastMode::kLegacy;
case ::substrait::Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_THROW_EXCEPTION:
return false;
return SparkCastMode::kAnsi;
case ::substrait::Expression_Cast_FailureBehavior_FAILURE_BEHAVIOR_RETURN_NULL:
return true;
return SparkCastMode::kTry;
default:
VELOX_NYI("The given failure behavior is NOT supported: '{}'", std::to_string(failureBehavior));
VELOX_NYI(
"The given failure behavior is NOT supported: '{}'", std::to_string(failureBehavior));
}
}

Expand Down Expand Up @@ -564,7 +573,18 @@ core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr(
const RowTypePtr& inputType) {
auto type = SubstraitParser::parseType(castExpr.type());
std::vector<core::TypedExprPtr> inputs{toVeloxExpr(castExpr.input(), inputType)};
return std::make_shared<core::CastTypedExpr>(type, inputs, isTryCast(castExpr.failure_behavior()));
switch (sparkCastMode(castExpr.failure_behavior())) {
case SparkCastMode::kLegacy:
return std::make_shared<const core::CallTypedExpr>(
type, std::move(inputs), kSparkLegacyCast);
case SparkCastMode::kAnsi:
return std::make_shared<const core::CallTypedExpr>(
type, std::move(inputs), kSparkAnsiCast);
case SparkCastMode::kTry:
return std::make_shared<core::CastTypedExpr>(type, std::move(inputs), true);
default:
VELOX_UNREACHABLE();
}
}

core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr(
Expand Down
Loading
Loading