Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions sqlglot/typing/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t

from decimal import Decimal
from sqlglot import exp
from sqlglot.typing import EXPRESSION_METADATA

Expand Down Expand Up @@ -81,6 +82,230 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
return expression


def _extract_type_precision_scale(
data_type: exp.DataType,
) -> t.Tuple[t.Optional[int], t.Optional[int]]:
expressions = data_type.expressions
if not expressions:
return None, None

precision = None
scale = None

if len(expressions) >= 1:
p_expr = expressions[0]
if isinstance(p_expr, exp.DataTypeParam) and isinstance(p_expr.this, exp.Literal):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be surprised if non-integers were syntactically/semantically valid. I think it's safe to assume that only integers can be used for the precision and scale parameters.

Additionally, numerics always have default precision and scale if missing, e.g., NUMBER is a synonym of NUMBER(38, 0) (and is parsed as such). So we could also do something like:

precision_value = precision.this.to_py() if isinstance(precision, exp.DataTypeParam) or 38
scale_value = scale.this.to_py() if isinstance(scale, exp.DataTypeParam) or 0

And so it should be possible to simplify this whole logic quite a bit.

precision = int(p_expr.this.this) if p_expr.this.is_int else None

if len(expressions) >= 2:
s_expr = expressions[1]
if isinstance(s_expr, exp.DataTypeParam) and isinstance(s_expr.this, exp.Literal):
scale = int(s_expr.this.this) if s_expr.this.is_int else None

return precision, scale


def _extract_literal_precision_scale(num_str: str) -> t.Tuple[int, int]:
d = Decimal(num_str).normalize()
s = format(d, "f").lstrip("-")

if "." in s:
int_part, frac_part = s.split(".", 1)
precision = len(int_part + frac_part)
scale = len(frac_part.rstrip("0"))
else:
precision = len(s)
scale = 0
return precision, scale


def _is_float(type_: t.Optional[exp.DataType]) -> bool:
return type_ is not None and type_.is_type(exp.DataType.Type.FLOAT)


def _is_parameterized_numeric(type_: t.Optional[exp.DataType]) -> bool:
return (
type_ is not None and type_.is_type(*exp.DataType.NUMERIC_TYPES) and bool(type_.expressions)
)


def _get_normalized_type(expression: exp.Expression) -> t.Optional[exp.DataType]:
"""
Normalizes numeric expressions to their parameterized representation.
For literal numbers, return the parameterized representation.
For integer types, return NUMBER(38, 0).
"""
if expression.type is None:
return None

if expression.is_number:
precision, scale = _extract_literal_precision_scale(expression.this)
return exp.DataType(
this=exp.DataType.Type.DECIMAL,
expressions=[
exp.DataTypeParam(this=exp.Literal.number(precision)),
exp.DataTypeParam(this=exp.Literal.number(scale)),
],
)
Comment on lines +141 to +149
Copy link
Collaborator

@georgesittas georgesittas Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should override the Literal annotator for this:

# In typing/snowflake.py
def _annotate_literal(self, expression):
    expression = self._annotate_literal(expression)
    if expression.is_type(exp.DataType.Type.DECIMAL):
        precision, scale = _extract_literal_precision_scale(expression.this)
        expression.type.set("expressions", [exp.DataTypeParam(...), exp.DataTypeParam(...)])

    return expression


if expression.type.is_type(*exp.DataType.INTEGER_TYPES) and not expression.type.expressions:
return exp.DataType(
this=exp.DataType.Type.DECIMAL,
expressions=[
exp.DataTypeParam(this=exp.Literal.number(38)),
exp.DataTypeParam(this=exp.Literal.number(0)),
],
)
Comment on lines +151 to +158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting change. Indeed, according to the docs, integers are really NUMBER(38, 0).

I wonder if we should hook into _set_type for Snowflake and intercept integer type annotations in order to automatically change them to NUMBER(38, 0) and coercion works as expected... This feels like a much bigger change, though.

(sqlglot) ➜  sqlglot git:(main) snow sql -q "select system\$typeof(cast(1 as int) + cast(2 as number(5, 3)))"
select system$typeof(cast(1 as int) + cast(2 as number(5, 3)))
+---------------------------------------------------------+
| SYSTEM$TYPEOF(CAST(1 AS INT) + CAST(2 AS NUMBER(5, 3))) |
|---------------------------------------------------------|
| NUMBER(38,3)[SB2]                                       |
+---------------------------------------------------------+

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also exposes an inconsistency where at parse time we change NUMBER (or equivalent, e.g. DECIMAL) to NUMBER(38, 0), but we don't do it for these integer types...


return expression.type


def _coerce_two_parameterized_types(
type1: exp.DataType, p1: int, s1: int, type2: exp.DataType, p2: int, s2: int
) -> t.Optional[exp.DataType]:
"""
Coerce two parameterized numeric types using Snowflake's type coercion rules.

Rules:
- If p1 >= p2 AND s1 >= s2: return type1
- If p2 >= p1 AND s2 >= s1: return type2
- Otherwise: return NUMBER(min(38, max(p1, p2) + |s2 - s1|), max(s1, s2))
"""
if p1 >= p2 and s1 >= s2:
return type1.copy()

if p2 >= p1 and s2 >= s1:
return type2.copy()
Comment on lines +174 to +178
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to copy() here, this is taken care of in _set_type -> type.setter -> DataType.build (copies by default)


result_scale = max(s1, s2)
result_precision = min(38, max(p1, p2) + abs(s2 - s1))

return exp.DataType(
this=type1.this,
expressions=[
exp.DataTypeParam(this=exp.Literal.number(result_precision)),
exp.DataTypeParam(this=exp.Literal.number(result_scale)),
],
)


def _coerce_parameterized_numeric_types(
types: t.List[t.Optional[exp.DataType]],
) -> t.Optional[exp.DataType]:
"""
Generalized function to coerce multiple parameterized numeric types.
Applies Snowflake's coercion logic pairwise across all types.
"""
if not types:
return None

result_type = None

for current_type in types:
if not current_type:
continue

if result_type is None:
result_type = current_type
continue

if not _is_parameterized_numeric(result_type) or not _is_parameterized_numeric(
current_type
):
return None

p1, s1 = _extract_type_precision_scale(result_type)
p2, s2 = _extract_type_precision_scale(current_type)

if p1 is None or s1 is None or p2 is None or s2 is None:
return None

result_type = _coerce_two_parameterized_types(result_type, p1, s1, current_type, p2, s2)

return result_type


T = t.TypeVar("T", bound=exp.Expression)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have E in sqlglot._typing and can use that here instead.



def _apply_numeric_coercion(
self: TypeAnnotator,
expression: T,
expressions_to_coerce: t.List[exp.Expression],
) -> t.Optional[T]:
if any(_is_float(e.type) for e in expressions_to_coerce):
self._set_type(expression, exp.DataType.Type.FLOAT)
return expression
Comment on lines +236 to +238
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like something that should be handled by a COERCES_TO override for Snowflake, i.e. NUMBER always coerces to FLOAT:

(sqlglot) ➜  sqlglot git:(main) snow sql -q "select system\$typeof(coalesce(cast(1 as number(4, 2)), cast(1 as float)))"
select system$typeof(coalesce(cast(1 as number(4, 2)), cast(1 as float)))
+--------------------------------------------------------------------+
| SYSTEM$TYPEOF(COALESCE(CAST(1 AS NUMBER(4, 2)), CAST(1 AS FLOAT))) |
|--------------------------------------------------------------------|
| FLOAT[DOUBLE]                                                      |
+--------------------------------------------------------------------+


if any(_is_parameterized_numeric(e.type) for e in expressions_to_coerce):
normalized_types = [_get_normalized_type(e) for e in expressions_to_coerce]
result_type = _coerce_parameterized_numeric_types(normalized_types)
if result_type:
self._set_type(expression, result_type)
return expression

return None


def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
self._annotate_args(expression)

expressions_to_coerce = []
if expression.this:
expressions_to_coerce.append(expression.this)
if expression.expression:
expressions_to_coerce.append(expression.expression)

coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce)
if coerced_result is None:
return self._annotate_by_args(expression, "this", "expression")

return coerced_result


def _annotate_iff(self: TypeAnnotator, expression: exp.If) -> exp.If:
self._annotate_args(expression)

expressions_to_coerce = []
true_expr = expression.args.get("true")
false_expr = expression.args.get("false")

if true_expr:
expressions_to_coerce.append(true_expr)
if false_expr:
expressions_to_coerce.append(false_expr)

coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce)
if coerced_result is None:
return self._annotate_by_args(expression, "true", "false")

return coerced_result


def _annotate_with_numeric_coercion(
self: TypeAnnotator, expression: exp.Expression
) -> exp.Expression:
"""
Generic annotator for functions that return one of their numeric arguments.

These functions all have the same structure: 'this' + 'expressions' arguments,
and they all need to coerce all argument types to find a common result type.
"""
self._annotate_args(expression)

expressions_to_coerce = []
if expression.this:
expressions_to_coerce.append(expression.this)
if expression.expressions:
expressions_to_coerce.extend(expression.expressions)

coerced_result = _apply_numeric_coercion(self, expression, expressions_to_coerce)
if coerced_result is None:
return self._annotate_by_args(expression, "this", "expressions")

return coerced_result


EXPRESSION_METADATA = {
**EXPRESSION_METADATA,
**{
Expand Down Expand Up @@ -248,6 +473,7 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
exp.Uuid,
}
},
exp.Coalesce: {"annotator": _annotate_with_numeric_coercion},
exp.ConcatWs: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")},
exp.ConvertTimezone: {
"annotator": lambda self, e: self._annotate_with_type(
Expand All @@ -259,10 +485,14 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
},
exp.DateAdd: {"annotator": _annotate_date_or_time_add},
exp.DecodeCase: {"annotator": _annotate_decode_case},
exp.Greatest: {"annotator": _annotate_with_numeric_coercion},
exp.GreatestIgnoreNulls: {
"annotator": lambda self, e: self._annotate_by_args(e, "expressions")
},
exp.If: {"annotator": _annotate_iff},
exp.Least: {"annotator": _annotate_with_numeric_coercion},
exp.LeastIgnoreNulls: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")},
exp.Nullif: {"annotator": _annotate_nullif},
exp.Reverse: {"annotator": _annotate_reverse},
exp.TimeAdd: {"annotator": _annotate_date_or_time_add},
exp.TimestampFromParts: {"annotator": _annotate_timestamp_from_parts},
Expand Down
Loading