Skip to content

[SPARK-52243][CONNECT] Add NERF support for schema-related InvalidPlanInput errors #50997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3432,6 +3432,12 @@
],
"sqlState" : "42602"
},
"INVALID_SCHEMA_TYPE_NON_STRUCT" : {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use this error condition everywhere? All of those features need to specify a schema, and INVALID_SCHEMA_TYPE_NON_STRUCT is general enough that fit all of them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To indicate where the error happened, error context is a better place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

"message" : [
"Invalid schema type. Expect a struct type, but got <dataType>."
],
"sqlState" : "42K09"
},
"INVALID_SET_SYNTAX" : {
"message" : [
"Expected format is 'SET', 'SET key', or 'SET key=value'. If you want to include special characters in key, or include semicolon in value, please use backquotes, e.g., SET `key`=`value`."
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_parity_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def eval(self, a: int):
yield a + 1,

with self.assertRaisesRegex(
InvalidPlanInput, "Invalid Python user-defined table function return type."
InvalidPlanInput, "Invalid schema type. Expect a struct type, but got"
):
TestUDTF(lit(1)).collect()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private[sql] trait DataTypeErrorsBase {
else value.toString
}

protected def quoteByDefault(elem: String): String = {
protected[sql] def quoteByDefault(elem: String): String = {
"\"" + elem + "\""
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,22 @@ package org.apache.spark.sql.connect.planner

import scala.collection.mutable

import org.apache.spark.SparkThrowableHelper
import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.common.{InvalidCommandInput, InvalidPlanInput}
import org.apache.spark.sql.errors.DataTypeErrors.{quoteByDefault, toSQLType}
import org.apache.spark.sql.types.DataType

object InvalidInputErrors {

// invalidPlanInput is a helper function to facilitate the migration of InvalidInputErrors
// to support NERF.
private def invalidPlanInput(
errorCondition: String,
messageParameters: Map[String, String] = Map.empty): InvalidPlanInput = {
InvalidPlanInput(SparkThrowableHelper.getMessage(errorCondition, messageParameters))
}

def unknownRelationNotSupported(rel: proto.Relation): InvalidPlanInput =
InvalidPlanInput(s"${rel.getUnknown} not supported.")

Expand Down Expand Up @@ -72,11 +82,6 @@ object InvalidInputErrors {
def rowNotSupportedForUdf(errorType: String): InvalidPlanInput =
InvalidPlanInput(s"Row is not a supported $errorType type for this UDF.")

def invalidUserDefinedOutputSchemaType(actualType: String): InvalidPlanInput =
InvalidPlanInput(
s"Invalid user-defined output schema type for TransformWithStateInPandas. " +
s"Expect a struct type, but got $actualType.")

def notFoundCachedLocalRelation(hash: String, sessionUUID: String): InvalidPlanInput =
InvalidPlanInput(
s"Not found any cached local relation with the hash: " +
Expand All @@ -91,8 +96,10 @@ object InvalidInputErrors {
def schemaRequiredForLocalRelation(): InvalidPlanInput =
InvalidPlanInput("Schema for LocalRelation is required when the input data is not provided.")

def invalidSchema(schema: DataType): InvalidPlanInput =
InvalidPlanInput(s"Invalid schema $schema")
def invalidSchemaStringNonStructType(schema: String, dataType: DataType): InvalidPlanInput =
invalidPlanInput(
"INVALID_SCHEMA.NON_STRUCT_TYPE",
Map("inputSchema" -> quoteByDefault(schema), "dataType" -> toSQLType(dataType)))

def invalidJdbcParams(): InvalidPlanInput =
InvalidPlanInput("Invalid jdbc params, please specify jdbc url and table.")
Expand All @@ -106,8 +113,8 @@ object InvalidInputErrors {
def doesNotSupport(what: String): InvalidPlanInput =
InvalidPlanInput(s"Does not support $what")

def invalidSchemaDataType(dataType: DataType): InvalidPlanInput =
InvalidPlanInput(s"Invalid schema dataType $dataType")
def invalidSchemaTypeNonStruct(dataType: DataType): InvalidPlanInput =
invalidPlanInput("INVALID_SCHEMA_TYPE_NON_STRUCT", Map("dataType" -> toSQLType(dataType)))

def expressionIdNotSupported(exprId: Int): InvalidPlanInput =
InvalidPlanInput(s"Expression with ID: $exprId is not supported")
Expand Down Expand Up @@ -189,9 +196,6 @@ object InvalidInputErrors {
def usingColumnsOrJoinConditionSetInJoin(): InvalidPlanInput =
InvalidPlanInput("Using columns or join conditions cannot be set at the same time in Join")

def invalidStateSchemaDataType(dataType: DataType): InvalidPlanInput =
InvalidPlanInput(s"Invalid state schema dataType $dataType for flatMapGroupsWithState")

def sqlCommandExpectsSqlOrWithRelations(other: proto.Relation.RelTypeCase): InvalidPlanInput =
InvalidPlanInput(s"SQL command expects either a SQL or a WithRelations, but got $other")

Expand All @@ -213,17 +217,6 @@ object InvalidInputErrors {
def invalidBucketCount(numBuckets: Int): InvalidCommandInput =
InvalidCommandInput("INVALID_BUCKET_COUNT", Map("numBuckets" -> numBuckets.toString))

def invalidPythonUdtfReturnType(actualType: String): InvalidPlanInput =
InvalidPlanInput(
s"Invalid Python user-defined table function return type. " +
s"Expect a struct type, but got $actualType.")

def invalidUserDefinedOutputSchemaTypeForTransformWithState(
actualType: String): InvalidPlanInput =
InvalidPlanInput(
s"Invalid user-defined output schema type for TransformWithStateInPandas. " +
s"Expect a struct type, but got $actualType.")

def unsupportedUserDefinedFunctionImplementation(clazz: Class[_]): InvalidPlanInput =
InvalidPlanInput(s"Unsupported UserDefinedFunction implementation: ${clazz}")
}
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ class SparkConnectPlanner(
val stateSchema = DataTypeProtoConverter.toCatalystType(rel.getStateSchema) match {
case s: StructType => s
case other =>
throw InvalidInputErrors.invalidStateSchemaDataType(other)
throw InvalidInputErrors.invalidSchemaTypeNonStruct(other)
}
val stateEncoder = TypedScalaUdf.encoderFor(
// the state agnostic encoder is the second element in the input encoders.
Expand Down Expand Up @@ -1105,8 +1105,7 @@ class SparkConnectPlanner(
transformDataType(twsInfo.getOutputSchema) match {
case s: StructType => s
case dt =>
throw InvalidInputErrors.invalidUserDefinedOutputSchemaTypeForTransformWithState(
dt.typeName)
throw InvalidInputErrors.invalidSchemaTypeNonStruct(dt)
}
}

Expand Down Expand Up @@ -1502,7 +1501,7 @@ class SparkConnectPlanner(
StructType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: StructType => s
case other => throw InvalidInputErrors.invalidSchema(other)
case other => throw InvalidInputErrors.invalidSchemaStringNonStructType(schema, other)
}
}

Expand Down Expand Up @@ -1580,7 +1579,7 @@ class SparkConnectPlanner(
if (rel.hasSchema) {
DataTypeProtoConverter.toCatalystType(rel.getSchema) match {
case s: StructType => reader.schema(s)
case other => throw InvalidInputErrors.invalidSchemaDataType(other)
case other => throw InvalidInputErrors.invalidSchemaTypeNonStruct(other)
}
}
localMap.foreach { case (key, value) => reader.option(key, value) }
Expand Down Expand Up @@ -2967,8 +2966,7 @@ class SparkConnectPlanner(
val returnType = if (udtf.hasReturnType) {
transformDataType(udtf.getReturnType) match {
case s: StructType => Some(s)
case dt =>
throw InvalidInputErrors.invalidPythonUdtfReturnType(dt.typeName)
case dt => throw InvalidInputErrors.invalidSchemaTypeNonStruct(dt)
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import org.apache.spark.SparkThrowableHelper
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput}
import org.apache.spark.sql.connect.planner.SparkConnectPlanTest
import org.apache.spark.sql.types._

class InvalidInputErrorsSuite extends PlanTest with SparkConnectPlanTest {

lazy val testLocalRelation =
createLocalRelationProto(
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()),
Seq.empty)

val testCases = Seq(
TestCase(
name = "Invalid schema data type non struct for Parse",
expectedErrorCondition = "INVALID_SCHEMA_TYPE_NON_STRUCT",
expectedParameters = Map("dataType" -> "\"ARRAY<INT>\""),
invalidInput = {
val parse = proto.Parse
.newBuilder()
.setSchema(DataTypeProtoConverter.toConnectProtoType(ArrayType(IntegerType)))
.setFormat(proto.Parse.ParseFormat.PARSE_FORMAT_CSV)
.build()

proto.Relation.newBuilder().setParse(parse).build()
}),
TestCase(
name = "Invalid schema type non struct for TransformWithState",
expectedErrorCondition = "INVALID_SCHEMA_TYPE_NON_STRUCT",
expectedParameters = Map("dataType" -> "\"ARRAY<INT>\""),
invalidInput = {
val pythonUdf = proto.CommonInlineUserDefinedFunction
.newBuilder()
.setPythonUdf(
proto.PythonUDF
.newBuilder()
.setEvalType(211)
.setOutputType(DataTypeProtoConverter.toConnectProtoType(ArrayType(IntegerType)))
.build())
.build()

val groupMap = proto.GroupMap
.newBuilder()
.setInput(testLocalRelation)
.setFunc(pythonUdf)
.setTransformWithStateInfo(
proto.TransformWithStateInfo
.newBuilder()
.setOutputSchema(DataTypeProtoConverter.toConnectProtoType(ArrayType(IntegerType)))
.build())
.build()

proto.Relation.newBuilder().setGroupMap(groupMap).build()
}),
TestCase(
name = "Invalid schema string non struct type",
expectedErrorCondition = "INVALID_SCHEMA.NON_STRUCT_TYPE",
expectedParameters = Map(
"inputSchema" -> """"{"type":"array","elementType":"integer","containsNull":false}"""",
"dataType" -> "\"ARRAY<INT>\""),
invalidInput = {
val invalidSchema = """{"type":"array","elementType":"integer","containsNull":false}"""

val dataSource = proto.Read.DataSource
.newBuilder()
.setFormat("csv")
.setSchema(invalidSchema)
.build()

val read = proto.Read
.newBuilder()
.setDataSource(dataSource)
.build()

proto.Relation.newBuilder().setRead(read).build()
}))

// Run all test cases
testCases.foreach { testCase =>
test(s"${testCase.name}") {
val exception = intercept[InvalidPlanInput] {
transform(testCase.invalidInput)
}
val expectedMessage = SparkThrowableHelper.getMessage(
testCase.expectedErrorCondition,
testCase.expectedParameters)
assert(exception.getMessage == expectedMessage)
}
}

// Helper case class to define test cases
case class TestCase(
name: String,
expectedErrorCondition: String,
expectedParameters: Map[String, String],
invalidInput: proto.Relation)
}