From 3b50d2d7aef3ed22928d430828dc2d4569f0f4c0 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Mon, 27 Jan 2025 23:15:50 -0500 Subject: [PATCH] refactor(jax.extend): update from deprecated imports style(jax.extend): alias to jex Signed-off-by: Nathaniel Starkman --- docs/examples/default_rules.ipynb | 2 +- pyproject.toml | 3 +++ quax/_core.py | 27 ++++++++++++++++----------- quax/examples/named/_core.py | 3 ++- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/docs/examples/default_rules.ipynb b/docs/examples/default_rules.ipynb index 7fae846..35806d5 100644 --- a/docs/examples/default_rules.ipynb +++ b/docs/examples/default_rules.ipynb @@ -63,7 +63,7 @@ "\n", " @staticmethod\n", " def default(\n", - " primitive: jax.core.Primitive,\n", + " primitive: jax.extend.core.Primitive,\n", " values: Sequence[Union[ArrayLike, quax.Value]],\n", " params: dict,\n", " ):\n", diff --git a/pyproject.toml b/pyproject.toml index 7c5fbca..0aedeef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ select = ["E", "F", "I001"] ignore = ["E402", "E721", "E731", "E741", "F722"] fixable = ["I001", "F401"] +[tool.ruff.lint.flake8-import-conventions.extend-aliases] +"jax.extend" = "jex" + [tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 diff --git a/quax/_core.py b/quax/_core.py index 5408d7b..0416b71 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -9,6 +9,7 @@ import jax import jax._src import jax.core as core +import jax.extend as jex import jax.extend.linear_util as lu import jax.numpy as jnp import jax.tree_util as jtu @@ -24,10 +25,12 @@ # -_rules: dict[core.Primitive, plum.Function] = {} +_rules: dict[jex.core.Primitive, plum.Function] = {} -def register(primitive: core.Primitive, *, precedence: int = 0) -> Callable[[CT], CT]: +def register( + primitive: jex.core.Primitive, *, precedence: int = 0 +) -> Callable[[CT], CT]: """Registers a multiple dispatch implementation for this JAX primitive. !!! Example @@ -47,8 +50,8 @@ def _(x: SomeValue, y: SomeValue): **Arguments:** - - `primitive`: The `jax.core.Primitive` to provide a multiple dispatch - implementation for. + - `primitive`: The `jax.extend.core.Primitive` to provide a multiple + dispatch implementation for. - `precedence`: The precedence of this rule. See `plum.Dispatcher.dispatch` for details. @@ -102,7 +105,7 @@ def full_lower(self): def _default_process( - primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params + primitive: jex.core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params ): defaults = set() for x in values: @@ -374,7 +377,9 @@ def aval(self) -> core.AbstractValue: @staticmethod def default( - primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params + primitive: jex.core.Primitive, + values: Sequence[Union[ArrayLike, "Value"]], + params, ) -> Union[ArrayLike, "Value", Sequence[Union[ArrayLike, "Value"]]]: """This is the default rule for when no rule has been [`quax.register`][]'d for a primitive. @@ -394,7 +399,7 @@ def default( **Arguments:** - - `primitive`: the `jax.core.Primitive` being considered. + - `primitive`: the `jax.extend.core.Primitive` being considered. - `values`: a sequence of what values this primitive is being called with. Each value can either be [`quax.Value`][]s, or a normal JAX arraylike (i.e. `bool`/`int`/`float`/`complex`/NumPy scalar/NumPy array/JAX array). @@ -519,7 +524,7 @@ def aval(self) -> core.ShapedArray: @register(jax._src.pjit.pjit_p) # pyright: ignore def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs): del kwargs - fun = quaxify(core.jaxpr_as_fun(jaxpr)) + fun = quaxify(jex.core.jaxpr_as_fun(jaxpr)) if inline: return fun(*args) else: @@ -541,9 +546,9 @@ def _( init_vals = args[cond_nconsts + body_nconsts :] # compute jaxpr of quaxified body and condition function - quax_cond_fn = quaxify(core.jaxpr_as_fun(cond_jaxpr)) + quax_cond_fn = quaxify(jex.core.jaxpr_as_fun(cond_jaxpr)) quax_cond_jaxpr = jax.make_jaxpr(quax_cond_fn)(*cond_consts, *init_vals) - quax_body_fn = quaxify(core.jaxpr_as_fun(body_jaxpr)) + quax_body_fn = quaxify(jex.core.jaxpr_as_fun(body_jaxpr)) quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals) cond_leaves, _ = jtu.tree_flatten(cond_consts) @@ -581,7 +586,7 @@ def _( def flat_quax_call(flat_args): args = jtu.tree_unflatten(in_tree, flat_args) - out = quaxify(core.jaxpr_as_fun(jaxpr))(*args) + out = quaxify(jex.core.jaxpr_as_fun(jaxpr))(*args) flat_out, out_tree = jtu.tree_flatten(out) out_trees.append(out_tree) return flat_out diff --git a/quax/examples/named/_core.py b/quax/examples/named/_core.py index aa68b88..691ddd2 100644 --- a/quax/examples/named/_core.py +++ b/quax/examples/named/_core.py @@ -4,6 +4,7 @@ import equinox as eqx import jax.core +import jax.extend as jex import jax.lax as lax import jax.numpy as jnp from jaxtyping import ArrayLike @@ -98,7 +99,7 @@ def _broadcast_axes(axes1, axes2): def _register_elementwise_binop( - op: Callable[[Any, Any], Any], prim: jax.core.Primitive + op: Callable[[Any, Any], Any], prim: jex.core.Primitive ): quax_op = quax.quaxify(op)