@@ -2970,7 +2970,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29702970 not local_errors .has_new_errors ()
29712971 and cont_type
29722972 and self .dangerous_comparison (
2973- left_type , cont_type , original_container = right_type
2973+ left_type , cont_type , original_container = right_type , prefer_literal = False
29742974 )
29752975 ):
29762976 self .msg .dangerous_comparison (left_type , cont_type , "container" , e )
@@ -2988,21 +2988,19 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29882988 # testCustomEqCheckStrictEquality for an example.
29892989 if not w .has_new_errors () and operator in ("==" , "!=" ):
29902990 right_type = self .accept (right )
2991- # Also flag non-overlapping literals in situations like:
2992- # x: Literal['a', 'b']
2993- # if x == 'c':
2994- # ...
2995- left_type = try_getting_literal (left_type )
2996- right_type = try_getting_literal (right_type )
29972991 if self .dangerous_comparison (left_type , right_type ):
2992+ # Show the most specific literal types possible
2993+ left_type = try_getting_literal (left_type )
2994+ right_type = try_getting_literal (right_type )
29982995 self .msg .dangerous_comparison (left_type , right_type , "equality" , e )
29992996
30002997 elif operator == "is" or operator == "is not" :
30012998 right_type = self .accept (right ) # validate the right operand
30022999 sub_result = self .bool_type ()
3003- left_type = try_getting_literal (left_type )
3004- right_type = try_getting_literal (right_type )
30053000 if self .dangerous_comparison (left_type , right_type ):
3001+ # Show the most specific literal types possible
3002+ left_type = try_getting_literal (left_type )
3003+ right_type = try_getting_literal (right_type )
30063004 self .msg .dangerous_comparison (left_type , right_type , "identity" , e )
30073005 method_type = None
30083006 else :
@@ -3036,7 +3034,12 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Type | None:
30363034 return None
30373035
30383036 def dangerous_comparison (
3039- self , left : Type , right : Type , original_container : Type | None = None
3037+ self ,
3038+ left : Type ,
3039+ right : Type ,
3040+ original_container : Type | None = None ,
3041+ * ,
3042+ prefer_literal : bool = True ,
30403043 ) -> bool :
30413044 """Check for dangerous non-overlapping comparisons like 42 == 'no'.
30423045
@@ -3064,6 +3067,14 @@ def dangerous_comparison(
30643067 if custom_special_method (left , "__eq__" ) or custom_special_method (right , "__eq__" ):
30653068 return False
30663069
3070+ if prefer_literal :
3071+ # Also flag non-overlapping literals in situations like:
3072+ # x: Literal['a', 'b']
3073+ # if x == 'c':
3074+ # ...
3075+ left = try_getting_literal (left )
3076+ right = try_getting_literal (right )
3077+
30673078 if self .chk .binder .is_unreachable_warning_suppressed ():
30683079 # We are inside a function that contains type variables with value restrictions in
30693080 # its signature. In this case we just suppress all strict-equality checks to avoid
0 commit comments