-
Notifications
You must be signed in to change notification settings - Fork 1k
feat(optimizer)!: Annotate parameterized numeric types for Snowflake functions #6230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import typing as t | ||
|
|
||
| from decimal import Decimal | ||
| from sqlglot import exp | ||
| from sqlglot.typing import EXPRESSION_METADATA | ||
|
|
||
|
|
@@ -81,6 +82,230 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex | |
| return expression | ||
|
|
||
|
|
||
| def _extract_type_precision_scale( | ||
| data_type: exp.DataType, | ||
| ) -> t.Tuple[t.Optional[int], t.Optional[int]]: | ||
| expressions = data_type.expressions | ||
| if not expressions: | ||
| return None, None | ||
|
|
||
| precision = None | ||
| scale = None | ||
|
|
||
| if len(expressions) >= 1: | ||
| p_expr = expressions[0] | ||
| if isinstance(p_expr, exp.DataTypeParam) and isinstance(p_expr.this, exp.Literal): | ||
| precision = int(p_expr.this.this) if p_expr.this.is_int else None | ||
|
|
||
| if len(expressions) >= 2: | ||
| s_expr = expressions[1] | ||
| if isinstance(s_expr, exp.DataTypeParam) and isinstance(s_expr.this, exp.Literal): | ||
| scale = int(s_expr.this.this) if s_expr.this.is_int else None | ||
|
|
||
| return precision, scale | ||
|
|
||
|
|
||
| def _extract_literal_precision_scale(num_str: str) -> t.Tuple[int, int]: | ||
| d = Decimal(num_str).normalize() | ||
| s = format(d, "f").lstrip("-") | ||
|
|
||
| if "." in s: | ||
| int_part, frac_part = s.split(".", 1) | ||
| precision = len(int_part + frac_part) | ||
| scale = len(frac_part.rstrip("0")) | ||
| else: | ||
| precision = len(s) | ||
| scale = 0 | ||
| return precision, scale | ||
|
|
||
|
|
||
| def _is_float(type_: t.Optional[exp.DataType]) -> bool: | ||
| return type_ is not None and type_.is_type(exp.DataType.Type.FLOAT) | ||
|
|
||
|
|
||
| def _is_parameterized_numeric(type_: t.Optional[exp.DataType]) -> bool: | ||
| return ( | ||
| type_ is not None and type_.is_type(*exp.DataType.NUMERIC_TYPES) and bool(type_.expressions) | ||
| ) | ||
|
|
||
|
|
||
| def _get_normalized_type(expression: exp.Expression) -> t.Optional[exp.DataType]: | ||
| """ | ||
| Normalizes numeric expressions to their parameterized representation. | ||
| For literal numbers, return the parameterized representation. | ||
| For integer types, return NUMBER(38, 0). | ||
| """ | ||
| if expression.type is None: | ||
| return None | ||
|
|
||
| if expression.is_number: | ||
| precision, scale = _extract_literal_precision_scale(expression.this) | ||
| return exp.DataType( | ||
| this=exp.DataType.Type.DECIMAL, | ||
| expressions=[ | ||
| exp.DataTypeParam(this=exp.Literal.number(precision)), | ||
| exp.DataTypeParam(this=exp.Literal.number(scale)), | ||
| ], | ||
| ) | ||
|
Comment on lines
+141
to
+149
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should override the # In typing/snowflake.py
def _annotate_literal(self, expression):
expression = self._annotate_literal(expression)
if expression.is_type(exp.DataType.Type.DECIMAL):
precision, scale = _extract_literal_precision_scale(expression.this)
expression.type.set("expressions", [exp.DataTypeParam(...), exp.DataTypeParam(...)])
return expression |
||
|
|
||
| if expression.type.is_type(*exp.DataType.INTEGER_TYPES) and not expression.type.expressions: | ||
| return exp.DataType( | ||
| this=exp.DataType.Type.DECIMAL, | ||
| expressions=[ | ||
| exp.DataTypeParam(this=exp.Literal.number(38)), | ||
| exp.DataTypeParam(this=exp.Literal.number(0)), | ||
| ], | ||
| ) | ||
|
Comment on lines
+151
to
+158
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interesting change. Indeed, according to the docs, integers are really I wonder if we should hook into
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also exposes an inconsistency where at parse time we change |
||
|
|
||
| return expression.type | ||
|
|
||
|
|
||
| def _coerce_two_parameterized_types( | ||
| type1: exp.DataType, p1: int, s1: int, type2: exp.DataType, p2: int, s2: int | ||
| ) -> t.Optional[exp.DataType]: | ||
| """ | ||
| Coerce two parameterized numeric types using Snowflake's type coercion rules. | ||
|
|
||
| Rules: | ||
| - If p1 >= p2 AND s1 >= s2: return type1 | ||
| - If p2 >= p1 AND s2 >= s1: return type2 | ||
| - Otherwise: return NUMBER(min(38, max(p1, p2) + |s2 - s1|), max(s1, s2)) | ||
| """ | ||
| if p1 >= p2 and s1 >= s2: | ||
| return type1.copy() | ||
|
|
||
| if p2 >= p1 and s2 >= s1: | ||
| return type2.copy() | ||
|
Comment on lines
+174
to
+178
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to |
||
|
|
||
| result_scale = max(s1, s2) | ||
| result_precision = min(38, max(p1, p2) + abs(s2 - s1)) | ||
|
|
||
| return exp.DataType( | ||
| this=type1.this, | ||
| expressions=[ | ||
| exp.DataTypeParam(this=exp.Literal.number(result_precision)), | ||
| exp.DataTypeParam(this=exp.Literal.number(result_scale)), | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| def _coerce_parameterized_numeric_types( | ||
| types: t.List[t.Optional[exp.DataType]], | ||
| ) -> t.Optional[exp.DataType]: | ||
| """ | ||
| Generalized function to coerce multiple parameterized numeric types. | ||
| Applies Snowflake's coercion logic pairwise across all types. | ||
| """ | ||
| if not types: | ||
| return None | ||
|
|
||
| result_type = None | ||
|
|
||
| for current_type in types: | ||
| if not current_type: | ||
| continue | ||
|
|
||
| if result_type is None: | ||
| result_type = current_type | ||
| continue | ||
|
|
||
| if not _is_parameterized_numeric(result_type) or not _is_parameterized_numeric( | ||
| current_type | ||
| ): | ||
| return None | ||
|
|
||
| p1, s1 = _extract_type_precision_scale(result_type) | ||
| p2, s2 = _extract_type_precision_scale(current_type) | ||
|
|
||
| if p1 is None or s1 is None or p2 is None or s2 is None: | ||
| return None | ||
|
|
||
| result_type = _coerce_two_parameterized_types(result_type, p1, s1, current_type, p2, s2) | ||
|
|
||
| return result_type | ||
|
|
||
|
|
||
| T = t.TypeVar("T", bound=exp.Expression) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have |
||
|
|
||
|
|
||
| def _apply_numeric_coercion( | ||
| self: TypeAnnotator, | ||
| expression: T, | ||
| expressions_to_coerce: t.List[exp.Expression], | ||
| ) -> t.Optional[T]: | ||
| if any(_is_float(e.type) for e in expressions_to_coerce): | ||
| self._set_type(expression, exp.DataType.Type.FLOAT) | ||
| return expression | ||
|
Comment on lines
+236
to
+238
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels like something that should be handled by a |
||
|
|
||
| if any(_is_parameterized_numeric(e.type) for e in expressions_to_coerce): | ||
| normalized_types = [_get_normalized_type(e) for e in expressions_to_coerce] | ||
| result_type = _coerce_parameterized_numeric_types(normalized_types) | ||
| if result_type: | ||
| self._set_type(expression, result_type) | ||
| return expression | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif: | ||
| self._annotate_args(expression) | ||
|
|
||
| expressions_to_coerce = [] | ||
| if expression.this: | ||
| expressions_to_coerce.append(expression.this) | ||
| if expression.expression: | ||
| expressions_to_coerce.append(expression.expression) | ||
|
|
||
| coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce) | ||
| if coerced_result is None: | ||
| return self._annotate_by_args(expression, "this", "expression") | ||
|
|
||
| return coerced_result | ||
|
|
||
|
|
||
| def _annotate_iff(self: TypeAnnotator, expression: exp.If) -> exp.If: | ||
| self._annotate_args(expression) | ||
|
|
||
| expressions_to_coerce = [] | ||
| true_expr = expression.args.get("true") | ||
| false_expr = expression.args.get("false") | ||
|
|
||
| if true_expr: | ||
| expressions_to_coerce.append(true_expr) | ||
| if false_expr: | ||
| expressions_to_coerce.append(false_expr) | ||
|
|
||
| coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce) | ||
| if coerced_result is None: | ||
| return self._annotate_by_args(expression, "true", "false") | ||
|
|
||
| return coerced_result | ||
|
|
||
|
|
||
| def _annotate_with_numeric_coercion( | ||
| self: TypeAnnotator, expression: exp.Expression | ||
| ) -> exp.Expression: | ||
| """ | ||
| Generic annotator for functions that return one of their numeric arguments. | ||
|
|
||
| These functions all have the same structure: 'this' + 'expressions' arguments, | ||
| and they all need to coerce all argument types to find a common result type. | ||
| """ | ||
| self._annotate_args(expression) | ||
|
|
||
| expressions_to_coerce = [] | ||
| if expression.this: | ||
| expressions_to_coerce.append(expression.this) | ||
| if expression.expressions: | ||
| expressions_to_coerce.extend(expression.expressions) | ||
|
|
||
| coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce) | ||
| if coerced_result is None: | ||
| return self._annotate_by_args(expression, "this", "expressions") | ||
|
|
||
| return coerced_result | ||
|
|
||
|
|
||
| EXPRESSION_METADATA = { | ||
| **EXPRESSION_METADATA, | ||
| **{ | ||
|
|
@@ -248,6 +473,7 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex | |
| exp.Uuid, | ||
| } | ||
| }, | ||
| exp.Coalesce: {"annotator": _annotate_with_numeric_coercion}, | ||
| exp.ConcatWs: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")}, | ||
| exp.ConvertTimezone: { | ||
| "annotator": lambda self, e: self._annotate_with_type( | ||
|
|
@@ -259,10 +485,14 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex | |
| }, | ||
| exp.DateAdd: {"annotator": _annotate_date_or_time_add}, | ||
| exp.DecodeCase: {"annotator": _annotate_decode_case}, | ||
| exp.Greatest: {"annotator": _annotate_with_numeric_coercion}, | ||
| exp.GreatestIgnoreNulls: { | ||
| "annotator": lambda self, e: self._annotate_by_args(e, "expressions") | ||
| }, | ||
| exp.If: {"annotator": _annotate_iff}, | ||
| exp.Least: {"annotator": _annotate_with_numeric_coercion}, | ||
| exp.LeastIgnoreNulls: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")}, | ||
| exp.Nullif: {"annotator": _annotate_nullif}, | ||
| exp.Reverse: {"annotator": _annotate_reverse}, | ||
| exp.TimeAdd: {"annotator": _annotate_date_or_time_add}, | ||
| exp.TimestampFromParts: {"annotator": _annotate_timestamp_from_parts}, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be surprised if non-integers were syntactically/semantically valid. I think it's safe to assume that only integers can be used for the precision and scale parameters.
Additionally, numerics always have default precision and scale if missing, e.g.,
NUMBERis a synonym ofNUMBER(38, 0)(and is parsed as such). So we could also do something like:And so it should be possible to simplify this whole logic quite a bit.