Skip to content

Commit 31d3717

Browse files
feat(tidy3d): FXC-3573-static-type-checking-in-python-client
1 parent cb3e2d0 commit 31d3717

File tree

112 files changed

+699
-502
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+699
-502
lines changed

.pre-commit-config.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,11 @@ repos:
2424
hooks:
2525
- id: zizmor
2626
stages: [pre-commit]
27-
27+
- repo: https://github.com/pre-commit/mirrors-mypy
28+
rev: v1.13.0
29+
hooks:
30+
- id: mypy
31+
name: mypy (type signatures)
32+
files: ^tidy3d/web
33+
args:
34+
- --config-file=pyproject.toml

poetry.lock

Lines changed: 92 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ pytest-env = "^1.1.5"
6868
tox = { version = "*", optional = true }
6969
diff-cover = { version = "*", optional = true }
7070
zizmor = { version = "*", optional = true }
71+
mypy = { version = "1.13.0", optional = true }
7172

7273
# gdstk
7374
gdstk = { version = ">=0.9.49", optional = true }
@@ -136,6 +137,7 @@ dev = [
136137
'jupyter',
137138
'myst-parser',
138139
'memory_profiler',
140+
'mypy',
139141
'psutil',
140142
'nbconvert',
141143
'nbdime',
@@ -312,3 +314,81 @@ norecursedirs = [
312314
filterwarnings = "ignore::DeprecationWarning"
313315
testpaths = ["tidy3d", "tests", "docs"]
314316
python_files = "*.py"
317+
318+
[tool.mypy]
319+
python_version = "3.10"
320+
ignore_missing_imports = true
321+
follow_imports = "skip"
322+
disallow_untyped_defs = true
323+
disable_error_code = [
324+
"abstract",
325+
"annotation-unchecked",
326+
"arg-type",
327+
"assert-type",
328+
"assignment",
329+
"attr-defined",
330+
"await-not-async",
331+
"call-arg",
332+
"call-overload",
333+
"comparison-overlap",
334+
"dict-item",
335+
"empty-body",
336+
"exit-return",
337+
"explicit-override",
338+
"func-returns-value",
339+
"has-type",
340+
"ignore-without-code",
341+
"import",
342+
"import-not-found",
343+
"import-untyped",
344+
"index",
345+
"list-item",
346+
"literal-required",
347+
"method-assign",
348+
"misc",
349+
"mutable-override",
350+
"name-defined",
351+
"name-match",
352+
"narrowed-type-not-subtype",
353+
"no-any-return",
354+
"no-any-unimported",
355+
"no-overload-impl",
356+
"no-redef",
357+
"no-untyped-call",
358+
"operator",
359+
"overload-cannot-match",
360+
"overload-overlap",
361+
"override",
362+
"possibly-undefined",
363+
"prop-decorator",
364+
"redundant-cast",
365+
"redundant-expr",
366+
"redundant-self",
367+
"return",
368+
"return-value",
369+
"safe-super",
370+
"str-bytes-safe",
371+
"str-format",
372+
"syntax",
373+
"top-level-await",
374+
"truthy-bool",
375+
"truthy-function",
376+
"truthy-iterable",
377+
"type-abstract",
378+
"type-arg",
379+
"type-var",
380+
"typeddict-item",
381+
"typeddict-readonly-mutated",
382+
"typeddict-unknown-key",
383+
"unimported-reveal",
384+
"union-attr",
385+
"unreachable",
386+
"unused-awaitable",
387+
"unused-coroutine",
388+
"unused-ignore",
389+
"used-before-def",
390+
"valid-newtype",
391+
"valid-type",
392+
"var-annotated",
393+
]
394+
enable_error_code = ["no-untyped-def"]

tidy3d/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tidy3d.web import Job
1010

1111

12-
def main(args):
12+
def main(args) -> None:
1313
"""Parse args and run the corresponding tidy3d simulaton."""
1414

1515
parser = argparse.ArgumentParser(description="Tidy3D")

tidy3d/components/autograd/derivative_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field, replace
6-
from typing import Callable, Optional, Union
6+
from typing import Any, Callable, Optional, Union
77

88
import numpy as np
99
import xarray as xr
@@ -25,12 +25,12 @@
2525
class LazyInterpolator:
2626
"""Lazy wrapper for interpolators that creates them on first access."""
2727

28-
def __init__(self, creator_func: Callable):
28+
def __init__(self, creator_func: Callable) -> None:
2929
"""Initialize with a function that creates the interpolator when called."""
3030
self.creator_func = creator_func
3131
self._interpolator = None
3232

33-
def __call__(self, *args, **kwargs):
33+
def __call__(self, *args: Any, **kwargs: Any):
3434
"""Create interpolator on first call and delegate to it."""
3535
if self._interpolator is None:
3636
self._interpolator = self.creator_func()
@@ -172,7 +172,7 @@ class DerivativeInfo:
172172
# private cache for interpolators
173173
_interpolators_cache: dict = field(default_factory=dict, init=False, repr=False)
174174

175-
def updated_copy(self, **kwargs):
175+
def updated_copy(self, **kwargs: Any):
176176
"""Create a copy with updated fields."""
177177
kwargs.pop("deep", None)
178178
kwargs.pop("validate", None)
@@ -251,7 +251,7 @@ def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict:
251251
interpolators = {}
252252
coord_cache = {}
253253

254-
def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=True):
254+
def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=True) -> None:
255255
"""Helper to create a group of lazy interpolators."""
256256
if is_field_group:
257257
interpolators[group_key] = {}

tidy3d/components/autograd/functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import itertools
4+
from typing import Any
45

56
import autograd.numpy as anp
67
import numpy as np
@@ -98,7 +99,7 @@ def interpn(
9899
xi: tuple[NDArray[np.float64], ...],
99100
*,
100101
method: InterpolationType = "linear",
101-
**kwargs,
102+
**kwargs: Any,
102103
) -> NDArray[np.float64]:
103104
"""Interpolate over a rectilinear grid in arbitrary dimensions.
104105

tidy3d/components/base.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def skip_if_fields_missing(fields: list[str], root=False):
137137

138138
def actual_decorator(validator):
139139
@wraps(validator)
140-
def _validator(cls, *args, **kwargs):
140+
def _validator(cls, *args: Any, **kwargs: Any):
141141
"""New validator function."""
142142
values = kwargs.get("values")
143143
if values is None:
@@ -180,7 +180,7 @@ def _hash_self(self) -> str:
180180
self.to_hdf5(bf)
181181
return hashlib.md5(bf.getvalue()).hexdigest()
182182

183-
def __init__(self, **kwargs):
183+
def __init__(self, **kwargs: Any) -> None:
184184
"""Init method, includes post-init validators."""
185185
log.begin_capture()
186186
super().__init__(**kwargs)
@@ -274,7 +274,7 @@ def _default(o):
274274

275275
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
276276

277-
def copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Self:
277+
def copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self:
278278
"""Copy a Tidy3dBaseModel. With ``deep=True`` and ``validate=True`` as default."""
279279
kwargs.update(deep=deep)
280280
new_copy = pydantic.BaseModel.copy(self, **kwargs)
@@ -286,7 +286,7 @@ def copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Self:
286286
return new_copy
287287

288288
def updated_copy(
289-
self, path: Optional[str] = None, deep: bool = True, validate: bool = True, **kwargs
289+
self, path: Optional[str] = None, deep: bool = True, validate: bool = True, **kwargs: Any
290290
) -> Self:
291291
"""Make copy of a component instance with ``**kwargs`` indicating updated field values.
292292
@@ -345,7 +345,7 @@ def updated_copy(
345345

346346
return self._updated_copy(deep=deep, validate=validate, **{field_name: new_component})
347347

348-
def _updated_copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Self:
348+
def _updated_copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self:
349349
"""Make copy of a component instance with ``**kwargs`` indicating updated field values."""
350350
return self.copy(update=kwargs, deep=deep, validate=validate)
351351

@@ -370,7 +370,7 @@ def from_file(
370370
group_path: Optional[str] = None,
371371
lazy: bool = False,
372372
on_load: Optional[Callable] = None,
373-
**parse_obj_kwargs,
373+
**parse_obj_kwargs: Any,
374374
) -> Self:
375375
"""Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file.
376376
@@ -471,7 +471,7 @@ def to_file(self, fname: PathLike) -> None:
471471
return converter(fname=fname)
472472

473473
@classmethod
474-
def from_json(cls, fname: PathLike, **parse_obj_kwargs) -> Self:
474+
def from_json(cls, fname: PathLike, **parse_obj_kwargs: Any) -> Self:
475475
"""Load a :class:`Tidy3dBaseModel` from .json file.
476476
477477
Parameters
@@ -536,7 +536,7 @@ def to_json(self, fname: PathLike) -> None:
536536
file_handle.write(json_string)
537537

538538
@classmethod
539-
def from_yaml(cls, fname: PathLike, **parse_obj_kwargs) -> Self:
539+
def from_yaml(cls, fname: PathLike, **parse_obj_kwargs: Any) -> Self:
540540
"""Loads :class:`Tidy3dBaseModel` from .yaml file.
541541
542542
Parameters
@@ -759,7 +759,7 @@ def from_hdf5(
759759
fname: PathLike,
760760
group_path: str = "",
761761
custom_decoders: Optional[list[Callable]] = None,
762-
**parse_obj_kwargs,
762+
**parse_obj_kwargs: Any,
763763
) -> Self:
764764
"""Loads :class:`Tidy3dBaseModel` instance to .hdf5 file.
765765
@@ -903,7 +903,7 @@ def from_hdf5_gz(
903903
fname: PathLike,
904904
group_path: str = "",
905905
custom_decoders: Optional[list[Callable]] = None,
906-
**parse_obj_kwargs,
906+
**parse_obj_kwargs: Any,
907907
) -> Self:
908908
"""Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file.
909909
@@ -1036,7 +1036,7 @@ def _json_string(self) -> str:
10361036
"""
10371037
return self._json()
10381038

1039-
def _json(self, indent=INDENT, exclude_unset=False, **kwargs) -> str:
1039+
def _json(self, indent=INDENT, exclude_unset=False, **kwargs: Any) -> str:
10401040
"""Overwrites the model ``json`` representation with some extra customized handling.
10411041
10421042
Parameters
@@ -1114,7 +1114,7 @@ def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self:
11141114

11151115
self_dict = self.dict()
11161116

1117-
def insert_value(x, path: tuple[str, ...], sub_dict: dict):
1117+
def insert_value(x, path: tuple[str, ...], sub_dict: dict) -> None:
11181118
"""Insert a value into the path into a dictionary."""
11191119
current_dict = sub_dict
11201120
for key in path[:-1]:
@@ -1349,13 +1349,13 @@ def __init__(
13491349
self,
13501350
fname: PathLike,
13511351
group_path: Optional[str],
1352-
parse_obj_kwargs: Optional[dict[str, Any]],
1352+
parse_obj_kwargs: Any,
13531353
):
13541354
object.__setattr__(self, "_lazy_fname", Path(fname))
13551355
object.__setattr__(self, "_lazy_group_path", group_path)
13561356
object.__setattr__(self, "_lazy_parse_obj_kwargs", dict(parse_obj_kwargs or {}))
13571357

1358-
def copy(self, **kwargs):
1358+
def copy(self, **kwargs: Any):
13591359
"""Return another lazy proxy instead of materializing."""
13601360
return _LazyProxy(
13611361
self._lazy_fname,

tidy3d/components/base_sim/simulation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6-
from typing import Optional
6+
from typing import Any, Optional
77

88
import autograd.numpy as anp
99
import pydantic.v1 as pd
@@ -251,7 +251,7 @@ def plot(
251251
hlim: Optional[tuple[float, float]] = None,
252252
vlim: Optional[tuple[float, float]] = None,
253253
fill_structures: bool = True,
254-
**patch_kwargs,
254+
**patch_kwargs: Any,
255255
) -> Ax:
256256
"""Plot each of simulation's components on a plane defined by one nonzero x,y,z coordinate.
257257
@@ -482,7 +482,7 @@ def plot_boundaries(
482482
y: Optional[float] = None,
483483
z: Optional[float] = None,
484484
ax: Ax = None,
485-
**kwargs,
485+
**kwargs: Any,
486486
) -> Ax:
487487
"""Plot the simulation boundary conditions as lines on a plane
488488
defined by one nonzero x,y,z coordinate.
@@ -685,7 +685,7 @@ def plot_structures_heat_conductivity(
685685
)
686686

687687
@classmethod
688-
def from_scene(cls, scene: Scene, **kwargs) -> AbstractSimulation:
688+
def from_scene(cls, scene: Scene, **kwargs: Any) -> AbstractSimulation:
689689
"""Create a simulation from a :class:`.Scene` instance. Must provide additional parameters
690690
to define a valid simulation (for example, ``size``, ``run_time``, ``grid_spec``, etc).
691691

tidy3d/components/boundary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from abc import ABC
6-
from typing import Optional, Union
6+
from typing import Any, Optional, Union
77

88
import numpy as np
99
import pydantic.v1 as pd
@@ -416,7 +416,7 @@ def plot(
416416
y: Optional[float] = None,
417417
z: Optional[float] = None,
418418
ax: Ax = None,
419-
**patch_kwargs,
419+
**patch_kwargs: Any,
420420
) -> Ax:
421421
"""Plot this absorber."""
422422

tidy3d/components/data/data_array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class DataArray(xr.DataArray):
7777
# stores a dictionary of attributes corresponding to the data values
7878
_data_attrs: dict[str, str] = {}
7979

80-
def __init__(self, data, *args, **kwargs):
80+
def __init__(self, data, *args: Any, **kwargs: Any) -> None:
8181
# if data is a vanilla autograd box, convert to our box
8282
if isbox(data) and not is_tidy_box(data):
8383
data = TidyArrayBox.from_arraybox(data)
@@ -155,7 +155,7 @@ def assign_coord_attrs(cls, val):
155155
return val
156156

157157
@classmethod
158-
def __modify_schema__(cls, field_schema):
158+
def __modify_schema__(cls, field_schema) -> None:
159159
"""Sets the schema of DataArray object."""
160160

161161
schema = {
@@ -435,7 +435,7 @@ def _ag_interp(
435435
return self._from_temp_dataset(ds)
436436

437437
@staticmethod
438-
def _ag_interp_func(var, indexes_coords, method, **kwargs):
438+
def _ag_interp_func(var, indexes_coords, method, **kwargs: Any):
439439
"""
440440
Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`.
441441

0 commit comments

Comments
 (0)