Skip to content

Commit bac9e77

Browse files
authored
[dataclass_transform] minimal implementation of dataclass_transform (#14523)
This is a very simple first step to implementing [PEP 0681](https://peps.python.org/pep-0681/#decorator-function-example), which will allow MyPy to recognize user-defined types that behave similarly to dataclasses. This initial implementation is very limited: we only support decorator-style use of `typing.dataclass_transform` and do not support passing additional options to the transform (such as `freeze` or `init`). Within MyPy, we add a new `is_dataclass_transform` field to `FuncBase` which is populated during semantic analysis. When we check for plugin hooks later, we add new special cases to use the existing dataclasses plugin if a class decorator is marked with `is_dataclass_transform`. Ideally we would use a proper plugin API; the hacky special case here can be replaced in subsequent iterations. Co-authored-by: Wesley Wright <[email protected]>
1 parent 6442b02 commit bac9e77

10 files changed

+106
-9
lines changed

mypy/nodes.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,13 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
480480
return visitor.visit_import_all(self)
481481

482482

483-
FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"]
483+
FUNCBASE_FLAGS: Final = [
484+
"is_property",
485+
"is_class",
486+
"is_static",
487+
"is_final",
488+
"is_dataclass_transform",
489+
]
484490

485491

486492
class FuncBase(Node):
@@ -506,6 +512,7 @@ class FuncBase(Node):
506512
"is_static", # Uses "@staticmethod"
507513
"is_final", # Uses "@final"
508514
"_fullname",
515+
"is_dataclass_transform", # Is decorated with "@typing.dataclass_transform" or similar
509516
)
510517

511518
def __init__(self) -> None:
@@ -524,6 +531,7 @@ def __init__(self) -> None:
524531
self.is_final = False
525532
# Name with module prefix
526533
self._fullname = ""
534+
self.is_dataclass_transform = False
527535

528536
@property
529537
@abstractmethod

mypy/plugins/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Var,
2020
)
2121
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
22-
from mypy.semanal import ALLOW_INCOMPATIBLE_OVERRIDE, set_callable_name
22+
from mypy.semanal_shared import ALLOW_INCOMPATIBLE_OVERRIDE, set_callable_name
2323
from mypy.typeops import ( # noqa: F401 # Part of public API
2424
try_getting_str_literals as try_getting_str_literals,
2525
)

mypy/semanal.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@
194194
Plugin,
195195
SemanticAnalyzerPluginInterface,
196196
)
197+
from mypy.plugins import dataclasses as dataclasses_plugin
197198
from mypy.reachability import (
198199
ALWAYS_FALSE,
199200
ALWAYS_TRUE,
@@ -208,6 +209,7 @@
208209
from mypy.semanal_namedtuple import NamedTupleAnalyzer
209210
from mypy.semanal_newtype import NewTypeAnalyzer
210211
from mypy.semanal_shared import (
212+
ALLOW_INCOMPATIBLE_OVERRIDE,
211213
PRIORITY_FALLBACKS,
212214
SemanticAnalyzerInterface,
213215
calculate_tuple_fallback,
@@ -234,6 +236,7 @@
234236
from mypy.typeops import function_type, get_type_vars
235237
from mypy.types import (
236238
ASSERT_TYPE_NAMES,
239+
DATACLASS_TRANSFORM_NAMES,
237240
FINAL_DECORATOR_NAMES,
238241
FINAL_TYPE_NAMES,
239242
NEVER_NAMES,
@@ -304,10 +307,6 @@
304307
# available very early on.
305308
CORE_BUILTIN_CLASSES: Final = ["object", "bool", "function"]
306309

307-
# Subclasses can override these Var attributes with incompatible types. This can also be
308-
# set for individual attributes using 'allow_incompatible_override' of Var.
309-
ALLOW_INCOMPATIBLE_OVERRIDE: Final = ("__slots__", "__deletable__", "__match_args__")
310-
311310

312311
# Used for tracking incomplete references
313312
Tag: _TypeAlias = int
@@ -1508,6 +1507,10 @@ def visit_decorator(self, dec: Decorator) -> None:
15081507
removed.append(i)
15091508
else:
15101509
self.fail("@final cannot be used with non-method functions", d)
1510+
elif isinstance(d, CallExpr) and refers_to_fullname(
1511+
d.callee, DATACLASS_TRANSFORM_NAMES
1512+
):
1513+
dec.func.is_dataclass_transform = True
15111514
elif not dec.var.is_property:
15121515
# We have seen a "non-trivial" decorator before seeing @property, if
15131516
# we will see a @property later, give an error, as we don't support this.
@@ -1709,6 +1712,11 @@ def apply_class_plugin_hooks(self, defn: ClassDef) -> None:
17091712
decorator_name = self.get_fullname_for_hook(decorator)
17101713
if decorator_name:
17111714
hook = self.plugin.get_class_decorator_hook(decorator_name)
1715+
# Special case: if the decorator is itself decorated with
1716+
# typing.dataclass_transform, apply the hook for the dataclasses plugin
1717+
# TODO: remove special casing here
1718+
if hook is None and is_dataclass_transform_decorator(decorator):
1719+
hook = dataclasses_plugin.dataclass_tag_callback
17121720
if hook:
17131721
hook(ClassDefContext(defn, decorator, self))
17141722

@@ -6599,3 +6607,10 @@ def halt(self, reason: str = ...) -> NoReturn:
65996607
return isinstance(stmt, PassStmt) or (
66006608
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
66016609
)
6610+
6611+
6612+
def is_dataclass_transform_decorator(node: Node | None) -> bool:
6613+
if isinstance(node, RefExpr):
6614+
return is_dataclass_transform_decorator(node.node)
6615+
6616+
return isinstance(node, Decorator) and node.func.is_dataclass_transform

mypy/semanal_main.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@
3737
from mypy.nodes import Decorator, FuncDef, MypyFile, OverloadedFuncDef, TypeInfo, Var
3838
from mypy.options import Options
3939
from mypy.plugin import ClassDefContext
40+
from mypy.plugins import dataclasses as dataclasses_plugin
4041
from mypy.semanal import (
4142
SemanticAnalyzer,
4243
apply_semantic_analyzer_patches,
44+
is_dataclass_transform_decorator,
4345
remove_imported_names_from_symtable,
4446
)
4547
from mypy.semanal_classprop import (
@@ -457,11 +459,19 @@ def apply_hooks_to_class(
457459
ok = True
458460
for decorator in defn.decorators:
459461
with self.file_context(file_node, options, info):
462+
hook = None
463+
460464
decorator_name = self.get_fullname_for_hook(decorator)
461465
if decorator_name:
462466
hook = self.plugin.get_class_decorator_hook_2(decorator_name)
463-
if hook:
464-
ok = ok and hook(ClassDefContext(defn, decorator, self))
467+
# Special case: if the decorator is itself decorated with
468+
# typing.dataclass_transform, apply the hook for the dataclasses plugin
469+
# TODO: remove special casing here
470+
if hook is None and is_dataclass_transform_decorator(decorator):
471+
hook = dataclasses_plugin.dataclass_class_maker_callback
472+
473+
if hook:
474+
ok = ok and hook(ClassDefContext(defn, decorator, self))
465475
return ok
466476

467477

mypy/semanal_shared.py

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
get_proper_type,
3939
)
4040

41+
# Subclasses can override these Var attributes with incompatible types. This can also be
42+
# set for individual attributes using 'allow_incompatible_override' of Var.
43+
ALLOW_INCOMPATIBLE_OVERRIDE: Final = ("__slots__", "__deletable__", "__match_args__")
44+
45+
4146
# Priorities for ordering of patches within the "patch" phase of semantic analysis
4247
# (after the main pass):
4348

mypy/types.py

+5
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@
150150
"typing_extensions.Never",
151151
)
152152

153+
DATACLASS_TRANSFORM_NAMES: Final = (
154+
"typing.dataclass_transform",
155+
"typing_extensions.dataclass_transform",
156+
)
157+
153158
# A placeholder used for Bogus[...] parameters
154159
_dummy: Final[Any] = object()
155160

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
[case testDataclassTransformReusesDataclassLogic]
2+
# flags: --python-version 3.7
3+
from typing import dataclass_transform, Type
4+
5+
@dataclass_transform()
6+
def my_dataclass(cls: Type) -> Type:
7+
return cls
8+
9+
@my_dataclass
10+
class Person:
11+
name: str
12+
age: int
13+
14+
def summary(self):
15+
return "%s is %d years old." % (self.name, self.age)
16+
17+
reveal_type(Person) # N: Revealed type is "def (name: builtins.str, age: builtins.int) -> __main__.Person"
18+
Person('John', 32)
19+
Person('Jonh', 21, None) # E: Too many arguments for "Person"
20+
21+
[typing fixtures/typing-medium.pyi]
22+
[builtins fixtures/dataclasses.pyi]
23+
24+
[case testDataclassTransformIsFoundInTypingExtensions]
25+
# flags: --python-version 3.7
26+
from typing import Type
27+
from typing_extensions import dataclass_transform
28+
29+
@dataclass_transform()
30+
def my_dataclass(cls: Type) -> Type:
31+
return cls
32+
33+
@my_dataclass
34+
class Person:
35+
name: str
36+
age: int
37+
38+
def summary(self):
39+
return "%s is %d years old." % (self.name, self.age)
40+
41+
reveal_type(Person) # N: Revealed type is "def (name: builtins.str, age: builtins.int) -> __main__.Person"
42+
Person('John', 32)
43+
Person('Jonh', 21, None) # E: Too many arguments for "Person"
44+
45+
[typing fixtures/typing-full.pyi]
46+
[builtins fixtures/dataclasses.pyi]

test-data/unit/fixtures/dataclasses.pyi

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ class dict(Mapping[KT, VT]):
3737
def get(self, k: KT, default: Union[KT, _T]) -> Union[VT, _T]: pass
3838
def __len__(self) -> int: ...
3939

40-
class list(Generic[_T], Sequence[_T]): pass
40+
class list(Generic[_T], Sequence[_T]):
41+
def __contains__(self, item: object) -> int: pass
42+
def __getitem__(self, key: int) -> _T: pass
43+
def __iter__(self) -> Iterator[_T]: pass
44+
4145
class function: pass
4246
class classmethod: pass
4347
property = object()

test-data/unit/fixtures/typing-medium.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,5 @@ class ContextManager(Generic[T]):
7171
class _SpecialForm: pass
7272

7373
TYPE_CHECKING = 1
74+
75+
def dataclass_transform() -> Callable[[T], T]: ...

test-data/unit/lib-stub/typing_extensions.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,5 @@ class _TypedDict(Mapping[str, object]):
5757
def TypedDict(typename: str, fields: Dict[str, Type[_T]], *, total: Any = ...) -> Type[dict]: ...
5858

5959
def reveal_type(__obj: T) -> T: pass
60+
61+
def dataclass_transform() -> Callable[[T], T]: ...

0 commit comments

Comments
 (0)