Skip to content

Commit

Permalink
refactor(jax.extend): update from deprecated imports
Browse files Browse the repository at this point in the history
style(jax.extend): alias to jex

Signed-off-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
nstarman authored and patrick-kidger committed Feb 4, 2025
1 parent 37dc139 commit 3b50d2d
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/examples/default_rules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"\n",
" @staticmethod\n",
" def default(\n",
" primitive: jax.core.Primitive,\n",
" primitive: jax.extend.core.Primitive,\n",
" values: Sequence[Union[ArrayLike, quax.Value]],\n",
" params: dict,\n",
" ):\n",
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ select = ["E", "F", "I001"]
ignore = ["E402", "E721", "E731", "E741", "F722"]
fixable = ["I001", "F401"]

[tool.ruff.lint.flake8-import-conventions.extend-aliases]
"jax.extend" = "jex"

[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
Expand Down
27 changes: 16 additions & 11 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import jax._src
import jax.core as core
import jax.extend as jex
import jax.extend.linear_util as lu
import jax.numpy as jnp
import jax.tree_util as jtu
Expand All @@ -24,10 +25,12 @@
#


_rules: dict[core.Primitive, plum.Function] = {}
_rules: dict[jex.core.Primitive, plum.Function] = {}


def register(primitive: core.Primitive, *, precedence: int = 0) -> Callable[[CT], CT]:
def register(
primitive: jex.core.Primitive, *, precedence: int = 0
) -> Callable[[CT], CT]:
"""Registers a multiple dispatch implementation for this JAX primitive.
!!! Example
Expand All @@ -47,8 +50,8 @@ def _(x: SomeValue, y: SomeValue):
**Arguments:**
- `primitive`: The `jax.core.Primitive` to provide a multiple dispatch
implementation for.
- `primitive`: The `jax.extend.core.Primitive` to provide a multiple
dispatch implementation for.
- `precedence`: The precedence of this rule.
See `plum.Dispatcher.dispatch` for details.
Expand Down Expand Up @@ -102,7 +105,7 @@ def full_lower(self):


def _default_process(
primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
primitive: jex.core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
):
defaults = set()
for x in values:
Expand Down Expand Up @@ -374,7 +377,9 @@ def aval(self) -> core.AbstractValue:

@staticmethod
def default(
primitive: core.Primitive, values: Sequence[Union[ArrayLike, "Value"]], params
primitive: jex.core.Primitive,
values: Sequence[Union[ArrayLike, "Value"]],
params,
) -> Union[ArrayLike, "Value", Sequence[Union[ArrayLike, "Value"]]]:
"""This is the default rule for when no rule has been [`quax.register`][]'d for
a primitive.
Expand All @@ -394,7 +399,7 @@ def default(
**Arguments:**
- `primitive`: the `jax.core.Primitive` being considered.
- `primitive`: the `jax.extend.core.Primitive` being considered.
- `values`: a sequence of what values this primitive is being called with. Each
value can either be [`quax.Value`][]s, or a normal JAX arraylike (i.e.
`bool`/`int`/`float`/`complex`/NumPy scalar/NumPy array/JAX array).
Expand Down Expand Up @@ -519,7 +524,7 @@ def aval(self) -> core.ShapedArray:
@register(jax._src.pjit.pjit_p) # pyright: ignore
def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
del kwargs
fun = quaxify(core.jaxpr_as_fun(jaxpr))
fun = quaxify(jex.core.jaxpr_as_fun(jaxpr))
if inline:
return fun(*args)
else:
Expand All @@ -541,9 +546,9 @@ def _(
init_vals = args[cond_nconsts + body_nconsts :]

# compute jaxpr of quaxified body and condition function
quax_cond_fn = quaxify(core.jaxpr_as_fun(cond_jaxpr))
quax_cond_fn = quaxify(jex.core.jaxpr_as_fun(cond_jaxpr))
quax_cond_jaxpr = jax.make_jaxpr(quax_cond_fn)(*cond_consts, *init_vals)
quax_body_fn = quaxify(core.jaxpr_as_fun(body_jaxpr))
quax_body_fn = quaxify(jex.core.jaxpr_as_fun(body_jaxpr))
quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals)

cond_leaves, _ = jtu.tree_flatten(cond_consts)
Expand Down Expand Up @@ -581,7 +586,7 @@ def _(

def flat_quax_call(flat_args):
args = jtu.tree_unflatten(in_tree, flat_args)
out = quaxify(core.jaxpr_as_fun(jaxpr))(*args)
out = quaxify(jex.core.jaxpr_as_fun(jaxpr))(*args)
flat_out, out_tree = jtu.tree_flatten(out)
out_trees.append(out_tree)
return flat_out
Expand Down
3 changes: 2 additions & 1 deletion quax/examples/named/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import equinox as eqx
import jax.core
import jax.extend as jex
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import ArrayLike
Expand Down Expand Up @@ -98,7 +99,7 @@ def _broadcast_axes(axes1, axes2):


def _register_elementwise_binop(
op: Callable[[Any, Any], Any], prim: jax.core.Primitive
op: Callable[[Any, Any], Any], prim: jex.core.Primitive
):
quax_op = quax.quaxify(op)

Expand Down

0 comments on commit 3b50d2d

Please sign in to comment.