-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache initialization fails when a JAX Array is created before enabling local cache #25768
Comments
I think this is working as expected: the compilation cache state must be enabled before backend initialization. Many of the JAX configuration options have similar requirements. I'm assigning @skye who knows more about this code path and may be able to confirm. |
Thanks for reporting this issue. We have an idea of how to address it and will work on a fix. |
This issue has been resolved with PR #25889. Closing now. |
Description
The persistent compilation cache in JAX fails to initialize if a JAX array is created prior to activating the local cache using jax.config.update. Removing the array creation line allows the cache to initialize correctly.
MRE with array allocation:
Full Log
This issue is present also with
ClassVar
default values if they are JAX NumPy arrays and with default arguments of functions. (See also ami-iit/jaxsim#322 and ami-iit/jaxsim#329)MRE with
ClassVar
:MRE with default arguments:
This was quite hard to spot for me, so I would expect a more clear error message if for some reason the cache cannot be initialized.
Thank you for your help!
FYI @traversaro @xela-95 @CarlottaSartore @paLeziart
System info (python version, jaxlib version, accelerator, etc.)
pip list
The text was updated successfully, but these errors were encountered: