-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap(custom_jvp)
does not strip zeros from nondifferentiable return values, leading to AD crashes
#25724
Comments
Thanks for tracking this down and for the clear reproduction, @patrick-kidger! I haven't had a chance to dig in too deeply yet, but I think the key place where this issue hits is here: jax/jax/_src/interpreters/batching.py Lines 863 to 865 in 51b9fe3
which has an existing TODO for @mattjj and @froystig. To get consistent behavior under vmap, I think we would need to remove the I'll take a look soon! |
Testing the provided repro on Colab with JAX 0.5.0 does not result in a crash; instead, it produces the following output. [1 1]
(None, None) Please find the colab gist for reference. Thank you. |
Ah, interesting. Looks like things might have gotten fixed. I'll mention this in our upstream issue and see if things have been fixed there too. |
Whoops - sorry I totally dropped the ball on this one. I'm glad to see that it seems to be fixed (no thanks to me!). Let me know if we should close this issue or if there are further outstanding problems. |
Closing as per patrick-kidger/optimistix#104 (comment), looks like this is fixed! :D |
Description
This:
prints:
In contrast if the
jax.vmap
is removed, then there is noJVPTrace
at all -- it seems thatcustom_jvp
attempts to strip symbolic zeros from nondifferentiable return values, but that this is defeated by having a vmap wrapper.So (a) there is a discrepancy there, but (b) this is now a problem for downsteam nondifferentiable primitives! They see an AD tracer, they don't have an AD rule, they explode. And moreover this spurious tangent isn't removable with
lax.stop_gradient
because that skips over nondifferentiable types:jax/jax/_src/lax/lax.py
Lines 2047 to 2054 in 54fd738
!
I imagine the fix is either to allow
lax.stop_gradient
to operate on nondifferentiable types, or to adjustcustom_jvp
to have consistent behavior regardless of whether it is vmapped. (Or maybe both?)System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.38
The text was updated successfully, but these errors were encountered: