diff --git a/tests/test_vmap_vmap.py b/tests/test_vmap_vmap.py index d0bc01a..8e2b98d 100644 --- a/tests/test_vmap_vmap.py +++ b/tests/test_vmap_vmap.py @@ -15,7 +15,6 @@ import functools as ft import equinox as eqx -import equinox.internal as eqxi import jax.numpy as jnp import jax.random as jr import lineax as lx @@ -123,10 +122,10 @@ def linear_solve(operator, vector): eqx.filter_vmap( lambda x: x.as_matrix(), in_axes=vmap1_op, - out_axes=eqxi.if_mapped(0), + out_axes=None if vmap1_op is None else 0, ), in_axes=vmap2_op, - out_axes=eqxi.if_mapped(0), + out_axes=None if vmap2_op is None else 0, )(operator) vmap1_axes = (vmap1_op, vmap1_vec)