@@ -573,6 +573,102 @@ def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> ex
573573 return expression
574574
575575
576+ def _extract_precision_scale (data_type : exp .DataType ) -> t .Tuple [t .Optional [int ], t .Optional [int ]]:
577+ """Extract precision and scale from a parameterized numeric type."""
578+ expressions = data_type .expressions
579+ if not expressions :
580+ return None , None
581+
582+ precision = None
583+ scale = None
584+
585+ if len (expressions ) >= 1 :
586+ p_expr = expressions [0 ]
587+ if isinstance (p_expr , exp .DataTypeParam ) and isinstance (p_expr .this , exp .Literal ):
588+ precision = int (p_expr .this .this ) if p_expr .this .is_int else None
589+
590+ if len (expressions ) >= 2 :
591+ s_expr = expressions [1 ]
592+ if isinstance (s_expr , exp .DataTypeParam ) and isinstance (s_expr .this , exp .Literal ):
593+ scale = int (s_expr .this .this ) if s_expr .this .is_int else None
594+
595+ return precision , scale
596+
597+
598+ def _compute_nullif_result_type (
599+ base_type : exp .DataType , p1 : int , s1 : int , p2 : int , s2 : int
600+ ) -> t .Optional [exp .DataType ]:
601+ """
602+ Compute result type for NULLIF with two parameterized numeric types.
603+
604+ Rules:
605+ - If p1 >= p2 AND s1 >= s2: return type1
606+ - If p2 >= p1 AND s2 >= s1: return type2
607+ - Otherwise: return DECIMAL(max(p1, p2) + |s2 - s1|, max(s1, s2))
608+ """
609+
610+ if p1 >= p2 and s1 >= s2 :
611+ return base_type .copy ()
612+
613+ if p2 >= p1 and s2 >= s1 :
614+ return exp .DataType (
615+ this = base_type .this ,
616+ expressions = [
617+ exp .DataTypeParam (this = exp .Literal .number (p2 )),
618+ exp .DataTypeParam (this = exp .Literal .number (s2 )),
619+ ],
620+ )
621+
622+ result_scale = max (s1 , s2 )
623+ result_precision = max (p1 , p2 ) + abs (s2 - s1 )
624+
625+ return exp .DataType (
626+ this = base_type .this ,
627+ expressions = [
628+ exp .DataTypeParam (this = exp .Literal .number (result_precision )),
629+ exp .DataTypeParam (this = exp .Literal .number (result_scale )),
630+ ],
631+ )
632+
633+
634+ def _annotate_nullif (self : TypeAnnotator , expression : exp .Nullif ) -> exp .Nullif :
635+ """
636+ Annotate NULLIF with Snowflake-specific type coercion rules for parameterized numeric types.
637+
638+ When both arguments are parameterized numeric types (e.g., DECIMAL(p, s)):
639+ - If one type dominates (p1 >= p2 AND s1 >= s2), use that type
640+ - Otherwise, compute new type with:
641+ - scale = max(s1, s2)
642+ - precision = max(p1, p2) + |s2 - s1|
643+ """
644+
645+ self ._annotate_args (expression )
646+
647+ this_type = expression .this .type
648+ expr_type = expression .expression .type
649+
650+ if not this_type or not expr_type :
651+ return self ._annotate_by_args (expression , "this" , "expression" )
652+
653+ # Snowflake specific type coercion for NULLIF with parameterized numeric types
654+ if (
655+ this_type .is_type (* exp .DataType .NUMERIC_TYPES )
656+ and expr_type .is_type (* exp .DataType .NUMERIC_TYPES )
657+ and this_type .expressions
658+ and expr_type .expressions
659+ ):
660+ p1 , s1 = _extract_precision_scale (this_type )
661+ p2 , s2 = _extract_precision_scale (expr_type )
662+
663+ if p1 is not None and s1 is not None and p2 is not None and s2 is not None :
664+ result_type = _compute_nullif_result_type (this_type , p1 , s1 , p2 , s2 )
665+ if result_type :
666+ self ._set_type (expression , result_type )
667+ return expression
668+
669+ return self ._annotate_by_args (expression , "this" , "expression" )
670+
671+
576672def _annotate_timestamp_from_parts (
577673 self : TypeAnnotator , expression : exp .TimestampFromParts
578674) -> exp .TimestampFromParts :
@@ -795,6 +891,7 @@ class Snowflake(Dialect):
795891 exp .GreatestIgnoreNulls : lambda self , e : self ._annotate_by_args (e , "expressions" ),
796892 exp .LeastIgnoreNulls : lambda self , e : self ._annotate_by_args (e , "expressions" ),
797893 exp .DecodeCase : _annotate_decode_case ,
894+ exp .Nullif : _annotate_nullif ,
798895 exp .Reverse : _annotate_reverse ,
799896 exp .TimestampFromParts : _annotate_timestamp_from_parts ,
800897 }
0 commit comments