Skip to content

Introduce temporary named expressions for match subjects #18446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
63 changes: 48 additions & 15 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
CallExpr,
ClassDef,
ComparisonExpr,
ComplexExpr,
Context,
ContinueStmt,
Decorator,
Expand Down Expand Up @@ -350,6 +351,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):
# functions such as open(), etc.
plugin: Plugin

# A helper state to produce unique temporary names on demand.
_unique_id: int

def __init__(
self,
errors: Errors,
Expand Down Expand Up @@ -414,6 +418,7 @@ def __init__(
self, self.msg, self.plugin, per_line_checking_time_ns
)
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
self._unique_id = 0

@property
def expr_checker(self) -> mypy.checkexpr.ExpressionChecker:
Expand Down Expand Up @@ -5413,21 +5418,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
return

def visit_match_stmt(self, s: MatchStmt) -> None:
named_subject: Expression
if isinstance(s.subject, CallExpr):
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
if s.subject_dummy is None:
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
name = "dummy-match-" + id
v = Var(name)
s.subject_dummy = NameExpr(name)
s.subject_dummy.node = v
named_subject = s.subject_dummy
else:
named_subject = s.subject

named_subject = self._make_named_statement_for_match(s)
with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))

Expand Down Expand Up @@ -5459,6 +5450,12 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
pattern_map, else_map = conditional_types_to_typemaps(
named_subject, pattern_type.type, pattern_type.rest_type
)
# Maybe the subject type can be inferred from constraints on
# its attribute/item?
if pattern_map and named_subject in pattern_map:
pattern_map[s.subject] = pattern_map[named_subject]
if else_map and named_subject in else_map:
else_map[s.subject] = else_map[named_subject]
pattern_map = self.propagate_up_typemap_info(pattern_map)
else_map = self.propagate_up_typemap_info(else_map)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
Expand Down Expand Up @@ -5506,6 +5503,36 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
with self.binder.frame_context(can_skip=False, fall_through=2):
pass

def _make_named_statement_for_match(self, s: MatchStmt) -> Expression:
"""Construct a fake NameExpr for inference if a match clause is complex."""
subject = s.subject
expressions_to_preserve = (
# Already named - we should infer type of it as given
NameExpr,
AssignmentExpr,
# Primitive literals - their type is known, no need to name them
IntExpr,
StrExpr,
BytesExpr,
FloatExpr,
ComplexExpr,
EllipsisExpr,
)
if isinstance(subject, expressions_to_preserve):
return subject
elif s.subject_dummy is not None:
return s.subject_dummy
else:
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
name = self.new_unique_dummy_name("match")
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
s.subject_dummy = named_subject
return named_subject

def _get_recursive_sub_patterns_map(
self, expr: Expression, typ: Type
) -> dict[Expression, Type]:
Expand Down Expand Up @@ -7885,6 +7912,12 @@ def warn_deprecated_overload_item(
if candidate == target:
self.warn_deprecated(item.func, context)

def new_unique_dummy_name(self, namespace: str) -> str:
"""Generate a name that is guaranteed to be unique for this TypeChecker instance."""
name = f"dummy-{namespace}-{self._unique_id}"
self._unique_id += 1
return name

# leafs

def visit_pass_stmt(self, o: PassStmt, /) -> None:
Expand Down
85 changes: 84 additions & 1 deletion test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ def main() -> None:
case a:
reveal_type(a) # N: Revealed type is "builtins.int"

[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail]
[case testMatchCapturePatternFromAsyncFunctionReturningUnion]
async def func1(arg: bool) -> str | int: ...
async def func2(arg: bool) -> bytes | int: ...

Expand Down Expand Up @@ -2586,6 +2586,89 @@ def fn2(x: Some | int | str) -> None:
pass
[builtins fixtures/dict.pyi]

[case testMatchFunctionCall]
# flags: --warn-unreachable

def fn() -> int | str: ...

match fn():
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[case testMatchAttribute]
# flags: --warn-unreachable

class A:
foo: int | str

match A().foo:
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[case testMatchOperations]
# flags: --warn-unreachable

x: int
match -x:
case -1 as s:
reveal_type(s) # N: Revealed type is "Literal[-1]"
case int(s):
reveal_type(s) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

match 1 + 2:
case 3 as s:
reveal_type(s) # N: Revealed type is "Literal[3]"
case int(s):
reveal_type(s) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

match 1 > 2:
case True as s:
reveal_type(s) # N: Revealed type is "Literal[True]"
case False as s:
reveal_type(s) # N: Revealed type is "Literal[False]"
case other:
other # E: Statement is unreachable
[builtins fixtures/ops.pyi]

[case testMatchDictItem]
# flags: --warn-unreachable

m: dict[str, int | str]
k: str

match m[k]:
case str(s):
reveal_type(s) # N: Revealed type is "builtins.str"
case int(i):
reveal_type(i) # N: Revealed type is "builtins.int"
case other:
other # E: Statement is unreachable

[builtins fixtures/dict.pyi]

[case testMatchLiteralValuePathological]
# flags: --warn-unreachable

match 0:
case 0 as i:
reveal_type(i) # N: Revealed type is "Literal[0]?"
case int(i):
i # E: Statement is unreachable
case other:
other # E: Statement is unreachable

[case testMatchNamedTupleSequence]
from typing import Any, NamedTuple

Expand Down