-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Regression] Slower integration of differential equations since jaxlib > 0.4.32.dev20240807 #518
Comments
I'd definitely open this as an issue on the JAX GitHub! Probably this will be due to some change in the XLA compiler between those two builds, which unfortunately means there isn't much we can do about it from Diffrax. If you're interested in digging into it then you might be able to locate the appropriate commit by bisecting through the XLA repo (https://github.com/openxla/xla/). I believe JAX itself hosts several benchmarks so probably something can be added to those to prevent a regression again afterwards. |
I've just seen a suggested fix over here: jax-ml/jax#24501 Give it a try? |
Yes, with the suggested fix, I observe similar runtimes again. A bit unrelated, I can't test it for the newest nightly Jax version, as |
Great that the source of the slowdown has been fixed. Thanks for the heads-up on |
Between
jaxlib==0.4.32.dev20240807
andjaxlib==0.4.32.dev20240812
I observe a significant decrease of performance for integration of differential equations with many solver steps (up to 8x slower). Minimal example:Tested on Ubuntu 22.04 and CPU backend. Runtime is on my PC 1.8 ms for the nightly jaxlib version 20240807 and 14.8ms for version 20240812. The difference is the largest if
t_steps
is quite large.To test it quickly, I used the following one-liner:
uv venv --python 3.12 && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240807 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py && uv pip install diffrax numpy --pre jax==0.4.32.dev20240807 jaxlib==0.4.32.dev20240812 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --reinstall --exclude-newer 2024-09-20 && uv run test_diffrax.py
, wheretest_diffrax.py
is the script above.I don't know whether I should have better opened the issue on the JAX Github, let me know if it isn't correct here.
The text was updated successfully, but these errors were encountered: