Skip to content

Commit 7946b90

Browse files
Michael LeeMichael Lee
authored andcommitted
feat(optimizer)!: Annotate numeric parameterized types for NULLIF
1 parent 7f93e85 commit 7946b90

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,102 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
573573
return expression
574574

575575

576+
def _extract_precision_scale(data_type: exp.DataType) -> t.Tuple[t.Optional[int], t.Optional[int]]:
577+
"""Extract precision and scale from a parameterized numeric type."""
578+
expressions = data_type.expressions
579+
if not expressions:
580+
return None, None
581+
582+
precision = None
583+
scale = None
584+
585+
if len(expressions) >= 1:
586+
p_expr = expressions[0]
587+
if isinstance(p_expr, exp.DataTypeParam) and isinstance(p_expr.this, exp.Literal):
588+
precision = int(p_expr.this.this) if p_expr.this.is_int else None
589+
590+
if len(expressions) >= 2:
591+
s_expr = expressions[1]
592+
if isinstance(s_expr, exp.DataTypeParam) and isinstance(s_expr.this, exp.Literal):
593+
scale = int(s_expr.this.this) if s_expr.this.is_int else None
594+
595+
return precision, scale
596+
597+
598+
def _compute_nullif_result_type(
599+
base_type: exp.DataType, p1: int, s1: int, p2: int, s2: int
600+
) -> t.Optional[exp.DataType]:
601+
"""
602+
Compute result type for NULLIF with two parameterized numeric types.
603+
604+
Rules:
605+
- If p1 >= p2 AND s1 >= s2: return type1
606+
- If p2 >= p1 AND s2 >= s1: return type2
607+
- Otherwise: return DECIMAL(max(p1, p2) + |s2 - s1|, max(s1, s2))
608+
"""
609+
610+
if p1 >= p2 and s1 >= s2:
611+
return base_type.copy()
612+
613+
if p2 >= p1 and s2 >= s1:
614+
return exp.DataType(
615+
this=base_type.this,
616+
expressions=[
617+
exp.DataTypeParam(this=exp.Literal.number(p2)),
618+
exp.DataTypeParam(this=exp.Literal.number(s2)),
619+
],
620+
)
621+
622+
result_scale = max(s1, s2)
623+
result_precision = max(p1, p2) + abs(s2 - s1)
624+
625+
return exp.DataType(
626+
this=base_type.this,
627+
expressions=[
628+
exp.DataTypeParam(this=exp.Literal.number(result_precision)),
629+
exp.DataTypeParam(this=exp.Literal.number(result_scale)),
630+
],
631+
)
632+
633+
634+
def _annotate_nullif(self: TypeAnnotator, expression: exp.Nullif) -> exp.Nullif:
635+
"""
636+
Annotate NULLIF with Snowflake-specific type coercion rules for parameterized numeric types.
637+
638+
When both arguments are parameterized numeric types (e.g., DECIMAL(p, s)):
639+
- If one type dominates (p1 >= p2 AND s1 >= s2), use that type
640+
- Otherwise, compute new type with:
641+
- scale = max(s1, s2)
642+
- precision = max(p1, p2) + |s2 - s1|
643+
"""
644+
645+
self._annotate_args(expression)
646+
647+
this_type = expression.this.type
648+
expr_type = expression.expression.type
649+
650+
if not this_type or not expr_type:
651+
return self._annotate_by_args(expression, "this", "expression")
652+
653+
# Snowflake specific type coercion for NULLIF with parameterized numeric types
654+
if (
655+
this_type.is_type(*exp.DataType.NUMERIC_TYPES)
656+
and expr_type.is_type(*exp.DataType.NUMERIC_TYPES)
657+
and this_type.expressions
658+
and expr_type.expressions
659+
):
660+
p1, s1 = _extract_precision_scale(this_type)
661+
p2, s2 = _extract_precision_scale(expr_type)
662+
663+
if p1 is not None and s1 is not None and p2 is not None and s2 is not None:
664+
result_type = _compute_nullif_result_type(this_type, p1, s1, p2, s2)
665+
if result_type:
666+
self._set_type(expression, result_type)
667+
return expression
668+
669+
return self._annotate_by_args(expression, "this", "expression")
670+
671+
576672
def _annotate_timestamp_from_parts(
577673
self: TypeAnnotator, expression: exp.TimestampFromParts
578674
) -> exp.TimestampFromParts:
@@ -795,6 +891,7 @@ class Snowflake(Dialect):
795891
exp.GreatestIgnoreNulls: lambda self, e: self._annotate_by_args(e, "expressions"),
796892
exp.LeastIgnoreNulls: lambda self, e: self._annotate_by_args(e, "expressions"),
797893
exp.DecodeCase: _annotate_decode_case,
894+
exp.Nullif: _annotate_nullif,
798895
exp.Reverse: _annotate_reverse,
799896
exp.TimestampFromParts: _annotate_timestamp_from_parts,
800897
}

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,6 +2543,26 @@ BIGINT;
25432543
NULLIF(1::INT, 2.5::DOUBLE);
25442544
DOUBLE;
25452545

2546+
# dialect: snowflake
2547+
NULLIF(1::DECIMAL(10, 2), 2::DECIMAL(10, 2));
2548+
DECIMAL(10, 2);
2549+
2550+
# dialect: snowflake
2551+
NULLIF(1::DECIMAL(12, 3), 2::DECIMAL(10, 2));
2552+
DECIMAL(12, 3);
2553+
2554+
# dialect: snowflake
2555+
NULLIF(1::DECIMAL(10, 2), 2::DECIMAL(12, 3));
2556+
DECIMAL(12, 3);
2557+
2558+
# dialect: snowflake
2559+
NULLIF(1::DECIMAL(12, 2), 2::DECIMAL(10, 3));
2560+
DECIMAL(13, 3);
2561+
2562+
# dialect: snowflake
2563+
NULLIF(1::DECIMAL(10, 3), 2::DECIMAL(12, 2));
2564+
DECIMAL(13, 3);
2565+
25462566
# dialect: snowflake
25472567
NULLIFZERO(5);
25482568
INT;

0 commit comments

Comments
 (0)