Skip to content

Commit 2fbbf6a

Browse files
authored
Fix(optimizer)!: infer timestamp function types as TIMESTAMPTZ for bigquery (#4914)
1 parent f17004e commit 2fbbf6a

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

sqlglot/dialects/bigquery.py

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlglot.dialects.dialect import (
99
Dialect,
1010
NormalizationStrategy,
11+
annotate_with_type_lambda,
1112
arg_max_or_min_no_count,
1213
binary_from_function,
1314
date_add_interval_sql,
@@ -398,8 +399,20 @@ class BigQuery(Dialect):
398399
# All set operations require either a DISTINCT or ALL specifier
399400
SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys((exp.Except, exp.Intersect, exp.Union), None)
400401

402+
# BigQuery maps Type.TIMESTAMP to DATETIME, so we need to amend the inferred types
403+
TYPE_TO_EXPRESSIONS = {
404+
**Dialect.TYPE_TO_EXPRESSIONS,
405+
exp.DataType.Type.TIMESTAMPTZ: Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.TIMESTAMP],
406+
}
407+
TYPE_TO_EXPRESSIONS.pop(exp.DataType.Type.TIMESTAMP)
408+
401409
ANNOTATORS = {
402410
**Dialect.ANNOTATORS,
411+
**{
412+
expr_type: annotate_with_type_lambda(data_type)
413+
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
414+
for expr_type in expressions
415+
},
403416
**{
404417
expr_type: lambda self, e: _annotate_math_functions(self, e)
405418
for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round)

sqlglot/dialects/dialect.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
}
5252

5353

54-
def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
54+
def annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
5555
return lambda self, e: self._annotate_with_type(e, data_type)
5656

5757

@@ -683,15 +683,15 @@ class Dialect(metaclass=_Dialect):
683683
exp.ParseJSON,
684684
},
685685
exp.DataType.Type.TIME: {
686+
exp.CurrentTime,
686687
exp.Time,
688+
exp.TimeAdd,
689+
exp.TimeSub,
687690
},
688691
exp.DataType.Type.TIMESTAMP: {
689-
exp.CurrentTime,
690692
exp.CurrentTimestamp,
691693
exp.StrToTime,
692-
exp.TimeAdd,
693694
exp.TimeStrToTime,
694-
exp.TimeSub,
695695
exp.TimestampAdd,
696696
exp.TimestampSub,
697697
exp.UnixToTime,
@@ -733,7 +733,7 @@ class Dialect(metaclass=_Dialect):
733733
for expr_type in subclasses(exp.__name__, exp.Binary)
734734
},
735735
**{
736-
expr_type: _annotate_with_type_lambda(data_type)
736+
expr_type: annotate_with_type_lambda(data_type)
737737
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
738738
for expr_type in expressions
739739
},

tests/dialects/test_bigquery.py

+16
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sqlglot.helper import logger as helper_logger
1616
from sqlglot.parser import logger as parser_logger
1717
from tests.dialects.test_dialect import Validator
18+
from sqlglot.optimizer.annotate_types import annotate_types
1819

1920

2021
class TestBigQuery(Validator):
@@ -2366,3 +2367,18 @@ def test_string_agg(self):
23662367
"STRING_AGG(DISTINCT a ORDER BY b DESC, c DESC LIMIT 10)",
23672368
"STRING_AGG(DISTINCT a, ',' ORDER BY b DESC, c DESC LIMIT 10)",
23682369
)
2370+
2371+
def test_annotate_timestamps(self):
2372+
sql = """
2373+
SELECT
2374+
CURRENT_TIMESTAMP() AS curr_ts,
2375+
TIMESTAMP_SECONDS(2) AS ts_seconds,
2376+
PARSE_TIMESTAMP('%c', 'Thu Dec 25 07:30:00 2008', 'UTC') AS parsed_ts,
2377+
TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE) AS ts_add,
2378+
TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE) AS ts_sub,
2379+
"""
2380+
2381+
annotated = annotate_types(self.parse_one(sql), dialect="bigquery")
2382+
2383+
for select in annotated.selects:
2384+
self.assertEqual(select.type.sql("bigquery"), "TIMESTAMP")

tests/fixtures/optimizer/annotate_functions.sql

+9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ INT;
1919
LEAST(1, 2.5, 3);
2020
DOUBLE;
2121

22+
CURRENT_TIME();
23+
TIME;
24+
25+
TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL 2 HOUR);
26+
TIME;
27+
28+
TIME_SUB(CAST('09:05:03' AS TIME), INTERVAL 2 HOUR);
29+
TIME;
30+
2231
--------------------------------------
2332
-- Spark2 / Spark3 / Databricks
2433
--------------------------------------

0 commit comments

Comments
 (0)