Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression test, etc. #39

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/default_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"\n",
" @staticmethod\n",
" def default(\n",
" primitive: jax.core.Primitive,\n",
" primitive: jax.extend.core.Primitive,\n",
" values: Sequence[Union[ArrayLike, quax.Value]],\n",
" params: dict,\n",
" ):\n",
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ select = ["E", "F", "I001"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
fixable = ["I001", "F401"]

[tool.ruff.lint.flake8-import-conventions.extend-aliases]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh now this is a cool rule that I didn't know existed!

"jax.extend" = "jex"

[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
Expand Down
27 changes: 16 additions & 11 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import jax._src
import jax.core as core
import jax.extend as jex
import jax.extend.linear_util as lu
import jax.numpy as jnp
import jax.tree_util as jtu
Expand All @@ -24,10 +25,12 @@
#


_rules: dict[core.Primitive, plum.Function] = {}
_rules: dict[jex.core.Primitive, plum.Function] = {}


def register(primitive: core.Primitive, *, precedence: int = 0) -> Callable[[CT], CT]:
def register(
primitive: jex.core.Primitive, *, precedence: int = 0
) -> Callable[[CT], CT]:
"""Registers a multiple dispatch implementation for this JAX primitive.

!!! Example
Expand All @@ -47,8 +50,8 @@ def _(x: SomeValue, y: SomeValue):

**Arguments:**

- `primitive`: The `jax.core.Primitive` to provide a multiple dispatch
implementation for.
- `primitive`: The `jax.extend.core.Primitive` to provide a multiple
dispatch implementation for.

- `precedence`: The precedence of this rule.
See `plum.Dispatcher.dispatch` for details.
Expand Down Expand Up @@ -102,7 +105,7 @@ def full_lower(self):


def _default_process(
primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
primitive: jex.core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
):
defaults = set()
for x in values:
Expand Down Expand Up @@ -374,7 +377,9 @@ def aval(self) -> core.AbstractValue:

@staticmethod
def default(
primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
primitive: jex.core.Primitive,
values: Sequence[Union[ArrayLike, "Value"]],
params,
) -> Union[ArrayLike, "Value", Sequence[Union[ArrayLike, "Value"]]]:
"""This is the default rule for when no rule has been [`quax.register`][]'d for
a primitive.
Expand All @@ -394,7 +399,7 @@ def default(

**Arguments:**

- `primitive`: the `jax.core.Primitive` being considered.
- `primitive`: the `jax.extend.core.Primitive` being considered.
- `values`: a sequence of what values this primitive is being called with. Each
value can either be [`quax.Value`][]s, or a normal JAX arraylike (i.e.
`bool`/`int`/`float`/`complex`/NumPy scalar/NumPy array/JAX array).
Expand Down Expand Up @@ -519,7 +524,7 @@ def aval(self) -> core.ShapedArray:
@register(jax._src.pjit.pjit_p) # pyright: ignore
def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
del kwargs
fun = quaxify(core.jaxpr_as_fun(jaxpr))
fun = quaxify(jex.core.jaxpr_as_fun(jaxpr))
if inline:
return fun(*args)
else:
Expand All @@ -541,9 +546,9 @@ def _(
init_vals = args[cond_nconsts + body_nconsts :]

# compute jaxpr of quaxified body and condition function
quax_cond_fn = quaxify(core.jaxpr_as_fun(cond_jaxpr))
quax_cond_fn = quaxify(jex.core.jaxpr_as_fun(cond_jaxpr))
quax_cond_jaxpr = jax.make_jaxpr(quax_cond_fn)(*cond_consts, *init_vals)
quax_body_fn = quaxify(core.jaxpr_as_fun(body_jaxpr))
quax_body_fn = quaxify(jex.core.jaxpr_as_fun(body_jaxpr))
quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals)

cond_leaves, _ = jtu.tree_flatten(cond_consts)
Expand Down Expand Up @@ -581,7 +586,7 @@ def _(

def flat_quax_call(flat_args):
args = jtu.tree_unflatten(in_tree, flat_args)
out = quaxify(core.jaxpr_as_fun(jaxpr))(*args)
out = quaxify(jex.core.jaxpr_as_fun(jaxpr))(*args)
flat_out, out_tree = jtu.tree_flatten(out)
out_trees.append(out_tree)
return flat_out
Expand Down
3 changes: 2 additions & 1 deletion quax/examples/named/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import equinox as eqx
import jax.core
import jax.extend as jex
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import ArrayLike
Expand Down Expand Up @@ -98,7 +99,7 @@ def _broadcast_axes(axes1, axes2):


def _register_elementwise_binop(
op: Callable[[Any, Any], Any], prim: jax.core.Primitive
op: Callable[[Any, Any], Any], prim: jex.core.Primitive
):
quax_op = quax.quaxify(op)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax.numpy as jnp
import jax.random as jr
import pytest
from jaxtyping import TypeCheckError
from plum import NotFoundLookupError

import quax
import quax.examples.lora as lora
Expand Down Expand Up @@ -110,3 +112,18 @@ def test_materialise():
_ = quax.quaxify(jax.nn.relu)(x_true)
with pytest.raises(RuntimeError, match="Refusing to materialise"):
_ = quax.quaxify(jax.nn.relu)(x_false)


def test_regression_38(getkey):
"""Regression test for PR 38 (stackless tracers)."""
x = jnp.arange(4.0).reshape(2, 2)
y = lora.LoraArray(x, rank=1, key=getkey())

def f(x):
return jax.lax.add_p.bind(x, y)

func = quax.quaxify(f)

# Error type depends on whether jaxtyping is on
with pytest.raises((TypeCheckError, NotFoundLookupError)):
_ = func(y)