|
12 | 12 | from sqlglot.dialects.dialect import ( |
13 | 13 | Dialect, |
14 | 14 | NormalizationStrategy, |
15 | | - annotate_with_type_lambda, |
16 | 15 | arg_max_or_min_no_count, |
17 | 16 | binary_from_function, |
18 | 17 | date_add_interval_sql, |
|
34 | 33 | strposition_sql, |
35 | 34 | groupconcat_sql, |
36 | 35 | ) |
| 36 | +from sqlglot.generator import unsupported_args |
37 | 37 | from sqlglot.helper import seq_get, split_num_words |
38 | 38 | from sqlglot.tokens import TokenType |
39 | | -from sqlglot.generator import unsupported_args |
| 39 | +from sqlglot.typing.bigquery import EXPRESSION_SPEC |
40 | 40 |
|
41 | 41 | if t.TYPE_CHECKING: |
42 | 42 | from sqlglot._typing import Lit |
@@ -290,59 +290,6 @@ def _str_to_datetime_sql( |
290 | 290 | return self.func(f"PARSE_{dtype}", fmt, this, expression.args.get("zone")) |
291 | 291 |
|
292 | 292 |
|
293 | | -def _annotate_math_functions(self: TypeAnnotator, expression: E) -> E: |
294 | | - """ |
295 | | - Many BigQuery math functions such as CEIL, FLOOR etc follow this return type convention: |
296 | | - +---------+---------+---------+------------+---------+ |
297 | | - | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | |
298 | | - +---------+---------+---------+------------+---------+ |
299 | | - | OUTPUT | FLOAT64 | NUMERIC | BIGNUMERIC | FLOAT64 | |
300 | | - +---------+---------+---------+------------+---------+ |
301 | | - """ |
302 | | - self._annotate_args(expression) |
303 | | - |
304 | | - this: exp.Expression = expression.this |
305 | | - |
306 | | - self._set_type( |
307 | | - expression, |
308 | | - exp.DataType.Type.DOUBLE if this.is_type(*exp.DataType.INTEGER_TYPES) else this.type, |
309 | | - ) |
310 | | - return expression |
311 | | - |
312 | | - |
313 | | -def _annotate_by_args_with_coerce(self: TypeAnnotator, expression: E) -> E: |
314 | | - """ |
315 | | - +------------+------------+------------+-------------+---------+ |
316 | | - | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | |
317 | | - +------------+------------+------------+-------------+---------+ |
318 | | - | INT64 | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | |
319 | | - | NUMERIC | NUMERIC | NUMERIC | BIGNUMERIC | FLOAT64 | |
320 | | - | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | BIGNUMERIC | FLOAT64 | |
321 | | - | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | FLOAT64 | |
322 | | - +------------+------------+------------+-------------+---------+ |
323 | | - """ |
324 | | - self._annotate_args(expression) |
325 | | - |
326 | | - self._set_type(expression, self._maybe_coerce(expression.this.type, expression.expression.type)) |
327 | | - return expression |
328 | | - |
329 | | - |
330 | | -def _annotate_by_args_approx_top(self: TypeAnnotator, expression: exp.ApproxTopK) -> exp.ApproxTopK: |
331 | | - self._annotate_args(expression) |
332 | | - |
333 | | - struct_type = exp.DataType( |
334 | | - this=exp.DataType.Type.STRUCT, |
335 | | - expressions=[expression.this.type, exp.DataType(this=exp.DataType.Type.BIGINT)], |
336 | | - nested=True, |
337 | | - ) |
338 | | - self._set_type( |
339 | | - expression, |
340 | | - exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[struct_type], nested=True), |
341 | | - ) |
342 | | - |
343 | | - return expression |
344 | | - |
345 | | - |
346 | 293 | @unsupported_args("ins_cost", "del_cost", "sub_cost") |
347 | 294 | def _levenshtein_sql(self: BigQuery.Generator, expression: exp.Levenshtein) -> str: |
348 | 295 | max_dist = expression.args.get("max_dist") |
@@ -398,44 +345,6 @@ def _json_extract_sql(self: BigQuery.Generator, expression: JSON_EXTRACT_TYPE) - |
398 | 345 | return sql |
399 | 346 |
|
400 | 347 |
|
401 | | -def _annotate_concat(self: TypeAnnotator, expression: exp.Concat) -> exp.Concat: |
402 | | - annotated = self._annotate_by_args(expression, "expressions") |
403 | | - |
404 | | - # Args must be BYTES or types that can be cast to STRING, return type is either BYTES or STRING |
405 | | - # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#concat |
406 | | - if not annotated.is_type(exp.DataType.Type.BINARY, exp.DataType.Type.UNKNOWN): |
407 | | - annotated.type = exp.DataType.Type.VARCHAR |
408 | | - |
409 | | - return annotated |
410 | | - |
411 | | - |
412 | | -def _annotate_array(self: TypeAnnotator, expression: exp.Array) -> exp.Array: |
413 | | - array_args = expression.expressions |
414 | | - |
415 | | - # BigQuery behaves as follows: |
416 | | - # |
417 | | - # SELECT t, TYPEOF(t) FROM (SELECT 'foo') AS t -- foo, STRUCT<STRING> |
418 | | - # SELECT ARRAY(SELECT 'foo'), TYPEOF(ARRAY(SELECT 'foo')) -- foo, ARRAY<STRING> |
419 | | - if ( |
420 | | - len(array_args) == 1 |
421 | | - and isinstance(select := array_args[0].unnest(), exp.Select) |
422 | | - and (query_type := select.meta.get("query_type")) is not None |
423 | | - and query_type.is_type(exp.DataType.Type.STRUCT) |
424 | | - and len(query_type.expressions) == 1 |
425 | | - and isinstance(col_def := query_type.expressions[0], exp.ColumnDef) |
426 | | - and (projection_type := col_def.kind) is not None |
427 | | - and not projection_type.is_type(exp.DataType.Type.UNKNOWN) |
428 | | - ): |
429 | | - array_type = exp.DataType( |
430 | | - this=exp.DataType.Type.ARRAY, |
431 | | - expressions=[projection_type.copy()], |
432 | | - nested=True, |
433 | | - ) |
434 | | - return self._annotate_with_type(expression, array_type) |
435 | | - |
436 | | - return self._annotate_by_args(expression, "expressions", array=True) |
437 | | - |
438 | | - |
439 | 348 | class BigQuery(Dialect): |
440 | 349 | WEEK_OFFSET = -1 |
441 | 350 | UNNEST_COLUMN_ONLY = True |
@@ -493,198 +402,7 @@ class BigQuery(Dialect): |
493 | 402 | COERCES_TO[exp.DataType.Type.DECIMAL] |= {exp.DataType.Type.BIGDECIMAL} |
494 | 403 | COERCES_TO[exp.DataType.Type.BIGINT] |= {exp.DataType.Type.BIGDECIMAL} |
495 | 404 |
|
496 | | - # BigQuery maps Type.TIMESTAMP to DATETIME, so we need to amend the inferred types |
497 | | - TYPE_TO_EXPRESSIONS = { |
498 | | - **Dialect.TYPE_TO_EXPRESSIONS, |
499 | | - exp.DataType.Type.BIGINT: { |
500 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BIGINT], |
501 | | - exp.Ascii, |
502 | | - exp.BitwiseAndAgg, |
503 | | - exp.BitwiseOrAgg, |
504 | | - exp.BitwiseXorAgg, |
505 | | - exp.BitwiseCount, |
506 | | - exp.ByteLength, |
507 | | - exp.DenseRank, |
508 | | - exp.FarmFingerprint, |
509 | | - exp.Grouping, |
510 | | - exp.LaxInt64, |
511 | | - exp.Length, |
512 | | - exp.Ntile, |
513 | | - exp.Rank, |
514 | | - exp.RangeBucket, |
515 | | - exp.RegexpInstr, |
516 | | - exp.RowNumber, |
517 | | - exp.Unicode, |
518 | | - }, |
519 | | - exp.DataType.Type.BINARY: { |
520 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BINARY], |
521 | | - exp.ByteString, |
522 | | - exp.CodePointsToBytes, |
523 | | - exp.MD5Digest, |
524 | | - exp.SHA, |
525 | | - exp.SHA2, |
526 | | - exp.Unhex, |
527 | | - }, |
528 | | - exp.DataType.Type.BOOLEAN: { |
529 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BOOLEAN], |
530 | | - exp.IsInf, |
531 | | - exp.IsNan, |
532 | | - exp.JSONBool, |
533 | | - exp.LaxBool, |
534 | | - }, |
535 | | - exp.DataType.Type.DATE: { |
536 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DATE], |
537 | | - exp.DateFromUnixDate, |
538 | | - }, |
539 | | - exp.DataType.Type.DATETIME: { |
540 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DATETIME], |
541 | | - exp.ParseDatetime, |
542 | | - exp.TimestampFromParts, |
543 | | - }, |
544 | | - exp.DataType.Type.DECIMAL: { |
545 | | - exp.ParseNumeric, |
546 | | - }, |
547 | | - exp.DataType.Type.BIGDECIMAL: { |
548 | | - exp.ParseBignumeric, |
549 | | - }, |
550 | | - exp.DataType.Type.DOUBLE: { |
551 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DOUBLE], |
552 | | - exp.Acos, |
553 | | - exp.Acosh, |
554 | | - exp.Asin, |
555 | | - exp.Asinh, |
556 | | - exp.Atan, |
557 | | - exp.Atanh, |
558 | | - exp.Atan2, |
559 | | - exp.Cbrt, |
560 | | - exp.Corr, |
561 | | - exp.Cot, |
562 | | - exp.CosineDistance, |
563 | | - exp.Coth, |
564 | | - exp.CovarPop, |
565 | | - exp.CovarSamp, |
566 | | - exp.Csc, |
567 | | - exp.Csch, |
568 | | - exp.CumeDist, |
569 | | - exp.EuclideanDistance, |
570 | | - exp.Float64, |
571 | | - exp.LaxFloat64, |
572 | | - exp.PercentRank, |
573 | | - exp.Rand, |
574 | | - exp.Sec, |
575 | | - exp.Sech, |
576 | | - exp.Sin, |
577 | | - exp.Sinh, |
578 | | - }, |
579 | | - exp.DataType.Type.JSON: { |
580 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.JSON], |
581 | | - exp.JSONArray, |
582 | | - exp.JSONArrayAppend, |
583 | | - exp.JSONArrayInsert, |
584 | | - exp.JSONObject, |
585 | | - exp.JSONRemove, |
586 | | - exp.JSONSet, |
587 | | - exp.JSONStripNulls, |
588 | | - }, |
589 | | - exp.DataType.Type.TIME: { |
590 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.TIME], |
591 | | - exp.ParseTime, |
592 | | - exp.TimeFromParts, |
593 | | - exp.TimeTrunc, |
594 | | - exp.TsOrDsToTime, |
595 | | - }, |
596 | | - exp.DataType.Type.VARCHAR: { |
597 | | - *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.VARCHAR], |
598 | | - exp.CodePointsToString, |
599 | | - exp.Format, |
600 | | - exp.JSONExtractScalar, |
601 | | - exp.JSONType, |
602 | | - exp.LaxString, |
603 | | - exp.LowerHex, |
604 | | - exp.Normalize, |
605 | | - exp.SafeConvertBytesToString, |
606 | | - exp.Soundex, |
607 | | - exp.Uuid, |
608 | | - }, |
609 | | - exp.DataType.Type.TIMESTAMPTZ: Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.TIMESTAMP], |
610 | | - } |
611 | | - TYPE_TO_EXPRESSIONS.pop(exp.DataType.Type.TIMESTAMP) |
612 | | - |
613 | | - ANNOTATORS = { |
614 | | - **Dialect.ANNOTATORS, |
615 | | - **{ |
616 | | - expr_type: annotate_with_type_lambda(data_type) |
617 | | - for data_type, expressions in TYPE_TO_EXPRESSIONS.items() |
618 | | - for expr_type in expressions |
619 | | - }, |
620 | | - **{ |
621 | | - expr_type: lambda self, e: _annotate_math_functions(self, e) |
622 | | - for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round) |
623 | | - }, |
624 | | - **{ |
625 | | - expr_type: lambda self, e: self._annotate_by_args(e, "this") |
626 | | - for expr_type in ( |
627 | | - exp.Abs, |
628 | | - exp.ArgMax, |
629 | | - exp.ArgMin, |
630 | | - exp.DateTrunc, |
631 | | - exp.DatetimeTrunc, |
632 | | - exp.FirstValue, |
633 | | - exp.GroupConcat, |
634 | | - exp.IgnoreNulls, |
635 | | - exp.JSONExtract, |
636 | | - exp.Lead, |
637 | | - exp.Left, |
638 | | - exp.Lower, |
639 | | - exp.NthValue, |
640 | | - exp.Pad, |
641 | | - exp.PercentileDisc, |
642 | | - exp.RegexpExtract, |
643 | | - exp.RegexpReplace, |
644 | | - exp.Repeat, |
645 | | - exp.Replace, |
646 | | - exp.RespectNulls, |
647 | | - exp.Reverse, |
648 | | - exp.Right, |
649 | | - exp.SafeNegate, |
650 | | - exp.Sign, |
651 | | - exp.Substring, |
652 | | - exp.TimestampTrunc, |
653 | | - exp.Translate, |
654 | | - exp.Trim, |
655 | | - exp.Upper, |
656 | | - ) |
657 | | - }, |
658 | | - exp.ApproxTopSum: lambda self, e: _annotate_by_args_approx_top(self, e), |
659 | | - exp.ApproxTopK: lambda self, e: _annotate_by_args_approx_top(self, e), |
660 | | - exp.ApproxQuantiles: lambda self, e: self._annotate_by_args(e, "this", array=True), |
661 | | - exp.Array: _annotate_array, |
662 | | - exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), |
663 | | - exp.Concat: _annotate_concat, |
664 | | - exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( |
665 | | - e, exp.DataType.build("ARRAY<TIMESTAMP>", dialect="bigquery") |
666 | | - ), |
667 | | - exp.JSONExtractArray: lambda self, e: self._annotate_by_args(e, "this", array=True), |
668 | | - exp.JSONFormat: lambda self, e: self._annotate_with_type( |
669 | | - e, exp.DataType.Type.JSON if e.args.get("to_json") else exp.DataType.Type.VARCHAR |
670 | | - ), |
671 | | - exp.JSONKeysAtDepth: lambda self, e: self._annotate_with_type( |
672 | | - e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery") |
673 | | - ), |
674 | | - exp.JSONValueArray: lambda self, e: self._annotate_with_type( |
675 | | - e, exp.DataType.build("ARRAY<VARCHAR>", dialect="bigquery") |
676 | | - ), |
677 | | - exp.Lag: lambda self, e: self._annotate_by_args(e, "this", "default"), |
678 | | - exp.PercentileCont: lambda self, e: _annotate_by_args_with_coerce(self, e), |
679 | | - exp.RegexpExtractAll: lambda self, e: self._annotate_by_args(e, "this", array=True), |
680 | | - exp.SafeAdd: lambda self, e: _annotate_by_args_with_coerce(self, e), |
681 | | - exp.SafeMultiply: lambda self, e: _annotate_by_args_with_coerce(self, e), |
682 | | - exp.SafeSubtract: lambda self, e: _annotate_by_args_with_coerce(self, e), |
683 | | - exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True), |
684 | | - exp.ToCodePoints: lambda self, e: self._annotate_with_type( |
685 | | - e, exp.DataType.build("ARRAY<BIGINT>", dialect="bigquery") |
686 | | - ), |
687 | | - } |
| 405 | + EXPRESSION_SPEC = EXPRESSION_SPEC.copy() |
688 | 406 |
|
689 | 407 | def normalize_identifier(self, expression: E) -> E: |
690 | 408 | if ( |
|
0 commit comments