Skip to content

Commit

Permalink
Fix filter_vmap with out_axes!=0,1 producing the wrong axis order.
Browse files Browse the repository at this point in the history
Fixes #900
  • Loading branch information
patrick-kidger committed Nov 24, 2024
1 parent 15a800d commit 62dbaba
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
6 changes: 3 additions & 3 deletions equinox/_vmap_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def _bind(axis):
return jtu.tree_map(_bind, out_axes)


def _swapaxes(array, axis):
return jnp.swapaxes(array, 0, axis)
def _moveaxis(array, axis):
return jnp.moveaxis(array, 0, axis)


def _named_in_axes(fun, in_axes, args):
Expand Down Expand Up @@ -230,7 +230,7 @@ def _fun_wrapper(_dynamic_args):
nonvmapd = combine(nonvmapd_arr, nonvmapd_static)

assert jtu.tree_structure(vmapd) == jtu.tree_structure(out_axes)
vmapd = jtu.tree_map(_swapaxes, vmapd, out_axes)
vmapd = jtu.tree_map(_moveaxis, vmapd, out_axes)

return combine(vmapd, nonvmapd)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "equinox"
version = "0.11.8"
version = "0.11.9"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
Expand Down
15 changes: 15 additions & 0 deletions tests/test_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,18 @@ def g(y):
assert b.shape == (3, 1)

filter_pmap(f)(jnp.arange(3).reshape(1, 3, 1))


# https://github.com/patrick-kidger/equinox/issues/900
# Unlike the vmap case we only test nonnegative integers, as pmap does not support
# negative indexing for `in_axes` or `out_axes`.
@pytest.mark.parametrize("out_axes", (0, 1, 2))
def test_out_axes_with_at_least_three_dimensions(out_axes):
def foo(x):
return x * 2

x = jnp.arange(24).reshape((1, 2, 3, 4))
y = jax.pmap(foo, out_axes=out_axes)(x)
z = filter_pmap(foo, out_axes=out_axes)(x)
assert y.shape == z.shape
assert (y == z).all()
13 changes: 13 additions & 0 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,16 @@ def g(y):
assert b.shape == (3, 1)

eqx.filter_vmap(f)(jnp.arange(6).reshape(2, 3, 1))


# https://github.com/patrick-kidger/equinox/issues/900
@pytest.mark.parametrize("out_axes", (0, 1, 2, -1, -2, -3))
def test_out_axes_with_at_least_three_dimensions(out_axes):
def foo(x):
return x * 2

x = jnp.arange(24).reshape((2, 3, 4))
y = jax.vmap(foo, out_axes=out_axes)(x)
z = eqx.filter_vmap(foo, out_axes=out_axes)(x)
assert y.shape == z.shape
assert (y == z).all()

0 comments on commit 62dbaba

Please sign in to comment.