Skip to content

Commit 4a18179

Browse files
committed
[SPARK-54683][SQL] Unify geo and time types blocking
### What changes were proposed in this pull request? This PR aims to refactor the code that blocks time and geo types. ### Why are the changes needed? code unification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #53438 from cloud-fan/block. Lead-authored-by: Wenchen Fan <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3481649 commit 4a18179

File tree

9 files changed

+30
-43
lines changed

9 files changed

+30
-43
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period
2525
import java.util.{Map => JavaMap}
2626
import javax.annotation.Nullable
2727

28+
import scala.language.existentials
29+
2830
import org.apache.spark.SparkIllegalArgumentException
2931
import org.apache.spark.sql.Row
3032
import org.apache.spark.sql.catalyst.expressions._
3133
import org.apache.spark.sql.catalyst.util._
32-
import org.apache.spark.sql.errors.QueryCompilationErrors
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.types._
3536
import org.apache.spark.sql.types.DayTimeIntervalType._
@@ -60,6 +61,7 @@ object CatalystTypeConverters {
6061
}
6162

6263
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
64+
TypeUtils.failUnsupportedDataType(dataType, SQLConf.get)
6365
val converter = dataType match {
6466
case udt: UserDefinedType[_] => UDTConverter(udt)
6567
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
@@ -68,18 +70,12 @@ object CatalystTypeConverters {
6870
case CharType(length) => new CharConverter(length)
6971
case VarcharType(length) => new VarcharConverter(length)
7072
case _: StringType => StringConverter
71-
case _ @ (_: GeographyType | _: GeometryType) if !SQLConf.get.geospatialEnabled =>
72-
throw new org.apache.spark.sql.AnalysisException(
73-
errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
74-
messageParameters = scala.collection.immutable.Map.empty)
7573
case g: GeographyType =>
7674
new GeographyConverter(g)
7775
case g: GeometryType =>
7876
new GeometryConverter(g)
7977
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
8078
case DateType => DateConverter
81-
case _: TimeType if !SQLConf.get.isTimeTypeEnabled =>
82-
QueryCompilationErrors.unsupportedTimeTypeError()
8379
case _: TimeType => TimeConverter
8480
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
8581
case TimestampType => TimestampConverter

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ object DeserializerBuildHelper {
345345
createDeserializerForInstant(path)
346346
case LocalDateTimeEncoder =>
347347
createDeserializerForLocalDateTime(path)
348+
case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled =>
349+
throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError()
348350
case LocalTimeEncoder =>
349351
createDeserializerForLocalTime(path)
350352
case UDTEncoder(udt, udtClass) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ object SerializerBuildHelper {
367367
case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
368368
case InstantEncoder(false) => createSerializerForJavaInstant(input)
369369
case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
370+
case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled =>
371+
throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError()
370372
case LocalTimeEncoder => createSerializerForLocalTime(input)
371373
case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass)
372374
case OptionEncoder(valueEnc) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
712712
}
713713

714714
create.tableSchema.foreach(f => TypeUtils.failWithIntervalType(f.dataType))
715+
TypeUtils.failUnsupportedDataType(create.tableSchema, SQLConf.get)
715716
SchemaUtils.checkIndeterminateCollationInSchema(create.tableSchema)
716717

717718
case write: V2WriteCommand if write.resolved =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util._
3535
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3636
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
3737
import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
38-
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
38+
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
3939
import org.apache.spark.sql.internal.SQLConf
4040
import org.apache.spark.sql.types._
4141
import org.apache.spark.unsafe.types.{GeographyVal, UTF8String, VariantVal}
@@ -90,12 +90,6 @@ object Cast extends QueryErrorsBase {
9090
* - String <=> Binary
9191
*/
9292
def canAnsiCast(from: DataType, to: DataType): Boolean = (from, to) match {
93-
case (fromType, toType) if !SQLConf.get.geospatialEnabled &&
94-
(isGeoSpatialType(fromType) || isGeoSpatialType(toType)) =>
95-
throw new org.apache.spark.sql.AnalysisException(
96-
errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
97-
messageParameters = scala.collection.immutable.Map.empty)
98-
9993
case (fromType, toType) if fromType == toType => true
10094

10195
case (NullType, _) => true
@@ -224,12 +218,6 @@ object Cast extends QueryErrorsBase {
224218
* Returns true iff we can cast `from` type to `to` type.
225219
*/
226220
def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
227-
case (fromType, toType) if !SQLConf.get.geospatialEnabled &&
228-
(isGeoSpatialType(fromType) || isGeoSpatialType(toType)) =>
229-
throw new org.apache.spark.sql.AnalysisException(
230-
errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
231-
messageParameters = scala.collection.immutable.Map.empty)
232-
233221
case (fromType, toType) if fromType == toType => true
234222

235223
case (NullType, _) => true
@@ -617,12 +605,7 @@ case class Cast(
617605
}
618606

619607
override def checkInputDataTypes(): TypeCheckResult = {
620-
dataType match {
621-
// If the cast is to a TIME type, first check if TIME type is enabled.
622-
case _: TimeType if !SQLConf.get.isTimeTypeEnabled =>
623-
throw QueryCompilationErrors.unsupportedTimeTypeError()
624-
case _ =>
625-
}
608+
TypeUtils.failUnsupportedDataType(dataType, SQLConf.get)
626609
val canCast = evalMode match {
627610
case EvalMode.LEGACY => Cast.canCast(child.dataType, dataType)
628611
case EvalMode.ANSI => Cast.canAnsiCast(child.dataType, dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ case class AddColumns(
112112
columnsToAdd: Seq[QualifiedColType]) extends AlterTableCommand {
113113
columnsToAdd.foreach { c =>
114114
TypeUtils.failWithIntervalType(c.dataType)
115+
TypeUtils.failUnsupportedDataType(c.dataType, conf)
115116
}
116117

117118
override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved)
@@ -144,6 +145,7 @@ case class ReplaceColumns(
144145
columnsToAdd: Seq[QualifiedColType]) extends AlterTableCommand {
145146
columnsToAdd.foreach { c =>
146147
TypeUtils.failWithIntervalType(c.dataType)
148+
TypeUtils.failUnsupportedDataType(c.dataType, conf)
147149
}
148150

149151
override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ package org.apache.spark.sql.catalyst.util
2020
import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, TypeCheckResult, TypeCoercion}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2222
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
23+
import org.apache.spark.sql.catalyst.expressions.st.STExpressionUtils.isGeoSpatialType
2324
import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalNumericType}
2425
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
26+
import org.apache.spark.sql.internal.SQLConf
2527
import org.apache.spark.sql.types._
2628

2729
/**
@@ -136,4 +138,15 @@ object TypeUtils extends QueryErrorsBase {
136138
}
137139
if (dataType.existsRecursively(isInterval)) f
138140
}
141+
142+
def failUnsupportedDataType(dataType: DataType, conf: SQLConf): Unit = {
143+
if (!conf.isTimeTypeEnabled && dataType.existsRecursively(_.isInstanceOf[TimeType])) {
144+
throw QueryCompilationErrors.unsupportedTimeTypeError()
145+
}
146+
if (!conf.geospatialEnabled && dataType.existsRecursively(isGeoSpatialType)) {
147+
throw new org.apache.spark.sql.AnalysisException(
148+
errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
149+
messageParameters = scala.collection.immutable.Map.empty)
150+
}
151+
}
139152
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,18 @@ import org.apache.spark.connect.proto
2929
import org.apache.spark.connect.proto.ExecutePlanResponse
3030
import org.apache.spark.sql.{AnalysisException, Row}
3131
import org.apache.spark.sql.catalyst.InternalRow
32-
import org.apache.spark.sql.catalyst.expressions.st.STExpressionUtils
32+
import org.apache.spark.sql.catalyst.util.TypeUtils
3333
import org.apache.spark.sql.classic.{DataFrame, Dataset}
3434
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
3535
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
3636
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_ARROW_MAX_BATCH_SIZE, CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE}
3737
import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner}
3838
import org.apache.spark.sql.connect.service.ExecuteHolder
3939
import org.apache.spark.sql.connect.utils.{MetricGenerator, PipelineAnalysisContextUtils}
40-
import org.apache.spark.sql.errors.QueryCompilationErrors
4140
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
4241
import org.apache.spark.sql.execution.arrow.ArrowConverters
4342
import org.apache.spark.sql.internal.SQLConf
44-
import org.apache.spark.sql.types.{DataType, StructType, TimeType}
43+
import org.apache.spark.sql.types.{DataType, StructType}
4544
import org.apache.spark.util.ThreadUtils
4645

4746
/**
@@ -138,16 +137,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
138137
val sessionId = executePlan.sessionHolder.sessionId
139138
val spark = dataframe.sparkSession
140139
val schema = dataframe.schema
141-
val geospatialEnabled = spark.sessionState.conf.geospatialEnabled
142-
if (!geospatialEnabled && schema.existsRecursively(STExpressionUtils.isGeoSpatialType)) {
143-
throw new org.apache.spark.sql.AnalysisException(
144-
errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
145-
messageParameters = scala.collection.immutable.Map.empty)
146-
}
147-
val timeTypeEnabled = spark.sessionState.conf.isTimeTypeEnabled
148-
if (!timeTypeEnabled && schema.existsRecursively(_.isInstanceOf[TimeType])) {
149-
throw QueryCompilationErrors.unsupportedTimeTypeError()
150-
}
140+
TypeUtils.failUnsupportedDataType(schema, spark.sessionState.conf)
151141
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
152142
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
153143
val largeVarTypes = spark.sessionState.conf.arrowUseLargeVarTypes

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.{SparkException, SparkUpgradeException}
3030
import org.apache.spark.sql.{sources, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
3131
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
3232
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
33-
import org.apache.spark.sql.catalyst.util.RebaseDateTime
33+
import org.apache.spark.sql.catalyst.util.{RebaseDateTime, TypeUtils}
3434
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
3535
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3636
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
@@ -93,9 +93,7 @@ object DataSourceUtils extends PredicateHelper {
9393
* in a driver side.
9494
*/
9595
def verifySchema(format: FileFormat, schema: StructType, readOnly: Boolean = false): Unit = {
96-
if (!SQLConf.get.isTimeTypeEnabled && schema.existsRecursively(_.isInstanceOf[TimeType])) {
97-
throw QueryCompilationErrors.unsupportedTimeTypeError()
98-
}
96+
TypeUtils.failUnsupportedDataType(schema, SQLConf.get)
9997
schema.foreach { field =>
10098
val supported = if (readOnly) {
10199
format.supportReadDataType(field.dataType)

0 commit comments

Comments
 (0)