Skip to content

Commit d141a03

Browse files
authored
Add direct support for dataclasses (#26)
* WIP support dataclasses * Ignore stubs/ directory * Also test KW_ONLY and InitVar in dataclasses
1 parent 6cad9c5 commit d141a03

File tree

4 files changed

+74
-15
lines changed

4 files changed

+74
-15
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Project specific folders
2+
stubs/
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

src/docstub/_stubs.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,30 +119,42 @@ class ScopeType(enum.StrEnum):
119119
# docstub: on
120120

121121

122+
# TODO use `libcst.metadata.ScopeProvider` instead
122123
@dataclass(slots=True, frozen=True)
123124
class _Scope:
124125
""""""
125126

126127
type: ScopeType
127-
node: cst.CSTNode = None
128+
node: cst.CSTNode | None = None
128129

129130
@property
130-
def has_self_or_cls(self):
131+
def has_self_or_cls(self) -> bool:
131132
return self.type in {ScopeType.METHOD, ScopeType.CLASSMETHOD}
132133

133134
@property
134-
def is_method(self):
135+
def is_method(self) -> bool:
135136
return self.type in {
136137
ScopeType.METHOD,
137138
ScopeType.CLASSMETHOD,
138139
ScopeType.STATICMETHOD,
139140
}
140141

141142
@property
142-
def is_class_init(self):
143+
def is_class_init(self) -> bool:
143144
out = self.is_method and self.node.name.value == "__init__"
144145
return out
145146

147+
@property
148+
def is_dataclass(self) -> bool:
149+
if cstm.matches(self.node, cstm.ClassDef()):
150+
# Determine if dataclass
151+
decorators = cstm.findall(self.node, cstm.Decorator())
152+
is_dataclass = any(
153+
cstm.findall(d, cstm.Name("dataclass")) for d in decorators
154+
)
155+
return is_dataclass
156+
return False
157+
146158

147159
def _get_docstring_node(node):
148160
"""Extract the node with the docstring from a definition.
@@ -672,16 +684,27 @@ def leave_AnnAssign(self, original_node, updated_node):
672684
updated_node : cst.AnnAssign
673685
"""
674686
name = updated_node.target.value
675-
is_type_alias = cstm.matches(
676-
updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias"))
677-
)
678-
is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__"))
679687

680-
# Remove value if not type alias or __all__
681-
if updated_node.value is not None and not is_type_alias and not is__all__:
682-
updated_node = updated_node.with_changes(
683-
value=None, equal=cst.MaybeSentinel.DEFAULT
688+
if updated_node.value is not None:
689+
is_type_alias = cstm.matches(
690+
updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias"))
684691
)
692+
is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__"))
693+
is_dataclass = self._scope_stack[-1].is_dataclass
694+
is_classvar = any(
695+
cstm.findall(updated_node.annotation, cstm.Name("ClassVar"))
696+
)
697+
698+
# Replace with ellipses if dataclass
699+
if is_dataclass and not is_classvar:
700+
updated_node = updated_node.with_changes(
701+
value=cst.Ellipsis(), equal=cst.MaybeSentinel.DEFAULT
702+
)
703+
# Remove value if not type alias or __all__
704+
elif not is_type_alias and not is__all__:
705+
updated_node = updated_node.with_changes(
706+
value=None, equal=cst.MaybeSentinel.DEFAULT
707+
)
685708

686709
# Replace with type annotation from docstring, if available
687710
pytypes = self._pytypes_stack[-1]

stubtest_allow.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,3 @@ docstub\._version\..*
22
docstub\..*\.__match_args__$
33
docstub._cache.FuncSerializer.__type_params__
44
docstub._cli.main
5-
docstub._config.Config.__init__
6-
docstub._docstrings.Annotation.__init__
7-
docstub._stubs._Scope.__init__

tests/test_stubs.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,39 @@ class Foo:
394394
# remove these empty lines from the result too
395395
result = dedent(result)
396396
assert expected == result
397+
398+
@pytest.mark.parametrize("decorator", ["dataclass", "dataclasses.dataclass"])
399+
def test_dataclass(self, decorator):
400+
source = dedent(
401+
f"""
402+
@{decorator}
403+
class Foo:
404+
a: float
405+
b: int = 3
406+
c: str = None
407+
_: KW_ONLY
408+
d: dict[str, Any] = field(default_factory=dict)
409+
e: InitVar[tuple] = tuple()
410+
f: ClassVar
411+
g: ClassVar[float]
412+
h: Final[ClassVar[int]] = 1
413+
"""
414+
)
415+
expected = dedent(
416+
f"""
417+
@{decorator}
418+
class Foo:
419+
a: float
420+
b: int = ...
421+
c: str = ...
422+
_: KW_ONLY
423+
d: dict[str, Any] = ...
424+
e: InitVar[tuple] = ...
425+
f: ClassVar
426+
g: ClassVar[float]
427+
h: Final[ClassVar[int]]
428+
"""
429+
)
430+
transformer = Py2StubTransformer()
431+
result = transformer.python_to_stub(source)
432+
assert expected == result

0 commit comments

Comments
 (0)