Skip to content

Commit

Permalink
feat: add select type annotations
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored and patrick-kidger committed Jul 21, 2024
1 parent 3433bc8 commit 9b27fc0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.5.2
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.365
rev: v1.1.372
hooks:
- id: pyright
additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping", "plum-dispatch"]
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ addopts = "--jaxtyping-packages=quax,beartype.beartype(conf=beartype.BeartypeCon
[tool.ruff.lint]
select = ["E", "F", "I001"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
ignore-init-module-imports = true
fixable = ["I001", "F401"]

[tool.ruff.lint.isort]
Expand Down
30 changes: 19 additions & 11 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools as ft
import itertools as it
from collections.abc import Callable, Sequence
from typing import Any, cast, Union
from typing import Any, cast, Generic, TypeVar, Union
from typing_extensions import TypeGuard

import equinox as eqx
Expand All @@ -17,6 +17,8 @@
from jaxtyping import ArrayLike, PyTree


CT = TypeVar("CT", bound=Callable)

#
# Rules
#
Expand All @@ -25,7 +27,7 @@
_rules: dict[core.Primitive, plum.Function] = {}


def register(primitive: core.Primitive):
def register(primitive: core.Primitive) -> Callable[[CT], CT]:
"""Registers a multiple dispatch implementation for this JAX primitive.
!!! Example
Expand Down Expand Up @@ -53,7 +55,7 @@ def _(x: SomeValue, y: SomeValue):
A decorator for registering a multiple dispatch rule with the specified primitive.
"""

def _register(rule: Callable):
def _register(rule: CT) -> CT:
try:
existing_rule = _rules[primitive] # pyright: ignore
except KeyError:
Expand All @@ -80,7 +82,7 @@ def existing_rule():
class _QuaxTracer(core.Tracer):
__slots__ = ("value",)

def __init__(self, trace: "_QuaxTrace", value: "Value"):
def __init__(self, trace: "_QuaxTrace", value: "Value") -> None:
assert _is_value(value)
self._trace = trace
self.value = value
Expand Down Expand Up @@ -292,13 +294,13 @@ def _unwrap_tracer(trace, x):
return x


class _Quaxify(eqx.Module):
fn: Callable
class _Quaxify(eqx.Module, Generic[CT]):
fn: CT
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]]
dynamic: bool = eqx.field(static=True)

@property
def __wrapped__(self):
def __wrapped__(self) -> CT:
return self.fn

def __call__(self, *args, **kwargs):
Expand All @@ -320,13 +322,16 @@ def __call__(self, *args, **kwargs):
out = jtu.tree_map(ft.partial(_unwrap_tracer, trace), out)
return out

def __get__(self, instance, owner):
def __get__(self, instance: Union[object, None], owner: Any):
if instance is None:
return self
return eqx.Partial(self, instance)


def quaxify(fn, filter_spec=True):
def quaxify(
fn: CT,
filter_spec: PyTree[Union[bool, Callable[[Any], bool]]] = True,
) -> _Quaxify[CT]:
"""'Quaxifies' a function, so that it understands custom array-ish objects like
[`quax.examples.lora.LoraArray`][]. When this function is called, multiple dispatch
will be performed against the types it is called with.
Expand All @@ -349,7 +354,10 @@ def quaxify(fn, filter_spec=True):
nested `quax.quaxify`. See the
[advanced tutorial](../examples/redispatch.ipynb).
"""
return eqx.module_update_wrapper(_Quaxify(fn, filter_spec, dynamic=False))
return cast(
_Quaxify[CT],
eqx.module_update_wrapper(_Quaxify(fn, filter_spec, dynamic=False)),
)


#
Expand Down Expand Up @@ -381,7 +389,7 @@ def aval(self) -> core.AbstractValue:

@staticmethod
def default(
primitive, values: Sequence[Union[ArrayLike, "Value"]], params
primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
) -> Union[ArrayLike, "Value", Sequence[Union[ArrayLike, "Value"]]]:
"""This is the default rule for when no rule has been [`quax.register`][]'d for
a primitive.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def default(primitive, values, params):
if primitive.multiple_results:
return [Foo(x) for x in out]
else:
return Foo(out)
return Foo(cast(Array, out))

return Foo

Expand Down

0 comments on commit 9b27fc0

Please sign in to comment.