From 6ca8b23886d0e9648847451f00aa26835aec9857 Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 20 Mar 2025 14:40:01 +0800 Subject: [PATCH] [SPARK-51568][SQL] Introduce isSupportedExtract to prevent happening unexpected behavior --- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 4 +++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 +++++++++++++++++- .../apache/spark/sql/jdbc/MySQLDialect.scala | 19 ++++++++++++------- .../spark/sql/jdbc/PostgresDialect.scala | 4 +++- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 82f6f5c6264c4..3825812f92214 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSu import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, Extract, FieldReference, NamedReference} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, MetadataBuilder, ShortType, StringType, TimestampType} @@ -57,6 +57,8 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError { override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) + override def isSupportedExtract(extract: Extract): Boolean = true + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { sqlType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 0d10b8e04484e..e43506bd3ee3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange} import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.catalog.index.TableIndex -import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, Extract, Literal, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder @@ -411,6 +411,14 @@ abstract class JdbcDialect extends Serializable with Logging { s"CAST($expr AS $databaseTypeDefinition)" } + override def visitExtract(extract: Extract): String = { + if (isSupportedExtract(extract)) { + super.visitExtract(extract) + } else { + visitUnexpectedExpr(extract) + } + } + override def visitSQLFunction(funcName: String, inputs: Array[Expression]): String = { if (isSupportedFunction(funcName)) { super.visitSQLFunction(funcName, inputs) @@ -499,6 +507,14 @@ abstract class JdbcDialect extends Serializable with Logging { @Since("3.3.0") def isSupportedFunction(funcName: String): Boolean = false + /** + * Returns whether the database supports extract. + * @param extract The V2 Extract to be converted. + * @return True if the database supports extract. + */ + @Since("4.1.0") + def isSupportedExtract(extract: Extract): Boolean = false + /** * Converts V2 expression to String representing a SQL expression. * @param expr The V2 expression to be converted. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 5dec15b0fbcde..21e33335ecde3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -50,20 +50,25 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) + override def isSupportedExtract(extract: Extract): Boolean = { + extract.field match { + case "YEAR_OF_WEEK" => false + case _ => true + } + } + class MySQLSQLBuilder extends JDBCSQLBuilder { - override def visitExtract(extract: Extract): String = { - val field = extract.field + override def visitExtract(field: String, source: String): String = { field match { - case "DAY_OF_YEAR" => s"DAYOFYEAR(${build(extract.source())})" - case "WEEK" => s"WEEKOFYEAR(${build(extract.source())})" - case "YEAR_OF_WEEK" => visitUnexpectedExpr(extract) + case "DAY_OF_YEAR" => s"DAYOFYEAR($source)" + case "WEEK" => s"WEEKOFYEAR($source)" // WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ..., // so we use the formula (WEEKDAY + 1) to follow the ISO standard. - case "DAY_OF_WEEK" => s"(WEEKDAY(${build(extract.source())}) + 1)" + case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)" // SECOND, MINUTE, HOUR, DAY, MONTH, QUARTER, YEAR are identical on MySQL and Spark for // both datetime and interval types. - case _ => super.visitExtract(field, build(extract.source())) + case _ => super.visitExtract(field, source) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index b4cd5f578ccd1..32d061c4e8a0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.expressions.{Expression, NamedReference} +import org.apache.spark.sql.connector.expressions.{Expression, Extract, NamedReference} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -54,6 +54,8 @@ private case class PostgresDialect() override def isSupportedFunction(funcName: String): Boolean = supportedFunctions.contains(funcName) + override def isSupportedExtract(extract: Extract): Boolean = true + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { sqlType match {