Skip to content

Commit 08684f9

Browse files
committed
Typing: Add overload signatures for open
Added for the `FolderData` and `NodeRepository` classes. The signature of the `SinglefileData` was actually incorrect as it defined: t.Iterator[t.BinaryIO | t.TextIO] as the return type, but which should really be: t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO] The former will cause `mypy` to raise an error.
1 parent a2a1796 commit 08684f9

File tree

6 files changed

+43
-13
lines changed

6 files changed

+43
-13
lines changed

aiida/orm/nodes/data/folder.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,18 @@ def list_object_names(self, path: str | None = None) -> list[str]:
7171
"""
7272
return self.base.repository.list_object_names(path)
7373

74+
@t.overload
75+
@contextlib.contextmanager
76+
def open(self, path: FilePath, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
77+
...
78+
79+
@t.overload
80+
@contextlib.contextmanager
81+
def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
82+
...
83+
7484
@contextlib.contextmanager
75-
def open(self, path: str, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]:
85+
def open(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]:
7686
"""Open a file handle to an object stored under the given key.
7787
7888
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method

aiida/orm/nodes/data/singlefile.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
__all__ = ('SinglefileData',)
2424

25+
FilePath = t.Union[str, pathlib.PurePosixPath]
26+
2527

2628
class SinglefileData(Data):
2729
"""Data class that can be used to store a single file in its repository."""
@@ -37,7 +39,9 @@ def from_string(cls, content: str, filename: str | pathlib.Path | None = None, *
3739
"""
3840
return cls(io.StringIO(content), filename, **kwargs)
3941

40-
def __init__(self, file: str | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any) -> None:
42+
def __init__(
43+
self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path | None = None, **kwargs: t.Any
44+
) -> None:
4145
"""Construct a new instance and set the contents to that of the file.
4246
4347
:param file: an absolute filepath or filelike object whose contents to copy.
@@ -60,26 +64,30 @@ def filename(self) -> str:
6064

6165
@t.overload
6266
@contextlib.contextmanager
63-
def open(self, path: str, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
67+
def open(self, path: FilePath, mode: t.Literal['r'] = ...) -> t.Iterator[t.TextIO]:
6468
...
6569

6670
@t.overload
6771
@contextlib.contextmanager
68-
def open(self, path: None, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
72+
def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
6973
...
7074

7175
@t.overload
7276
@contextlib.contextmanager
73-
def open(self, path: str, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
77+
def open( # type: ignore[overload-overlap]
78+
self, path: None = None, mode: t.Literal['r'] = ...
79+
) -> t.Iterator[t.TextIO]:
7480
...
7581

7682
@t.overload
7783
@contextlib.contextmanager
78-
def open(self, path: None, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
84+
def open(self, path: None = None, mode: t.Literal['rb'] = ...) -> t.Iterator[t.BinaryIO]:
7985
...
8086

8187
@contextlib.contextmanager
82-
def open(self, path: str | None = None, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO | t.TextIO]:
88+
def open(self,
89+
path: FilePath | None = None,
90+
mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]:
8391
"""Return an open file handle to the content of this data node.
8492
8593
:param path: the relative path of the object within the repository.
@@ -113,7 +121,7 @@ def get_content(self, mode: str = 'r') -> str | bytes:
113121
with self.open(mode=mode) as handle: # type: ignore[call-overload]
114122
return handle.read()
115123

116-
def set_file(self, file: str | t.IO, filename: str | pathlib.Path | None = None) -> None:
124+
def set_file(self, file: str | pathlib.Path | t.IO, filename: str | pathlib.Path | None = None) -> None:
117125
"""Store the content of the file in the node's repository, deleting any other existing objects.
118126
119127
:param file: an absolute filepath or filelike object whose contents to copy

aiida/orm/nodes/repository.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,18 @@ def list_object_names(self, path: str | None = None) -> list[str]:
164164
"""
165165
return self._repository.list_object_names(path)
166166

167+
@t.overload
168+
@contextlib.contextmanager
169+
def open(self, path: FilePath, mode: t.Literal['r']) -> t.Iterator[t.TextIO]:
170+
...
171+
172+
@t.overload
173+
@contextlib.contextmanager
174+
def open(self, path: FilePath, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
175+
...
176+
167177
@contextlib.contextmanager
168-
def open(self, path: FilePath, mode='r') -> t.Iterator[t.BinaryIO | t.TextIO]:
178+
def open(self, path: FilePath, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO] | t.Iterator[t.TextIO]:
169179
"""Open a file handle to an object stored under the given key.
170180
171181
.. note:: this should only be used to open a handle to read an existing file. To write a new file use the method
@@ -210,7 +220,7 @@ def as_path(self, path: FilePath | None = None) -> t.Iterator[pathlib.Path]:
210220
assert path is not None
211221
with self.open(path, mode='rb') as source:
212222
with filepath.open('wb') as target:
213-
shutil.copyfileobj(source, target) # type: ignore[misc]
223+
shutil.copyfileobj(source, target)
214224
yield filepath
215225

216226
def get_object(self, path: FilePath | None = None) -> File:

aiida/parsers/plugins/arithmetic/add.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
###########################################################################
1010
# Warning: this implementation is used directly in the documentation as a literal-include, which means that if any part
1111
# of this code is changed, the snippets in the file `docs/source/howto/codes.rst` have to be checked for consistency.
12-
# mypy: disable_error_code=arg-type
12+
# mypy: disable_error_code=call-overload
1313
"""Parser for an `ArithmeticAddCalculation` job."""
1414
from aiida.parsers.parser import Parser
1515

aiida/parsers/plugins/diff_tutorial/parsers.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
Register parsers via the "aiida.parsers" entry point in the pyproject.toml file.
66
"""
7+
# mypy: disable_error_code=call-overload
78
# START PARSER HEAD
89
from aiida.engine import ExitCode
910
from aiida.orm import SinglefileData
@@ -38,7 +39,7 @@ def parse(self, **kwargs):
3839

3940
# add output file
4041
self.logger.info(f"Parsing '{output_filename}'")
41-
with self.retrieved.open(output_filename, 'rb') as handle: # type: ignore[arg-type]
42+
with self.retrieved.open(output_filename, 'rb') as handle:
4243
output_node = SinglefileData(file=handle)
4344
self.out('diff', output_node)
4445

@@ -59,7 +60,7 @@ def parse(self, **kwargs):
5960

6061
# add output file
6162
self.logger.info(f"Parsing '{output_filename}'")
62-
with self.retrieved.open(output_filename, 'rb') as handle: # type: ignore[arg-type]
63+
with self.retrieved.open(output_filename, 'rb') as handle:
6364
output_node = SinglefileData(file=handle)
6465
self.out('diff', output_node)
6566

docs/source/nitpick-exceptions

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ py:class BinaryIO
1414
py:class EntryPoint
1515
py:class EntryPoints
1616
py:class IO
17+
py:class FilePath
1718
py:class Path
1819
py:class str | list[str]
1920
py:class str | Path

0 commit comments

Comments
 (0)