From 9b27fc036b7b0727566b51e5cff9fc4859badece Mon Sep 17 00:00:00 2001 From: nstarman Date: Tue, 16 Jul 2024 22:18:49 -0400 Subject: [PATCH] feat: add select type annotations Signed-off-by: nstarman --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 1 - quax/_core.py | 30 +++++++++++++++++++----------- tests/test_core.py | 2 +- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9960a4..b73c180 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.7 + rev: v0.5.2 hooks: - id: ruff # linter types_or: [ python, pyi, jupyter ] @@ -8,7 +8,7 @@ repos: - id: ruff-format # formatter types_or: [ python, pyi, jupyter ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.365 + rev: v1.1.372 hooks: - id: pyright additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping", "plum-dispatch"] diff --git a/pyproject.toml b/pyproject.toml index 439e40f..1d352b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeCon [tool.ruff.lint] select = ["E", "F", "I001"] ignore = ["E402", "E721", "E731", "E741", "F722"] -ignore-init-module-imports = true fixable = ["I001", "F401"] [tool.ruff.lint.isort] diff --git a/quax/_core.py b/quax/_core.py index 4103d45..35edd5d 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -2,7 +2,7 @@ import functools as ft import itertools as it from collections.abc import Callable, Sequence -from typing import Any, cast, Union +from typing import Any, cast, Generic, TypeVar, Union from typing_extensions import TypeGuard import equinox as eqx @@ -17,6 +17,8 @@ from jaxtyping import ArrayLike, PyTree +CT = TypeVar("CT", bound=Callable) + # # Rules # @@ -25,7 +27,7 @@ _rules: dict[core.Primitive, plum.Function] = {} -def register(primitive: core.Primitive): +def register(primitive: core.Primitive) -> Callable[[CT], CT]: """Registers a multiple dispatch implementation for this JAX primitive. !!! Example @@ -53,7 +55,7 @@ def _(x: SomeValue, y: SomeValue): A decorator for registering a multiple dispatch rule with the specified primitive. """ - def _register(rule: Callable): + def _register(rule: CT) -> CT: try: existing_rule = _rules[primitive] # pyright: ignore except KeyError: @@ -80,7 +82,7 @@ def existing_rule(): class _QuaxTracer(core.Tracer): __slots__ = ("value",) - def __init__(self, trace: "_QuaxTrace", value: "Value"): + def __init__(self, trace: "_QuaxTrace", value: "Value") -> None: assert _is_value(value) self._trace = trace self.value = value @@ -292,13 +294,13 @@ def _unwrap_tracer(trace, x): return x -class _Quaxify(eqx.Module): - fn: Callable +class _Quaxify(eqx.Module, Generic[CT]): + fn: CT filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] dynamic: bool = eqx.field(static=True) @property - def __wrapped__(self): + def __wrapped__(self) -> CT: return self.fn def __call__(self, *args, **kwargs): @@ -320,13 +322,16 @@ def __call__(self, *args, **kwargs): out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out) return out - def __get__(self, instance, owner): + def __get__(self, instance: Union[object, None], owner: Any): if instance is None: return self return eqx.Partial(self, instance) -def quaxify(fn, filter_spec=True): +def quaxify( + fn: CT, + filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = True, +) -> _Quaxify[CT]: """'Quaxifies' a function, so that it understands custom array-ish objects like [`quax.examples.lora.LoraArray`][]. When this function is called, multiple dispatch will be performed against the types it is called with. @@ -349,7 +354,10 @@ def quaxify(fn, filter_spec=True): nested `quax.quaxify`. See the [advanced tutorial](../examples/redispatch.ipynb). """ - return eqx.module_update_wrapper(_Quaxify(fn, filter_spec, dynamic=False)) + return cast( + _Quaxify[CT], + eqx.module_update_wrapper(_Quaxify(fn, filter_spec, dynamic=False)), + ) # @@ -381,7 +389,7 @@ def aval(self) -> core.AbstractValue: @staticmethod def default( - primitive, values: Sequence[Union[ArrayLike, "Value"]], params + primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params ) -> Union[ArrayLike, "Value", Sequence[Union[ArrayLike, "Value"]]]: """This is the default rule for when no rule has been [`quax.register`][]'d for a primitive. diff --git a/tests/test_core.py b/tests/test_core.py index 5395353..f031349 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -89,7 +89,7 @@ def default(primitive, values, params): if primitive.multiple_results: return [Foo(x) for x in out] else: - return Foo(out) + return Foo(cast(Array, out)) return Foo