Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Nested calls to accelerate should behave like nested jax.jit functions. #1086

Open
josh146 opened this issue Aug 30, 2024 · 0 comments

Comments

@josh146
Copy link
Member

josh146 commented Aug 30, 2024

For example, consider the following:

@qjit
def f(x):
    return accelerate(accelerate(jnp.sin))(x)

This works fine:

>>> f(0.43)
Array(0.4168708, dtype=float64)

However, differentiating it fails:

>>> qjit(grad(f))(0.43)
ValueError: Function vjp_wrapper must be jax.jit-able.But failed with error message Differentiation rule for 'python_callback' not implemented.

We should treat accelerate the same way that JAX treats jit; that is, a nested pjit context in the jaxpr. This will allow for arbitrary nesting of arbitrary jax-jitable code, without breaking gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant