Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51568][SQL] Introduce isSupportedExtract to prevent happening unexpected behavior #50333

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSu
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference}
import org.apache.spark.sql.connector.expressions.{Expression, Extract, FieldReference, NamedReference}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, MetadataBuilder, ShortType, StringType, TimestampType}

Expand All @@ -57,6 +57,8 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError {
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

override def isSupportedExtract(extract: Extract): Boolean = true

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableChange}
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
import org.apache.spark.sql.connector.expressions.{Expression, Extract, Literal, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
Expand Down Expand Up @@ -411,6 +411,14 @@ abstract class JdbcDialect extends Serializable with Logging {
s"CAST($expr AS $databaseTypeDefinition)"
}

override def visitExtract(extract: Extract): String = {
if (isSupportedExtract(extract)) {
super.visitExtract(extract)
} else {
visitUnexpectedExpr(extract)
}
}

override def visitSQLFunction(funcName: String, inputs: Array[Expression]): String = {
if (isSupportedFunction(funcName)) {
super.visitSQLFunction(funcName, inputs)
Expand Down Expand Up @@ -499,6 +507,14 @@ abstract class JdbcDialect extends Serializable with Logging {
@Since("3.3.0")
def isSupportedFunction(funcName: String): Boolean = false

/**
* Returns whether the database supports extract.
* @param extract The V2 Extract to be converted.
* @return True if the database supports extract.
*/
@Since("4.1.0")
def isSupportedExtract(extract: Extract): Boolean = false

/**
* Converts V2 expression to String representing a SQL expression.
* @param expr The V2 expression to be converted.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,25 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

override def isSupportedExtract(extract: Extract): Boolean = {
extract.field match {
case "YEAR_OF_WEEK" => false
case _ => true
}
}

class MySQLSQLBuilder extends JDBCSQLBuilder {

override def visitExtract(extract: Extract): String = {
val field = extract.field
override def visitExtract(field: String, source: String): String = {
field match {
case "DAY_OF_YEAR" => s"DAYOFYEAR(${build(extract.source())})"
case "WEEK" => s"WEEKOFYEAR(${build(extract.source())})"
case "YEAR_OF_WEEK" => visitUnexpectedExpr(extract)
case "DAY_OF_YEAR" => s"DAYOFYEAR($source)"
case "WEEK" => s"WEEKOFYEAR($source)"
// WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
// so we use the formula (WEEKDAY + 1) to follow the ISO standard.
case "DAY_OF_WEEK" => s"(WEEKDAY(${build(extract.source())}) + 1)"
case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)"
// SECOND, MINUTE, HOUR, DAY, MONTH, QUARTER, YEAR are identical on MySQL and Spark for
// both datetime and interval types.
case _ => super.visitExtract(field, build(extract.source()))
case _ => super.visitExtract(field, source)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.internal.MDC
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.expressions.{Expression, NamedReference}
import org.apache.spark.sql.connector.expressions.{Expression, Extract, NamedReference}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand All @@ -54,6 +54,8 @@ private case class PostgresDialect()
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

override def isSupportedExtract(extract: Extract): Boolean = true

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType match {
Expand Down