Hi all,
I’ve been trying to run HSSM models with GPU acceleration on my local machine (Linux, NVIDIA RTX 5080, CUDA 12.8). However, I ran into compatibility issues between NumPyro and the latest JAX releases:
- JAX ≥ 0.4.31 removed pjit_p, but older NumPyro (0.15.3) still references it.
- Updating to the newest JAX (0.7.x) breaks NumPyro; downgrading JAX to 0.4.x works but requires finding the exact matching jaxlib + CUDA/cuDNN combo.
- I tried combinations like jax==0.4.28 with jaxlib==0.4.28+cuda12.cudnn91 and 0.4.29+cuda12.cudnn91, but GPU initialization still fails (falling back to CPU).
Questions:
- Has anyone recently managed to get NumPyro + HSSM running on GPU with current JAX versions (0.4.x or newer)
- Which JAX/jaxlib/numpyro/PyMC version combinations worked for you?
- Has anyone successfully set this up on Windows (or is Linux the only reliable option right now)?
- Are there plans to update NumPyro for compatibility with JAX ≥ 0.5?
Any guidance or working environment specs would be greatly appreciated!