We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
filter_vmap
def foo(x): return x * 2. x = jnp.arange(24).reshape((2,3,4)) y = jax.vmap(foo, out_axes=-1)(x) z = eqx.filter_vmap(foo, out_axes=-1)(x) print("jax:", y.shape, "eqx:", z.shape) # jax: (3, 4, 2) eqx: (4, 3, 2)
equinox version: 0.11.8
The text was updated successfully, but these errors were encountered:
Fix filter_vmap with out_axes!=0,1 producing the wrong axis order.
596cb16
Fixes #900
Thank you for the report! I've just fixed this in #901.
This one is important enough that I'm also going to do a new Equinox release with the fix.
Sorry, something went wrong.
62dbaba
Successfully merging a pull request may close this issue.
equinox version: 0.11.8
The text was updated successfully, but these errors were encountered: