diff --git a/mypy/checker.py b/mypy/checker.py index 7d0b41c516e1..ccbed78d49ff 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -67,6 +67,7 @@ CallExpr, ClassDef, ComparisonExpr, + ComplexExpr, Context, ContinueStmt, Decorator, @@ -350,6 +351,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi): # functions such as open(), etc. plugin: Plugin + # A helper state to produce unique temporary names on demand. + _unique_id: int + def __init__( self, errors: Errors, @@ -414,6 +418,7 @@ def __init__( self, self.msg, self.plugin, per_line_checking_time_ns ) self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) + self._unique_id = 0 @property def expr_checker(self) -> mypy.checkexpr.ExpressionChecker: @@ -5413,21 +5418,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: return def visit_match_stmt(self, s: MatchStmt) -> None: - named_subject: Expression - if isinstance(s.subject, CallExpr): - # Create a dummy subject expression to handle cases where a match statement's subject - # is not a literal value. This lets us correctly narrow types and check exhaustivity - # This is hack! - if s.subject_dummy is None: - id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else "" - name = "dummy-match-" + id - v = Var(name) - s.subject_dummy = NameExpr(name) - s.subject_dummy.node = v - named_subject = s.subject_dummy - else: - named_subject = s.subject - + named_subject = self._make_named_statement_for_match(s) with self.binder.frame_context(can_skip=False, fall_through=0): subject_type = get_proper_type(self.expr_checker.accept(s.subject)) @@ -5459,6 +5450,12 @@ def visit_match_stmt(self, s: MatchStmt) -> None: pattern_map, else_map = conditional_types_to_typemaps( named_subject, pattern_type.type, pattern_type.rest_type ) + # Maybe the subject type can be inferred from constraints on + # its attribute/item? + if pattern_map and named_subject in pattern_map: + pattern_map[s.subject] = pattern_map[named_subject] + if else_map and named_subject in else_map: + else_map[s.subject] = else_map[named_subject] pattern_map = self.propagate_up_typemap_info(pattern_map) else_map = self.propagate_up_typemap_info(else_map) self.remove_capture_conflicts(pattern_type.captures, inferred_types) @@ -5506,6 +5503,36 @@ def visit_match_stmt(self, s: MatchStmt) -> None: with self.binder.frame_context(can_skip=False, fall_through=2): pass + def _make_named_statement_for_match(self, s: MatchStmt) -> Expression: + """Construct a fake NameExpr for inference if a match clause is complex.""" + subject = s.subject + expressions_to_preserve = ( + # Already named - we should infer type of it as given + NameExpr, + AssignmentExpr, + # Primitive literals - their type is known, no need to name them + IntExpr, + StrExpr, + BytesExpr, + FloatExpr, + ComplexExpr, + EllipsisExpr, + ) + if isinstance(subject, expressions_to_preserve): + return subject + elif s.subject_dummy is not None: + return s.subject_dummy + else: + # Create a dummy subject expression to handle cases where a match statement's subject + # is not a literal value. This lets us correctly narrow types and check exhaustivity + # This is hack! + name = self.new_unique_dummy_name("match") + v = Var(name) + named_subject = NameExpr(name) + named_subject.node = v + s.subject_dummy = named_subject + return named_subject + def _get_recursive_sub_patterns_map( self, expr: Expression, typ: Type ) -> dict[Expression, Type]: @@ -7885,6 +7912,12 @@ def warn_deprecated_overload_item( if candidate == target: self.warn_deprecated(item.func, context) + def new_unique_dummy_name(self, namespace: str) -> str: + """Generate a name that is guaranteed to be unique for this TypeChecker instance.""" + name = f"dummy-{namespace}-{self._unique_id}" + self._unique_id += 1 + return name + # leafs def visit_pass_stmt(self, o: PassStmt, /) -> None: diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 3774abfc548b..40af68a037d0 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1302,7 +1302,7 @@ def main() -> None: case a: reveal_type(a) # N: Revealed type is "builtins.int" -[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail] +[case testMatchCapturePatternFromAsyncFunctionReturningUnion] async def func1(arg: bool) -> str | int: ... async def func2(arg: bool) -> bytes | int: ... @@ -2586,6 +2586,89 @@ def fn2(x: Some | int | str) -> None: pass [builtins fixtures/dict.pyi] +[case testMatchFunctionCall] +# flags: --warn-unreachable + +def fn() -> int | str: ... + +match fn(): + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchAttribute] +# flags: --warn-unreachable + +class A: + foo: int | str + +match A().foo: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[case testMatchOperations] +# flags: --warn-unreachable + +x: int +match -x: + case -1 as s: + reveal_type(s) # N: Revealed type is "Literal[-1]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 + 2: + case 3 as s: + reveal_type(s) # N: Revealed type is "Literal[3]" + case int(s): + reveal_type(s) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +match 1 > 2: + case True as s: + reveal_type(s) # N: Revealed type is "Literal[True]" + case False as s: + reveal_type(s) # N: Revealed type is "Literal[False]" + case other: + other # E: Statement is unreachable +[builtins fixtures/ops.pyi] + +[case testMatchDictItem] +# flags: --warn-unreachable + +m: dict[str, int | str] +k: str + +match m[k]: + case str(s): + reveal_type(s) # N: Revealed type is "builtins.str" + case int(i): + reveal_type(i) # N: Revealed type is "builtins.int" + case other: + other # E: Statement is unreachable + +[builtins fixtures/dict.pyi] + +[case testMatchLiteralValuePathological] +# flags: --warn-unreachable + +match 0: + case 0 as i: + reveal_type(i) # N: Revealed type is "Literal[0]?" + case int(i): + i # E: Statement is unreachable + case other: + other # E: Statement is unreachable + [case testMatchNamedTupleSequence] from typing import Any, NamedTuple