Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix breaking changes introduced in JAX 0.4.36. #907

Merged
merged 2 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ jobs:
run-test:
strategy:
matrix:
python-version: [ 3.9, 3.11 ]
# must match the `language_version` in `.pre-commit-config.yaml`
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ repos:
rev: v1.1.379
hooks:
- id: pyright
# must match the Python version used in CI
language_version: python3.11
additional_dependencies:
[
beartype,
Expand Down
4 changes: 3 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def filter_jvp(
flat_tangents = jtu.tree_leaves(tangents) # all non-None tangents are dynamic

def _fn(*_flat_dynamic):
_main = jax.core.find_top_trace(_flat_dynamic).main
_top_trace = jax.core.find_top_trace(_flat_dynamic)
assert _top_trace is not None
_main = _top_trace.main
_dynamic = jtu.tree_unflatten(treedef, _flat_dynamic)
_in = combine(_dynamic, static_primals)
_out = fn(*_in, **kwargs)
Expand Down
57 changes: 3 additions & 54 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.interpreters.batching as batching
import jax.interpreters.pxla as pxla
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand Down Expand Up @@ -78,53 +76,6 @@ def __call__(self, x: Any) -> Optional[int]:
return self.axis if is_array(x) else None


@dataclasses.dataclass(frozen=True) # not a pytree
class if_mapped:
"""Used with the `out_axes` argument of [`equinox.filter_vmap`][], to only add an
output batch axis if necessary.
That is, `out_axes=if_mapped(i)` is equivalent to `out_axes=i` for any output that
is batched, and `out_axes=None` fofr any output that is not batched.
"""

axis: int

def __call__(self, x: Any):
raise RuntimeError(
"`eqx.internal.if_mapped` should not be called directly; it is only valid "
"when passed to `out_axes` of `eqx.filter_vmap`."
)


@dataclasses.dataclass(frozen=True) # not a pytree
class _if_mapped:
main: Any
axis: int

def __call__(self, x: Any) -> Optional[int]:
if isinstance(x, batching.BatchTracer) and x._trace.main is self.main:
if x.batch_dim is batching.not_mapped:
return None
else:
return self.axis
elif isinstance(x, pxla.MapTracer) and x._trace.main is self.main:
return self.axis
else:
return None


# The existence of this function is a complete hack: it couples together `filter_vmap`
# with `if_mapped`. I don't see an obvious way around it though.
def _bind_main(main, out_axes):
def _bind(axis):
if isinstance(axis, if_mapped):
return _if_mapped(main, axis.axis)
else:
return axis

return jtu.tree_map(_bind, out_axes)


def _moveaxis(array, axis):
return jnp.moveaxis(array, 0, axis)

Expand Down Expand Up @@ -199,11 +150,9 @@ def __call__(self, /, *args, **kwargs):
static_args, dynamic_args = partition(args, unmapped_axis)

def _fun_wrapper(_dynamic_args):
_main = jax.core.find_top_trace(jtu.tree_leaves(_dynamic_args)).main
_args = combine(_dynamic_args, static_args)
_out = self._fun(*_args)
_out_axes = _bind_main(_main, self._out_axes)
_out_axes = _resolve_axes(_out, _out_axes)
_out_axes = _resolve_axes(_out, self._out_axes)
_none_axes = jtu.tree_map(_is_none, _out_axes, is_leaf=_is_none)
_nonvmapd, _vmapd = partition(_out, _none_axes, is_leaf=_is_none)
_nonvmapd_arr, _nonvmapd_static = partition(_nonvmapd, is_array)
Expand Down Expand Up @@ -235,6 +184,7 @@ def _fun_wrapper(_dynamic_args):
return combine(vmapd, nonvmapd)

def __get__(self, instance, owner):
del owner
if instance is None:
return self
return Partial(self, instance)
Expand Down Expand Up @@ -439,10 +389,8 @@ def _check_map_out_axis(x: Optional[int]):
)

def fun_wrapped(_dynamic):
_main = jax.core.find_top_trace(jtu.tree_leaves(_dynamic))
_fun, _args, _, _out_axes = combine(_dynamic, static)
_out = _fun(*_args)
_out_axes = _bind_main(_main, _out_axes)
_out_axes = _resolve_axes(_out, _out_axes)
jtu.tree_map(_check_map_out_axis, _out_axes)
_pmapd = []
Expand Down Expand Up @@ -558,6 +506,7 @@ def lower(self, /, *args, **kwargs) -> Lowered:
return self._call(True, args, kwargs)

def __get__(self, instance, owner):
del owner
if instance is None:
return self
return Partial(self, instance)
Expand Down
1 change: 0 additions & 1 deletion equinox/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
unvmap_max as unvmap_max,
unvmap_max_p as unvmap_max_p,
)
from .._vmap_pmap import if_mapped as if_mapped

# Backward compatibility: expose via `equinox.internal`. Now available under
# `equinox.debug`.
Expand Down
23 changes: 16 additions & 7 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,26 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
bp, bx, by = batch_axes
if bp is batching.not_mapped:
if bx is batching.not_mapped:
x = jnp.broadcast_to(x, (axis_size,) + x.shape)
else:
x = jnp.moveaxis(x, bx, 0)
if by is batching.not_mapped:
y = jnp.broadcast_to(y, (axis_size,) + y.shape)
if by is batching.not_mapped:
out_axis = None
else:
x = jnp.broadcast_to(x, (axis_size,) + x.shape)
y = jnp.moveaxis(y, by, 0)
out_axis = 0
else:
y = jnp.moveaxis(y, by, 0)
if by is batching.not_mapped:
x = jnp.moveaxis(x, bx, 0)
y = jnp.broadcast_to(y, (axis_size,) + y.shape)
out_axis = 0
else:
x = jnp.moveaxis(x, bx, 0)
y = jnp.moveaxis(y, by, 0)
out_axis = 0
out = _select_if_vmap(pred, x, y, makes_false_steps=False)
else:
out = jax.vmap(lax.select, in_axes=(bp, bx, by))(pred, x, y)
return out, 0
out_axis = 0
return out, out_axis


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
Expand Down
24 changes: 14 additions & 10 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,20 @@ def create_vprim(name: str, impl, abstract_eval, jvp, transpose):

def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
del trace_type
# delegates batching to `_vprim_p`
out = _vprim_p.bind(
*inputs,
prim=prim,
__axis_size=axis_size,
__axis_name=axis_name,
__batch_axes=batch_axes,
params=params,
)
batch_axes_out = jtu.tree_map(lambda _: 0, out)
if all(b is batching.not_mapped for b in jtu.tree_leaves(batch_axes)):
out = prim.bind(*inputs, **params)
batch_axes_out = jtu.tree_map(lambda _: batching.not_mapped, out)
else:
# delegates batching to `_vprim_p`
out = _vprim_p.bind(
*inputs,
prim=prim,
__axis_size=axis_size,
__axis_name=axis_name,
__batch_axes=batch_axes,
params=params,
)
batch_axes_out = jtu.tree_map(lambda _: 0, out)
return out, batch_axes_out

prim.def_impl(impl)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "equinox"
version = "0.11.9"
version = "0.11.10"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
Expand Down
17 changes: 0 additions & 17 deletions tests/test_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,23 +272,6 @@ def f(x, y):
compiled(x, y)


def test_double_if_mapped():
out_axes = eqx.internal.if_mapped(1)

def f(x):
assert x.shape == (3, 1)

def g(y):
assert y.shape == (1,)
return y + 1, x + 1

a, b = eqx.filter_vmap(g, out_axes=out_axes)(x)
assert a.shape == (1, 3)
assert b.shape == (3, 1)

filter_pmap(f)(jnp.arange(3).reshape(1, 3, 1))


# https://github.com/patrick-kidger/equinox/issues/900
# Unlike the vmap case we only test nonnegative integers, as pmap does not support
# negative indexing for `in_axes` or `out_axes`.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ def fn(x):
def test_vprim():
def impl(x):
assert x.shape == (2,)
return 2 * x, jnp.concatenate([x, jnp.flip(x)])
return [2 * x, jnp.concatenate([x, jnp.flip(x)])]

def abstract(x):
assert type(x) is jax.core.ShapedArray
return x, jax.core.ShapedArray((4,), x.dtype)
return [x, jax.core.ShapedArray((4,), x.dtype)]

def jvp(primals, tangents):
(x,) = primals
Expand Down
17 changes: 0 additions & 17 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,6 @@ def test_keyword_default(getkey):
eqx.filter_vmap(lambda x, y=1: x, in_axes=dict(y=0))(x)


def test_double_if_mapped():
out_axes = eqx.internal.if_mapped(1)

def f(x):
assert x.shape == (3, 1)

def g(y):
assert y.shape == (1,)
return y + 1, x + 1

a, b = eqx.filter_vmap(g, out_axes=out_axes)(x)
assert a.shape == (1, 3)
assert b.shape == (3, 1)

eqx.filter_vmap(f)(jnp.arange(6).reshape(2, 3, 1))


# https://github.com/patrick-kidger/equinox/issues/900
@pytest.mark.parametrize("out_axes", (0, 1, 2, -1, -2, -3))
def test_out_axes_with_at_least_three_dimensions(out_axes):
Expand Down
Loading