You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
>>> qjit(grad(f))(0.43)
ValueError: Function vjp_wrapper must be jax.jit-able.But failed with error message Differentiation rule for 'python_callback' not implemented.
We should treat accelerate the same way that JAX treats jit; that is, a nested pjit context in the jaxpr. This will allow for arbitrary nesting of arbitrary jax-jitable code, without breaking gradients.
The text was updated successfully, but these errors were encountered:
For example, consider the following:
This works fine:
>>> f(0.43) Array(0.4168708, dtype=float64)
However, differentiating it fails:
>>> qjit(grad(f))(0.43) ValueError: Function vjp_wrapper must be jax.jit-able.But failed with error message Differentiation rule for 'python_callback' not implemented.
We should treat
accelerate
the same way that JAX treatsjit
; that is, a nestedpjit
context in the jaxpr. This will allow for arbitrary nesting of arbitrary jax-jitable code, without breaking gradients.The text was updated successfully, but these errors were encountered: