Jax sparse is much slower than scipy on CPU, is this expected? #30625
Replies: 2 comments 3 replies
-
Just to add some context - JAX Sparse tends to be slower than SciPy on CPU because it doesn’t leverage low-level libraries like MKL or OpenBLAS, which SciPy is tightly integrated with. JAX Sparse is more focused on supporting GPU/TPU acceleration via XLA and enabling autodiff within JAX workflows. For performance on CPU, SciPy is usually the better choice, while JAX Sparse is more suitable when working with accelerators or when differentiability is needed. This performance difference is a known and expected limitation at the current stage of development. |
Beta Was this translation helpful? Give feedback.
-
Yes, it's expected that
But the fact that you asked this question means that the docs didn't get that message across very effectively: if you have ideas of how to better communicate this, please let us know! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
My test scripts
My results on macstudio M2-ultra
jax sparse is super efficient on GPU though, my test on 5090 below:
is more than 20 times slower on cpu compared to scipy as expected?
Beta Was this translation helpful? Give feedback.
All reactions