Skip to content

Commit d197c47

Browse files
Refactor!!: make annotator format more extensible (#6222)
* Refactor!!: Refactor the Annotator format to make it more extensible * Move `DATE_PARTS` outside of annotator helper --------- Co-authored-by: George Sittas <[email protected]>
1 parent c8b0129 commit d197c47

File tree

18 files changed

+1058
-856
lines changed

18 files changed

+1058
-856
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 3 additions & 285 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from sqlglot.dialects.dialect import (
1313
Dialect,
1414
NormalizationStrategy,
15-
annotate_with_type_lambda,
1615
arg_max_or_min_no_count,
1716
binary_from_function,
1817
date_add_interval_sql,
@@ -34,9 +33,10 @@
3433
strposition_sql,
3534
groupconcat_sql,
3635
)
36+
from sqlglot.generator import unsupported_args
3737
from sqlglot.helper import seq_get, split_num_words
3838
from sqlglot.tokens import TokenType
39-
from sqlglot.generator import unsupported_args
39+
from sqlglot.typing.bigquery import EXPRESSION_SPEC
4040

4141
if t.TYPE_CHECKING:
4242
from sqlglot._typing import Lit
@@ -290,59 +290,6 @@ def _str_to_datetime_sql(
290290
return self.func(f"PARSE_{dtype}", fmt, this, expression.args.get("zone"))
291291

292292

293-
def _annotate_math_functions(self: TypeAnnotator, expression: E) -> E:
294-
"""
295-
Many BigQuery math functions such as CEIL, FLOOR etc follow this return type convention:
296-
+---------+---------+---------+------------+---------+
297-
| INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 |
298-
+---------+---------+---------+------------+---------+
299-
| OUTPUT | FLOAT64 | NUMERIC | BIGNUMERIC | FLOAT64 |
300-
+---------+---------+---------+------------+---------+
301-
"""
302-
self._annotate_args(expression)
303-
304-
this: exp.Expression = expression.this
305-
306-
self._set_type(
307-
expression,
308-
exp.DataType.Type.DOUBLE if this.is_type(*exp.DataType.INTEGER_TYPES) else this.type,
309-
)
310-
return expression
311-
312-
313-
def _annotate_by_args_with_coerce(self: TypeAnnotator, expression: E) -> E:
314-
"""
315-
+------------+------------+------------+-------------+---------+
316-
| INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 |
317-
+------------+------------+------------+-------------+---------+
318-
| INT64 | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 |
319-
| NUMERIC | NUMERIC | NUMERIC | BIGNUMERIC | FLOAT64 |
320-
| BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | FLOAT64 |
321-
| FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 |
322-
+------------+------------+------------+-------------+---------+
323-
"""
324-
self._annotate_args(expression)
325-
326-
self._set_type(expression, self._maybe_coerce(expression.this.type, expression.expression.type))
327-
return expression
328-
329-
330-
def _annotate_by_args_approx_top(self: TypeAnnotator, expression: exp.ApproxTopK) -> exp.ApproxTopK:
331-
self._annotate_args(expression)
332-
333-
struct_type = exp.DataType(
334-
this=exp.DataType.Type.STRUCT,
335-
expressions=[expression.this.type, exp.DataType(this=exp.DataType.Type.BIGINT)],
336-
nested=True,
337-
)
338-
self._set_type(
339-
expression,
340-
exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[struct_type], nested=True),
341-
)
342-
343-
return expression
344-
345-
346293
@unsupported_args("ins_cost", "del_cost", "sub_cost")
347294
def _levenshtein_sql(self: BigQuery.Generator, expression: exp.Levenshtein) -> str:
348295
max_dist = expression.args.get("max_dist")
@@ -398,44 +345,6 @@ def _json_extract_sql(self: BigQuery.Generator, expression: JSON_EXTRACT_TYPE) -
398345
return sql
399346

400347

401-
def _annotate_concat(self: TypeAnnotator, expression: exp.Concat) -> exp.Concat:
402-
annotated = self._annotate_by_args(expression, "expressions")
403-
404-
# Args must be BYTES or types that can be cast to STRING, return type is either BYTES or STRING
405-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#concat
406-
if not annotated.is_type(exp.DataType.Type.BINARY, exp.DataType.Type.UNKNOWN):
407-
annotated.type = exp.DataType.Type.VARCHAR
408-
409-
return annotated
410-
411-
412-
def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array:
413-
array_args = expression.expressions
414-
415-
# BigQuery behaves as follows:
416-
#
417-
# SELECT t, TYPEOF(t) FROM (SELECT 'foo') AS t -- foo, STRUCT<STRING>
418-
# SELECT ARRAY(SELECT 'foo'), TYPEOF(ARRAY(SELECT 'foo')) -- foo, ARRAY<STRING>
419-
if (
420-
len(array_args) == 1
421-
and isinstance(select := array_args[0].unnest(), exp.Select)
422-
and (query_type := select.meta.get("query_type")) is not None
423-
and query_type.is_type(exp.DataType.Type.STRUCT)
424-
and len(query_type.expressions) == 1
425-
and isinstance(col_def := query_type.expressions[0], exp.ColumnDef)
426-
and (projection_type := col_def.kind) is not None
427-
and not projection_type.is_type(exp.DataType.Type.UNKNOWN)
428-
):
429-
array_type = exp.DataType(
430-
this=exp.DataType.Type.ARRAY,
431-
expressions=[projection_type.copy()],
432-
nested=True,
433-
)
434-
return self._annotate_with_type(expression, array_type)
435-
436-
return self._annotate_by_args(expression, "expressions", array=True)
437-
438-
439348
class BigQuery(Dialect):
440349
WEEK_OFFSET = -1
441350
UNNEST_COLUMN_ONLY = True
@@ -493,198 +402,7 @@ class BigQuery(Dialect):
493402
COERCES_TO[exp.DataType.Type.DECIMAL] |= {exp.DataType.Type.BIGDECIMAL}
494403
COERCES_TO[exp.DataType.Type.BIGINT] |= {exp.DataType.Type.BIGDECIMAL}
495404

496-
# BigQuery maps Type.TIMESTAMP to DATETIME, so we need to amend the inferred types
497-
TYPE_TO_EXPRESSIONS = {
498-
**Dialect.TYPE_TO_EXPRESSIONS,
499-
exp.DataType.Type.BIGINT: {
500-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BIGINT],
501-
exp.Ascii,
502-
exp.BitwiseAndAgg,
503-
exp.BitwiseOrAgg,
504-
exp.BitwiseXorAgg,
505-
exp.BitwiseCount,
506-
exp.ByteLength,
507-
exp.DenseRank,
508-
exp.FarmFingerprint,
509-
exp.Grouping,
510-
exp.LaxInt64,
511-
exp.Length,
512-
exp.Ntile,
513-
exp.Rank,
514-
exp.RangeBucket,
515-
exp.RegexpInstr,
516-
exp.RowNumber,
517-
exp.Unicode,
518-
},
519-
exp.DataType.Type.BINARY: {
520-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BINARY],
521-
exp.ByteString,
522-
exp.CodePointsToBytes,
523-
exp.MD5Digest,
524-
exp.SHA,
525-
exp.SHA2,
526-
exp.Unhex,
527-
},
528-
exp.DataType.Type.BOOLEAN: {
529-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BOOLEAN],
530-
exp.IsInf,
531-
exp.IsNan,
532-
exp.JSONBool,
533-
exp.LaxBool,
534-
},
535-
exp.DataType.Type.DATE: {
536-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DATE],
537-
exp.DateFromUnixDate,
538-
},
539-
exp.DataType.Type.DATETIME: {
540-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DATETIME],
541-
exp.ParseDatetime,
542-
exp.TimestampFromParts,
543-
},
544-
exp.DataType.Type.DECIMAL: {
545-
exp.ParseNumeric,
546-
},
547-
exp.DataType.Type.BIGDECIMAL: {
548-
exp.ParseBignumeric,
549-
},
550-
exp.DataType.Type.DOUBLE: {
551-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DOUBLE],
552-
exp.Acos,
553-
exp.Acosh,
554-
exp.Asin,
555-
exp.Asinh,
556-
exp.Atan,
557-
exp.Atanh,
558-
exp.Atan2,
559-
exp.Cbrt,
560-
exp.Corr,
561-
exp.Cot,
562-
exp.CosineDistance,
563-
exp.Coth,
564-
exp.CovarPop,
565-
exp.CovarSamp,
566-
exp.Csc,
567-
exp.Csch,
568-
exp.CumeDist,
569-
exp.EuclideanDistance,
570-
exp.Float64,
571-
exp.LaxFloat64,
572-
exp.PercentRank,
573-
exp.Rand,
574-
exp.Sec,
575-
exp.Sech,
576-
exp.Sin,
577-
exp.Sinh,
578-
},
579-
exp.DataType.Type.JSON: {
580-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.JSON],
581-
exp.JSONArray,
582-
exp.JSONArrayAppend,
583-
exp.JSONArrayInsert,
584-
exp.JSONObject,
585-
exp.JSONRemove,
586-
exp.JSONSet,
587-
exp.JSONStripNulls,
588-
},
589-
exp.DataType.Type.TIME: {
590-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.TIME],
591-
exp.ParseTime,
592-
exp.TimeFromParts,
593-
exp.TimeTrunc,
594-
exp.TsOrDsToTime,
595-
},
596-
exp.DataType.Type.VARCHAR: {
597-
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.VARCHAR],
598-
exp.CodePointsToString,
599-
exp.Format,
600-
exp.JSONExtractScalar,
601-
exp.JSONType,
602-
exp.LaxString,
603-
exp.LowerHex,
604-
exp.Normalize,
605-
exp.SafeConvertBytesToString,
606-
exp.Soundex,
607-
exp.Uuid,
608-
},
609-
exp.DataType.Type.TIMESTAMPTZ: Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.TIMESTAMP],
610-
}
611-
TYPE_TO_EXPRESSIONS.pop(exp.DataType.Type.TIMESTAMP)
612-
613-
ANNOTATORS = {
614-
**Dialect.ANNOTATORS,
615-
**{
616-
expr_type: annotate_with_type_lambda(data_type)
617-
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
618-
for expr_type in expressions
619-
},
620-
**{
621-
expr_type: lambda self, e: _annotate_math_functions(self, e)
622-
for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round)
623-
},
624-
**{
625-
expr_type: lambda self, e: self._annotate_by_args(e, "this")
626-
for expr_type in (
627-
exp.Abs,
628-
exp.ArgMax,
629-
exp.ArgMin,
630-
exp.DateTrunc,
631-
exp.DatetimeTrunc,
632-
exp.FirstValue,
633-
exp.GroupConcat,
634-
exp.IgnoreNulls,
635-
exp.JSONExtract,
636-
exp.Lead,
637-
exp.Left,
638-
exp.Lower,
639-
exp.NthValue,
640-
exp.Pad,
641-
exp.PercentileDisc,
642-
exp.RegexpExtract,
643-
exp.RegexpReplace,
644-
exp.Repeat,
645-
exp.Replace,
646-
exp.RespectNulls,
647-
exp.Reverse,
648-
exp.Right,
649-
exp.SafeNegate,
650-
exp.Sign,
651-
exp.Substring,
652-
exp.TimestampTrunc,
653-
exp.Translate,
654-
exp.Trim,
655-
exp.Upper,
656-
)
657-
},
658-
exp.ApproxTopSum: lambda self, e: _annotate_by_args_approx_top(self, e),
659-
exp.ApproxTopK: lambda self, e: _annotate_by_args_approx_top(self, e),
660-
exp.ApproxQuantiles: lambda self, e: self._annotate_by_args(e, "this", array=True),
661-
exp.Array: _annotate_array,
662-
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
663-
exp.Concat: _annotate_concat,
664-
exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type(
665-
e, exp.DataType.build("ARRAY<TIMESTAMP>", dialect="bigquery")
666-
),
667-
exp.JSONExtractArray: lambda self, e: self._annotate_by_args(e, "this", array=True),
668-
exp.JSONFormat: lambda self, e: self._annotate_with_type(
669-
e, exp.DataType.Type.JSON if e.args.get("to_json") else exp.DataType.Type.VARCHAR
670-
),
671-
exp.JSONKeysAtDepth: lambda self, e: self._annotate_with_type(
672-
e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery")
673-
),
674-
exp.JSONValueArray: lambda self, e: self._annotate_with_type(
675-
e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery")
676-
),
677-
exp.Lag: lambda self, e: self._annotate_by_args(e, "this", "default"),
678-
exp.PercentileCont: lambda self, e: _annotate_by_args_with_coerce(self, e),
679-
exp.RegexpExtractAll: lambda self, e: self._annotate_by_args(e, "this", array=True),
680-
exp.SafeAdd: lambda self, e: _annotate_by_args_with_coerce(self, e),
681-
exp.SafeMultiply: lambda self, e: _annotate_by_args_with_coerce(self, e),
682-
exp.SafeSubtract: lambda self, e: _annotate_by_args_with_coerce(self, e),
683-
exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True),
684-
exp.ToCodePoints: lambda self, e: self._annotate_with_type(
685-
e, exp.DataType.build("ARRAY<BIGINT>", dialect="bigquery")
686-
),
687-
}
405+
EXPRESSION_SPEC = EXPRESSION_SPEC.copy()
688406

689407
def normalize_identifier(self, expression: E) -> E:
690408
if (

0 commit comments

Comments
 (0)