From d21552ac4c504d7b139ad8e4f15d5f102b54d705 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 14 Dec 2024 00:35:55 +0100 Subject: [PATCH] Fix #118 --- tests/test_vmap_vmap.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_vmap_vmap.py b/tests/test_vmap_vmap.py index d0bc01a..8e2b98d 100644 --- a/tests/test_vmap_vmap.py +++ b/tests/test_vmap_vmap.py @@ -15,7 +15,6 @@ import functools as ft import equinox as eqx -import equinox.internal as eqxi import jax.numpy as jnp import jax.random as jr import lineax as lx @@ -123,10 +122,10 @@ def linear_solve(operator, vector): eqx.filter_vmap( lambda x: x.as_matrix(), in_axes=vmap1_op, - out_axes=eqxi.if_mapped(0), + out_axes=None if vmap1_op is None else 0, ), in_axes=vmap2_op, - out_axes=eqxi.if_mapped(0), + out_axes=None if vmap2_op is None else 0, )(operator) vmap1_axes = (vmap1_op, vmap1_vec)