Skip to content

Has anyone successfully used HSSM/Numpyro with GPU on recent JAX versions (Linux/Windows)? #778

@JamesWeiChen

Description

@JamesWeiChen

Hi all,

I’ve been trying to run HSSM models with GPU acceleration on my local machine (Linux, NVIDIA RTX 5080, CUDA 12.8). However, I ran into compatibility issues between NumPyro and the latest JAX releases:

  • JAX ≥ 0.4.31 removed pjit_p, but older NumPyro (0.15.3) still references it.
  • Updating to the newest JAX (0.7.x) breaks NumPyro; downgrading JAX to 0.4.x works but requires finding the exact matching jaxlib + CUDA/cuDNN combo.
  • I tried combinations like jax==0.4.28 with jaxlib==0.4.28+cuda12.cudnn91 and 0.4.29+cuda12.cudnn91, but GPU initialization still fails (falling back to CPU).

Questions:

  1. Has anyone recently managed to get NumPyro + HSSM running on GPU with current JAX versions (0.4.x or newer)
  2. Which JAX/jaxlib/numpyro/PyMC version combinations worked for you?
  3. Has anyone successfully set this up on Windows (or is Linux the only reliable option right now)?
  4. Are there plans to update NumPyro for compatibility with JAX ≥ 0.5?

Any guidance or working environment specs would be greatly appreciated!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions