Skip to content

[SPARK-51554][SQL] Add the time_trunc() function #51547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,11 @@
"expects a string literal, but got <invalidValue>."
]
},
"TIMETRUNC_UNIT" : {
"message" : [
"expects one of the units 'HOUR', 'MINUTE', 'SECOND', 'MILLISECOND', 'MICROSECOND', but got the string literal <invalidValue>."
]
},
"ZERO_INDEX" : {
"message" : [
"expects %1$, %2$ and so on, but got %0$."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ object FunctionRegistry {
expression[WindowTime]("window_time"),
expression[MakeDate]("make_date"),
expression[MakeTime]("make_time"),
expression[TimeTrunc]("time_trunc"),
expression[MakeTimestamp]("make_timestamp"),
expression[TryMakeTimestamp]("try_make_timestamp"),
expression[MonthName]("monthname"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,56 @@ case class SubtractTimes(left: Expression, right: Expression)
newLeft: Expression, newRight: Expression): SubtractTimes =
copy(left = newLeft, right = newRight)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(unit, expr) - Returns time `expr` truncated to the unit specified by the unit `unit`.
""",
arguments = """
Arguments:
* unit - the unit representing the unit to be truncated to
- "HOUR" - zero out the minutes and seconds with fraction part
- "MINUTE" - zero out the seconds with fraction part
- "SECOND" - zero out the seconds with fraction part
- "MILLISECOND" - zero out the microseconds
- "MICROSECOND" - zero out the nanoseconds
* expr - a TIME with a valid time format
""",
examples = """
Examples:
> SELECT _FUNC_('HOUR', TIME'09:32:05.359');
09:00:00
> SELECT _FUNC_('MILLISECOND', TIME'09:32:05.123456');
09:32:05.123
""",
group = "datetime_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class TimeTrunc(unit: Expression, time: Expression)
extends BinaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {

override def left: Expression = unit
override def right: Expression = time

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true), AnyTimeType)

override def dataType: DataType = time.dataType

override def prettyName: String = "time_trunc"

override protected def withNewChildrenInternal(
newUnit: Expression, newTime: Expression): TimeTrunc =
copy(unit = newUnit, time = newTime)

override def replacement: Expression = {
StaticInvoke(
classOf[DateTimeUtils.type],
dataType,
"timeTrunc",
Seq(unit, time),
Seq(unit.dataType, time.dataType)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,53 @@ object DateTimeUtils extends SparkDateTimeUtils {
}
}

/**
* Returns time truncated to the unit specified by the level.
*/
private def parseTimeTruncLevel(level: Int): ChronoUnit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you simplify the implementation and parse input string units to ChronoUnits directly.

level match {
case TRUNC_TO_HOUR => ChronoUnit.HOURS
case TRUNC_TO_MINUTE => ChronoUnit.MINUTES
case TRUNC_TO_SECOND => ChronoUnit.SECONDS
case TRUNC_TO_MILLISECOND => ChronoUnit.MILLIS
case TRUNC_TO_MICROSECOND => ChronoUnit.MICROS
case _ =>
throw new IllegalArgumentException(s"Unsupported time truncation level: $level")
}
}

/**
* Returns time truncated to the unit specified by the level.
*/
private def timeTrunc(level: Int, nanos: Long): Long = {
localTimeToNanos(nanosToLocalTime(nanos).truncatedTo(parseTimeTruncLevel(level)))
}

/**
* Set of supported time truncation levels for TIME values.
*/
private val supportedTimeTruncLevels = Set(
TRUNC_TO_HOUR,
TRUNC_TO_MINUTE,
TRUNC_TO_SECOND,
TRUNC_TO_MILLISECOND,
TRUNC_TO_MICROSECOND
)

/**
* Returns time truncated to the unit specified by the level. Trunc level should be generated
* using `parseTruncLevel()`, and should be between TRUNC_TO_HOUR and TRUNC_TO_MICROSECOND.
*/
def timeTrunc(level: UTF8String, nanos: Long): Long = {
require(level != null, "Truncation level must not be null.")
require(nanos >= 0, "Nanoseconds must be non-negative.")
val truncLevel = parseTruncLevel(level)
if (!supportedTimeTruncLevels.contains(truncLevel)) {
throw QueryExecutionErrors.invalidTimeTruncUnitError("time_trunc", level.toString)
}
timeTrunc(truncLevel, nanos)
}

/**
* Returns the truncate level, could be from TRUNC_TO_MICROSECOND to TRUNC_TO_YEAR,
* or TRUNC_INVALID, TRUNC_INVALID means unsupported truncate level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3067,6 +3067,21 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}

// Throws a SparkIllegalArgumentException when an invalid time truncation unit is specified.
// Note that the supported units are: HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND.
def invalidTimeTruncUnitError(
functionName: String,
invalidValue: String): Throwable = {
new SparkIllegalArgumentException(
errorClass = "INVALID_PARAMETER_VALUE.TIMETRUNC_UNIT",
messageParameters = Map(
"functionName" -> toSQLId(functionName),
"parameter" -> toSQLId("unit"),
"invalidValue" -> toSQLValue(invalidValue)
)
)
}

// Throws a SparkRuntimeException when a CHECK constraint is violated, including details of the
// violation. This is a Java-friendly version of the above method.
def checkViolationJava(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions

import java.time.{Duration, LocalTime}

import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite}
import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLValue}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.localTimeToNanos
import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, IntegerType, StringType, TimeType}
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, SECOND}

Expand Down Expand Up @@ -418,4 +419,108 @@ class TimeExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}

test("SPARK-51554: TimeTrunc") {
// Test cases for different truncation units - 09:32:05.359123.
val testTime = localTime(9, 32, 5, 359123)

// Test HOUR truncation.
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
// Test MINUTE truncation.
checkEvaluation(
TimeTrunc(Literal("MINUTE"), Literal(testTime, TimeType())),
localTime(9, 32, 0, 0)
)
// Test SECOND truncation.
checkEvaluation(
TimeTrunc(Literal("SECOND"), Literal(testTime, TimeType())),
localTime(9, 32, 5, 0)
)
// Test MILLISECOND truncation.
checkEvaluation(
TimeTrunc(Literal("MILLISECOND"), Literal(testTime, TimeType())),
localTime(9, 32, 5, 359000)
)
// Test MICROSECOND truncation.
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal(testTime, TimeType())),
testTime
)

// Test case-insensitive units.
checkEvaluation(
TimeTrunc(Literal("hour"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("Hour"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("hoUR"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)

// Test invalid units.
val invalidUnits: Seq[String] = Seq("MS", "INVALID", "ABC", "XYZ", " ", "")
invalidUnits.foreach { unit =>
checkError(
exception = intercept[SparkIllegalArgumentException] {
TimeTrunc(Literal(unit), Literal(testTime, TimeType())).eval()
},
condition = "INVALID_PARAMETER_VALUE.TIMETRUNC_UNIT",
parameters = Map(
"functionName" -> "`time_trunc`",
"parameter" -> "`unit`",
"invalidValue" -> s"'$unit'"
)
)
}

// Test null inputs.
checkEvaluation(
TimeTrunc(Literal.create(null, StringType), Literal(testTime, TimeType())),
null
)
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal.create(null, TimeType())),
null
)
checkEvaluation(
TimeTrunc(Literal.create(null, StringType), Literal.create(null, TimeType())),
null
)

// Test edge cases.
val midnightTime = localTime(0, 0, 0, 0)
val supportedUnits: Seq[String] = Seq("HOUR", "MINUTE", "SECOND", "MILLISECOND", "MICROSECOND")
supportedUnits.foreach { unit =>
checkEvaluation(
TimeTrunc(Literal(unit), Literal(midnightTime, TimeType())),
midnightTime
)
}

val maxTime = localTimeToNanos(LocalTime.of(23, 59, 59, 999999999))
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal(maxTime, TimeType())),
localTime(23, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal(maxTime, TimeType())),
localTimeToNanos(LocalTime.of(23, 59, 59, 999999000))
)

// Test precision loss.
val timeWithMicroPrecision = localTime(15, 30, 45, 123456)
val timeTruncMin = TimeTrunc(Literal("MINUTE"), Literal(timeWithMicroPrecision, TimeType(3)))
assert(timeTruncMin.dataType == TimeType(3))
checkEvaluation(timeTruncMin, localTime(15, 30, 0, 0))
val timeTruncSec = TimeTrunc(Literal("SECOND"), Literal(timeWithMicroPrecision, TimeType(3)))
assert(timeTruncSec.dataType == TimeType(3))
checkEvaluation(timeTruncSec, localTime(15, 30, 45, 0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,43 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
}
}

test("SPARK-51554: time truncation using timeTrunc") {
// 01:02:03.400500600
val input = localTimeToNanos(LocalTime.of(1, 2, 3, 400500600))
// Truncate the minutes, seconds, and fractions of seconds. Result is: 01:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("HOUR"), input) === 3600000000000L)
// Truncate the seconds and fractions of seconds. Result is: 01:02:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MINUTE"), input) === 3720000000000L)
// Truncate the fractions of seconds. Result is: 01:02:03.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("SECOND"), input) === 3723000000000L)
// Truncate the milliseconds. Result is: 01:02:03.400.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MILLISECOND"), input) === 3723400000000L)
// Truncate the microseconds. Result is: 01:02:03.400500.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MICROSECOND"), input) === 3723400500000L)

// 00:00:00
val midnight = localTimeToNanos(LocalTime.MIDNIGHT)
// Truncate the minutes, seconds, and fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("HOUR"), midnight) === 0)
// Truncate the seconds and fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MINUTE"), midnight) === 0)
// Truncate the fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("SECOND"), midnight) === 0)
// Truncate the milliseconds. Result is: Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MILLISECOND"), midnight) === 0)
// Truncate the microseconds. Result is: Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MICROSECOND"), midnight) === 0)

// Unsupported truncation levels.
Seq("DAY", "WEEK", "MONTH", "QUARTER", "YEAR", "INVALID", "ABC", "XYZ", "MS", " ", "", null).
map(UTF8String.fromString).foreach { level =>
intercept[IllegalArgumentException] {
DateTimeUtils.timeTrunc(level, input)
DateTimeUtils.timeTrunc(level, midnight)
}
}
}

test("SPARK-35664: microseconds to LocalDateTime") {
assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00"))
assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@
| org.apache.spark.sql.catalyst.expressions.Subtract | - | SELECT 2 - 1 | struct<(2 - 1):int> |
| org.apache.spark.sql.catalyst.expressions.Tan | tan | SELECT tan(0) | struct<TAN(0):double> |
| org.apache.spark.sql.catalyst.expressions.Tanh | tanh | SELECT tanh(0) | struct<TANH(0):double> |
| org.apache.spark.sql.catalyst.expressions.TimeTrunc | time_trunc | SELECT time_trunc('HOUR', TIME'09:32:05.359') | struct<time_trunc(HOUR, TIME '09:32:05.359'):time(6)> |
| org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct<a:string,start:timestamp,end:timestamp,cnt:bigint> |
| org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct<to_binary(abc, utf-8):binary> |
| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_char | SELECT to_char(454, '999') | struct<to_char(454, 999):string> |
Expand Down
Loading