Skip to content

Commit a08388c

Browse files
authored
Fix crash in astdiff and clean it up (#14497)
Ref #14329 This fixes one of the crashes reported in the issue. In fact, using recursive type caught this crash statically, plus another subtle crash in `snapshot_optional_type()`, _without a single false positive_ (I was able to cleanly type also symbol table snapshots, but decided it is not worth the churn since we only ever compare them with `==`, supported by ~every Python object). I feel triumphant :-)
1 parent e8c844b commit a08388c

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

mypy/server/astdiff.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
5252

5353
from __future__ import annotations
5454

55-
from typing import Sequence, Tuple, cast
55+
from typing import Sequence, Tuple, Union, cast
5656
from typing_extensions import TypeAlias as _TypeAlias
5757

5858
from mypy.expandtype import expand_type
@@ -109,11 +109,17 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
109109
# snapshots are immutable).
110110
#
111111
# For example, the snapshot of the 'int' type is ('Instance', 'builtins.int', ()).
112-
SnapshotItem: _TypeAlias = Tuple[object, ...]
112+
113+
# Type snapshots are strict, they must be hashable and ordered (e.g. for Unions).
114+
Primitive: _TypeAlias = Union[str, float, int, bool] # float is for Literal[3.14] support.
115+
SnapshotItem: _TypeAlias = Tuple[Union[Primitive, "SnapshotItem"], ...]
116+
117+
# Symbol snapshots can be more lenient.
118+
SymbolSnapshot: _TypeAlias = Tuple[object, ...]
113119

114120

115121
def compare_symbol_table_snapshots(
116-
name_prefix: str, snapshot1: dict[str, SnapshotItem], snapshot2: dict[str, SnapshotItem]
122+
name_prefix: str, snapshot1: dict[str, SymbolSnapshot], snapshot2: dict[str, SymbolSnapshot]
117123
) -> set[str]:
118124
"""Return names that are different in two snapshots of a symbol table.
119125
@@ -155,7 +161,7 @@ def compare_symbol_table_snapshots(
155161
return triggers
156162

157163

158-
def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SnapshotItem]:
164+
def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, SymbolSnapshot]:
159165
"""Create a snapshot description that represents the state of a symbol table.
160166
161167
The snapshot has a representation based on nested tuples and dicts
@@ -165,7 +171,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
165171
things defined in other modules are represented just by the names of
166172
the targets.
167173
"""
168-
result: dict[str, SnapshotItem] = {}
174+
result: dict[str, SymbolSnapshot] = {}
169175
for name, symbol in table.items():
170176
node = symbol.node
171177
# TODO: cross_ref?
@@ -206,7 +212,7 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> dict[str, Sna
206212
return result
207213

208214

209-
def snapshot_definition(node: SymbolNode | None, common: tuple[object, ...]) -> tuple[object, ...]:
215+
def snapshot_definition(node: SymbolNode | None, common: SymbolSnapshot) -> SymbolSnapshot:
210216
"""Create a snapshot description of a symbol table node.
211217
212218
The representation is nested tuples and dicts. Only externally
@@ -290,11 +296,11 @@ def snapshot_type(typ: Type) -> SnapshotItem:
290296
return typ.accept(SnapshotTypeVisitor())
291297

292298

293-
def snapshot_optional_type(typ: Type | None) -> SnapshotItem | None:
299+
def snapshot_optional_type(typ: Type | None) -> SnapshotItem:
294300
if typ:
295301
return snapshot_type(typ)
296302
else:
297-
return None
303+
return ("<not set>",)
298304

299305

300306
def snapshot_types(types: Sequence[Type]) -> SnapshotItem:
@@ -396,7 +402,7 @@ def visit_parameters(self, typ: Parameters) -> SnapshotItem:
396402
"Parameters",
397403
snapshot_types(typ.arg_types),
398404
tuple(encode_optional_str(name) for name in typ.arg_names),
399-
tuple(typ.arg_kinds),
405+
tuple(k.value for k in typ.arg_kinds),
400406
)
401407

402408
def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
@@ -407,7 +413,7 @@ def visit_callable_type(self, typ: CallableType) -> SnapshotItem:
407413
snapshot_types(typ.arg_types),
408414
snapshot_type(typ.ret_type),
409415
tuple(encode_optional_str(name) for name in typ.arg_names),
410-
tuple(typ.arg_kinds),
416+
tuple(k.value for k in typ.arg_kinds),
411417
typ.is_type_obj(),
412418
typ.is_ellipsis_args,
413419
snapshot_types(typ.variables),
@@ -464,7 +470,7 @@ def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem:
464470
return ("TypeAliasType", typ.alias.fullname, snapshot_types(typ.args))
465471

466472

467-
def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[object, ...]:
473+
def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> SymbolSnapshot:
468474
"""Create a snapshot of the signature of a function that has no explicit signature.
469475
470476
If the arguments to a function without signature change, it must be
@@ -476,7 +482,7 @@ def snapshot_untyped_signature(func: OverloadedFuncDef | FuncItem) -> tuple[obje
476482
if isinstance(func, FuncItem):
477483
return (tuple(func.arg_names), tuple(func.arg_kinds))
478484
else:
479-
result = []
485+
result: list[SymbolSnapshot] = []
480486
for item in func.items:
481487
if isinstance(item, Decorator):
482488
if item.var.type:

mypy/server/update.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,11 @@
151151
semantic_analysis_for_scc,
152152
semantic_analysis_for_targets,
153153
)
154-
from mypy.server.astdiff import SnapshotItem, compare_symbol_table_snapshots, snapshot_symbol_table
154+
from mypy.server.astdiff import (
155+
SymbolSnapshot,
156+
compare_symbol_table_snapshots,
157+
snapshot_symbol_table,
158+
)
155159
from mypy.server.astmerge import merge_asts
156160
from mypy.server.aststrip import SavedAttributes, strip_target
157161
from mypy.server.deps import get_dependencies_of_target, merge_dependencies
@@ -417,7 +421,7 @@ def update_module(
417421

418422
t0 = time.time()
419423
# Record symbol table snapshot of old version the changed module.
420-
old_snapshots: dict[str, dict[str, SnapshotItem]] = {}
424+
old_snapshots: dict[str, dict[str, SymbolSnapshot]] = {}
421425
if module in manager.modules:
422426
snapshot = snapshot_symbol_table(module, manager.modules[module].names)
423427
old_snapshots[module] = snapshot
@@ -751,7 +755,7 @@ def get_sources(
751755

752756
def calculate_active_triggers(
753757
manager: BuildManager,
754-
old_snapshots: dict[str, dict[str, SnapshotItem]],
758+
old_snapshots: dict[str, dict[str, SymbolSnapshot]],
755759
new_modules: dict[str, MypyFile | None],
756760
) -> set[str]:
757761
"""Determine activated triggers by comparing old and new symbol tables.

test-data/unit/fine-grained.test

+28
Original file line numberDiff line numberDiff line change
@@ -10313,3 +10313,31 @@ a.py:3: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#varia
1031310313
a.py:4: note: Revealed type is "A?"
1031410314
==
1031510315
a.py:4: note: Revealed type is "Union[builtins.str, builtins.int]"
10316+
10317+
[case testUnionOfSimilarCallablesCrash]
10318+
import b
10319+
10320+
[file b.py]
10321+
from a import x
10322+
10323+
[file m.py]
10324+
from typing import Union, TypeVar
10325+
10326+
T = TypeVar("T")
10327+
S = TypeVar("S")
10328+
def foo(x: T, y: S) -> Union[T, S]: ...
10329+
def f(x: int) -> int: ...
10330+
def g(*x: int) -> int: ...
10331+
10332+
[file a.py]
10333+
from m import f, g, foo
10334+
x = foo(f, g)
10335+
10336+
[file a.py.2]
10337+
from m import f, g, foo
10338+
x = foo(f, g)
10339+
reveal_type(x)
10340+
[builtins fixtures/tuple.pyi]
10341+
[out]
10342+
==
10343+
a.py:3: note: Revealed type is "Union[def (x: builtins.int) -> builtins.int, def (*x: builtins.int) -> builtins.int]"

0 commit comments

Comments
 (0)