diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index aa4ed692d5745..444a085c93280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -3695,17 +3695,60 @@ case class ConvertTimezone( TimestampNTZType) override def dataType: DataType = TimestampNTZType + // Resolve foldable timezone arguments once to avoid per-row ZoneId.of lookups, + // which involve zone-id normalization plus a ZoneRulesProvider map lookup and are + // not free even when the resulting ZoneId is cached. + @transient private lazy val sourceZoneId: Option[ZoneId] = foldableZoneId(sourceTz) + @transient private lazy val targetZoneId: Option[ZoneId] = foldableZoneId(targetTz) + + private def foldableZoneId(e: Expression): Option[ZoneId] = { + if (e.foldable) { + Option(e.eval()).map(v => DateTimeUtils.getZoneId(v.asInstanceOf[UTF8String].toString)) + } else { + None + } + } + override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { - DateTimeUtils.convertTimestampNtzToAnotherTz( - srcTz.asInstanceOf[UTF8String].toString, - tgtTz.asInstanceOf[UTF8String].toString, - micros.asInstanceOf[Long]) + val srcZone = sourceZoneId.getOrElse( + DateTimeUtils.getZoneId(srcTz.asInstanceOf[UTF8String].toString)) + val tgtZone = targetZoneId.getOrElse( + DateTimeUtils.getZoneId(tgtTz.asInstanceOf[UTF8String].toString)) + DateTimeUtils.convertTimestampNtzToAnotherTz(srcZone, tgtZone, micros.asInstanceOf[Long]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (srcTz, tgtTz, micros) => - s"""$dtu.convertTimestampNtzToAnotherTz($srcTz.toString(), $tgtTz.toString(), $micros)""") + val tzClass = classOf[ZoneId].getName + + // If a foldable timezone literal is null, the expression always returns null. + // (Constant folding usually catches this, but be defensive.) + def isFoldableNull(e: Expression): Boolean = e.foldable && e.eval() == null + if (isFoldableNull(sourceTz) || isFoldableNull(targetTz)) { + ev.copy(code = code""" + |boolean ${ev.isNull} = true; + |long ${ev.value} = 0L; + """.stripMargin) + } else { + // Cache foldable ZoneIds in mutable state; non-foldable ones are resolved per row. + def cachedZoneTerm(e: Expression): Option[String] = { + if (!e.foldable) None + else { + val tz = e.eval().asInstanceOf[UTF8String].toString + val escapedTz = StringEscapeUtils.escapeJava(tz) + Some(ctx.addMutableState(tzClass, "tz", + v => s"""$v = $dtu.getZoneId("$escapedTz");""")) + } + } + val srcZoneTermOpt = cachedZoneTerm(sourceTz) + val tgtZoneTermOpt = cachedZoneTerm(targetTz) + + nullSafeCodeGen(ctx, ev, (srcTzCode, tgtTzCode, micros) => { + val srcZoneExpr = srcZoneTermOpt.getOrElse(s"$dtu.getZoneId($srcTzCode.toString())") + val tgtZoneExpr = tgtZoneTermOpt.getOrElse(s"$dtu.getZoneId($tgtTzCode.toString())") + s"${ev.value} = $dtu.convertTimestampNtzToAnotherTz($srcZoneExpr, $tgtZoneExpr, $micros);" + }) + } } override def prettyName: String = "convert_timezone" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 4c4b7160de509..76279f016fe33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -566,9 +566,22 @@ object DateTimeUtils extends SparkDateTimeUtils { * the input timestamp in the input time zone, but in the destination time zone. */ def convertTimestampNtzToAnotherTz(sourceTz: String, targetTz: String, micros: Long): Long = { + convertTimestampNtzToAnotherTz(getZoneId(sourceTz), getZoneId(targetTz), micros) + } + + /** + * Same as [[convertTimestampNtzToAnotherTz(String, String, Long)]] but accepts pre-resolved + * `ZoneId` objects. Useful when the timezone arguments are foldable and have been resolved + * once at expression construction time, avoiding per-row `ZoneId.of` lookups (zone-id + * normalization plus a `ZoneRulesProvider` map lookup) on every input row. + */ + def convertTimestampNtzToAnotherTz( + sourceZoneId: ZoneId, + targetZoneId: ZoneId, + micros: Long): Long = { val ldt = microsToLocalDateTime(micros) - .atZone(getZoneId(sourceTz)) - .withZoneSameInstant(getZoneId(targetTz)) + .atZone(sourceZoneId) + .withZoneSameInstant(targetZoneId) .toLocalDateTime localDateTimeToMicros(ldt) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 824d08d67c508..9bd84f41530af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -1892,6 +1892,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-56745: convert_timezone with non-foldable timezone arguments") { + // Exercises the fallback path in ConvertTimezone where source/target timezones + // are not foldable and must be resolved per row, complementing the foldable-cache + // path covered by SPARK-37552. + checkEvaluation( + ConvertTimezone( + NonFoldableLiteral.create("Europe/Brussels", StringType), + NonFoldableLiteral.create("Europe/Moscow", StringType), + Literal(LocalDateTime.of(2022, 3, 27, 3, 0, 0))), + LocalDateTime.of(2022, 3, 27, 4, 0, 0)) + + // Mixed: foldable source, non-foldable target. + checkEvaluation( + ConvertTimezone( + Literal("America/Los_Angeles"), + NonFoldableLiteral.create("UTC", StringType), + Literal(LocalDateTime.of(2022, 1, 1, 0, 0, 0))), + LocalDateTime.of(2022, 1, 1, 8, 0, 0)) + + // Non-foldable null timezone -- nullIntolerant must propagate null. + checkEvaluation( + ConvertTimezone( + NonFoldableLiteral.create(null, StringType), + Literal("UTC"), + Literal(LocalDateTime.of(2022, 1, 1, 0, 0, 0))), + null) + } + test("SPARK-38195: add a quantity of interval units to a timestamp") { // Check case-insensitivity checkEvaluation(