Skip to content

Commit b98a216

Browse files
committed
Fixed review comments
1 parent 5a71082 commit b98a216

File tree

6 files changed

+157
-1099
lines changed

6 files changed

+157
-1099
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala

Lines changed: 59 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -25,92 +25,18 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult}
2626
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
2727
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType, toSQLValue}
28-
import org.apache.spark.sql.catalyst.expressions.codegen._
2928
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3029
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
3130
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
32-
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
3331
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3432
import org.apache.spark.sql.catalyst.util.TimeFormatter
3533
import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
3634
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
37-
import org.apache.spark.sql.internal.SQLConf
3835
import org.apache.spark.sql.internal.types.StringTypeWithCollation
3936
import org.apache.spark.sql.types.{AbstractDataType, AnyTimeType, ByteType, DataType, DayTimeIntervalType, DecimalType, IntegerType, IntegralType, LongType, NumericType, ObjectType, TimeType}
4037
import org.apache.spark.sql.types.DayTimeIntervalType.{HOUR, SECOND}
4138
import org.apache.spark.unsafe.types.UTF8String
4239

43-
/**
44-
* Helper trait for TIME conversion expressions with consistent error handling.
45-
*/
46-
trait TimeConversionErrorHandling {
47-
def failOnError: Boolean
48-
49-
/** Wraps evaluation with error handling (throws in ANSI mode, null otherwise). */
50-
protected def evalWithErrorHandling[T](f: => T): Any = {
51-
try {
52-
f
53-
} catch {
54-
case e: DateTimeException if failOnError =>
55-
throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e)
56-
case e: ArithmeticException if failOnError =>
57-
throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(
58-
new DateTimeException(s"Overflow in TIME conversion: ${e.getMessage}"))
59-
case _: DateTimeException | _: ArithmeticException => null
60-
}
61-
}
62-
63-
/** Generates error handling code (DateTimeException + ArithmeticException). */
64-
protected def doGenErrorHandling(
65-
ctx: CodegenContext,
66-
ev: ExprCode,
67-
utilCall: String): String = {
68-
val dateTimeErrorBranch = if (failOnError) {
69-
"throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e);"
70-
} else {
71-
s"${ev.isNull} = true;"
72-
}
73-
74-
val arithmeticErrorBranch = if (failOnError) {
75-
s"""
76-
|throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(
77-
| new java.time.DateTimeException("Overflow in TIME conversion: " + e.getMessage()));
78-
|""".stripMargin
79-
} else {
80-
s"${ev.isNull} = true;"
81-
}
82-
83-
s"""
84-
|try {
85-
| ${ev.value} = $utilCall;
86-
|} catch (java.time.DateTimeException e) {
87-
| $dateTimeErrorBranch
88-
|} catch (java.lang.ArithmeticException e) {
89-
| $arithmeticErrorBranch
90-
|}
91-
|""".stripMargin
92-
}
93-
94-
/** Generates error handling code (DateTimeException only). */
95-
protected def doGenDateTimeError(
96-
ctx: CodegenContext,
97-
ev: ExprCode,
98-
utilCall: String): String = {
99-
val errorBranch = if (failOnError) {
100-
"throw QueryExecutionErrors.ansiDateTimeArgumentOutOfRange(e);"
101-
} else {
102-
s"${ev.isNull} = true;"
103-
}
104-
s"""
105-
|try {
106-
| ${ev.value} = $utilCall;
107-
|} catch (java.time.DateTimeException e) {
108-
| $errorBranch
109-
|}
110-
|""".stripMargin
111-
}
112-
}
113-
11440
/**
11541
* Parses a column to a time based on the given format.
11642
*/
@@ -824,36 +750,6 @@ case class TimeTrunc(unit: Expression, time: Expression)
824750
}
825751
}
826752

827-
abstract class IntegralToTimeBase
828-
extends UnaryExpression with ExpectsInputTypes with TimeConversionErrorHandling {
829-
protected def upScaleFactor: Long
830-
def failOnError: Boolean = SQLConf.get.ansiEnabled
831-
832-
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
833-
override def dataType: DataType = TimeType(TimeType.MICROS_PRECISION)
834-
override def nullable: Boolean = true
835-
override def nullIntolerant: Boolean = true
836-
837-
override protected def nullSafeEval(input: Any): Any = {
838-
evalWithErrorHandling {
839-
DateTimeUtils.timeFromIntegral(input.asInstanceOf[Number].longValue(), upScaleFactor)
840-
}
841-
}
842-
843-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
844-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
845-
nullSafeCodeGen(ctx, ev, c =>
846-
doGenErrorHandling(ctx, ev, s"$dtu.timeFromIntegral($c, ${upScaleFactor}L)")
847-
)
848-
}
849-
}
850-
851-
abstract class TimeToLongBase extends UnaryExpression with ExpectsInputTypes {
852-
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimeType)
853-
override def dataType: DataType = LongType
854-
override def nullIntolerant: Boolean = true
855-
}
856-
857753
// scalastyle:off line.size.limit
858754
@ExpressionDescription(
859755
usage = "_FUNC_(seconds) - Creates a TIME value from seconds since midnight.",
@@ -877,28 +773,18 @@ abstract class TimeToLongBase extends UnaryExpression with ExpectsInputTypes {
877773
group = "datetime_funcs")
878774
// scalastyle:on line.size.limit
879775
case class TimeFromSeconds(child: Expression)
880-
extends UnaryExpression with ExpectsInputTypes with TimeConversionErrorHandling {
881-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
882-
override def dataType: DataType = TimeType(TimeType.MICROS_PRECISION)
883-
override def nullable: Boolean = true
884-
override def nullIntolerant: Boolean = true
885-
886-
def failOnError: Boolean = SQLConf.get.ansiEnabled
776+
extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes {
887777

888-
override def nullSafeEval(input: Any): Any = {
889-
evalWithErrorHandling {
890-
DateTimeUtils.timeFromSeconds(input, child.dataType)
891-
}
892-
}
893-
894-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
895-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
896-
val dt = ctx.addReferenceObj("childDataType", child.dataType)
897-
nullSafeCodeGen(ctx, ev, c =>
898-
doGenErrorHandling(ctx, ev, s"$dtu.timeFromSeconds($c, $dt)")
899-
)
900-
}
778+
override def replacement: Expression = StaticInvoke(
779+
classOf[DateTimeUtils.type],
780+
TimeType(TimeType.MICROS_PRECISION),
781+
"timeFromSeconds",
782+
Seq(child),
783+
Seq(child.dataType)
784+
)
901785

786+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
787+
override def dataType: DataType = TimeType(TimeType.MICROS_PRECISION)
902788
override def prettyName: String = "time_from_seconds"
903789

904790
override protected def withNewChildInternal(newChild: Expression): TimeFromSeconds =
@@ -927,10 +813,18 @@ case class TimeFromSeconds(child: Expression)
927813
group = "datetime_funcs")
928814
// scalastyle:on line.size.limit
929815
case class TimeFromMillis(child: Expression)
930-
extends IntegralToTimeBase {
816+
extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes {
931817

932-
override def upScaleFactor: Long = NANOS_PER_MILLIS
818+
override def replacement: Expression = StaticInvoke(
819+
classOf[DateTimeUtils.type],
820+
TimeType(TimeType.MICROS_PRECISION),
821+
"timeFromMillis",
822+
Seq(child),
823+
Seq(child.dataType)
824+
)
933825

826+
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
827+
override def dataType: DataType = TimeType(TimeType.MICROS_PRECISION)
934828
override def prettyName: String = "time_from_millis"
935829

936830
override protected def withNewChildInternal(newChild: Expression): TimeFromMillis =
@@ -959,10 +853,18 @@ case class TimeFromMillis(child: Expression)
959853
group = "datetime_funcs")
960854
// scalastyle:on line.size.limit
961855
case class TimeFromMicros(child: Expression)
962-
extends IntegralToTimeBase {
856+
extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes {
963857

964-
override def upScaleFactor: Long = NANOS_PER_MICROS
858+
override def replacement: Expression = StaticInvoke(
859+
classOf[DateTimeUtils.type],
860+
TimeType(TimeType.MICROS_PRECISION),
861+
"timeFromMicros",
862+
Seq(child),
863+
Seq(child.dataType)
864+
)
965865

866+
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
867+
override def dataType: DataType = TimeType(TimeType.MICROS_PRECISION)
966868
override def prettyName: String = "time_from_micros"
967869

968870
override protected def withNewChildInternal(newChild: Expression): TimeFromMicros =
@@ -992,28 +894,18 @@ case class TimeFromMicros(child: Expression)
992894
group = "datetime_funcs")
993895
// scalastyle:on line.size.limit
994896
case class TimeToSeconds(child: Expression)
995-
extends UnaryExpression with ImplicitCastInputTypes with TimeConversionErrorHandling {
897+
extends UnaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {
898+
899+
override def replacement: Expression = StaticInvoke(
900+
classOf[DateTimeUtils.type],
901+
DecimalType(14, 6),
902+
"timeToSeconds",
903+
Seq(child),
904+
Seq(child.dataType)
905+
)
996906

997907
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimeType)
998908
override def dataType: DataType = DecimalType(14, 6)
999-
override def nullable: Boolean = true
1000-
override def nullIntolerant: Boolean = true
1001-
1002-
def failOnError: Boolean = SQLConf.get.ansiEnabled
1003-
1004-
protected override def nullSafeEval(input: Any): Any = {
1005-
evalWithErrorHandling {
1006-
DateTimeUtils.timeToSeconds(input.asInstanceOf[Long])
1007-
}
1008-
}
1009-
1010-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1011-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
1012-
nullSafeCodeGen(ctx, ev, nanos =>
1013-
doGenDateTimeError(ctx, ev, s"$dtu.timeToSeconds($nanos)")
1014-
)
1015-
}
1016-
1017909
override def prettyName: String = "time_to_seconds"
1018910

1019911
override protected def withNewChildInternal(newChild: Expression): TimeToSeconds =
@@ -1043,17 +935,18 @@ case class TimeToSeconds(child: Expression)
1043935
group = "datetime_funcs")
1044936
// scalastyle:on line.size.limit
1045937
case class TimeToMillis(child: Expression)
1046-
extends TimeToLongBase {
1047-
1048-
override def nullSafeEval(input: Any): Any = {
1049-
DateTimeUtils.timeToMillis(input.asInstanceOf[Number].longValue())
1050-
}
938+
extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes {
1051939

1052-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1053-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
1054-
defineCodeGen(ctx, ev, c => s"$dtu.timeToMillis($c)")
1055-
}
940+
override def replacement: Expression = StaticInvoke(
941+
classOf[DateTimeUtils.type],
942+
LongType,
943+
"timeToMillis",
944+
Seq(child),
945+
Seq(child.dataType)
946+
)
1056947

948+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimeType)
949+
override def dataType: DataType = LongType
1057950
override def prettyName: String = "time_to_millis"
1058951

1059952
override protected def withNewChildInternal(newChild: Expression): TimeToMillis =
@@ -1083,17 +976,18 @@ case class TimeToMillis(child: Expression)
1083976
group = "datetime_funcs")
1084977
// scalastyle:on line.size.limit
1085978
case class TimeToMicros(child: Expression)
1086-
extends TimeToLongBase {
1087-
1088-
override def nullSafeEval(input: Any): Any = {
1089-
DateTimeUtils.timeToMicros(input.asInstanceOf[Number].longValue())
1090-
}
979+
extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes {
1091980

1092-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1093-
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
1094-
defineCodeGen(ctx, ev, c => s"$dtu.timeToMicros($c)")
1095-
}
981+
override def replacement: Expression = StaticInvoke(
982+
classOf[DateTimeUtils.type],
983+
LongType,
984+
"timeToMicros",
985+
Seq(child),
986+
Seq(child.dataType)
987+
)
1096988

989+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimeType)
990+
override def dataType: DataType = LongType
1097991
override def prettyName: String = "time_to_micros"
1098992

1099993
override protected def withNewChildInternal(newChild: Expression): TimeToMicros =

0 commit comments

Comments
 (0)