Is GPU support available for running the HSSM model on Windows? #803
Replies: 3 comments 1 reply
-
|
Hi, Are you using JAX-based samplers (e.g. Thanks! |
Beta Was this translation helpful? Give feedback.
-
|
Hello, I have successfully set up a Conda environment in WSL2 to run HSSM with GPU acceleration (JAX + NumPyro). This works perfectly for models like model="angle" (LAN backend), giving me massive speedups. However, when I try to run model="full_ddm" (to include trial-wise variability sv, st, sz), it fails to use the GPU/NUTS sampler and throws this error: ValueError: Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability. My understanding (research and AI): I suspect this is because full_ddm in HSSM relies on the Blackbox backend (since analytical gradients/LANs might not be available for the full variable DDM?). Because Blackbox likelihoods are non-differentiable (no gradients), the NUTS sampler—which drives the JAX/GPU acceleration—cannot run. My Question: Is there any way to run full_ddm (with sv, st, sz) on the GPU? Or effectively, does the requirement for a non-gradient sampler (like Slice) force us back to the CPU for this specific model type? Any advice on getting "Full DDM" performance closer to "Angle/LAN" performance would be appreciated! |
Beta Was this translation helpful? Give feedback.
-
|
It is possible to use a LAN for the differentiable likelihood for full DDM
(like in the 2021 Fengler paper), which would allow NUTS sampling on GPU,
but adding a robust LAN for this model into HSSM is still on the to-do
list..
…On Thu, Jan 8, 2026 at 9:53 AM Paul Xu ***@***.***> wrote:
Hello @Pessegir <https://github.com/Pessegir>,
Unfortunately, the ddm_full likelihood function is implemented in Cython
and does not support sampling on GPU or any gradient-based sampling
methods. We only included this black-box likelihood function for
compatibility with the legacy HDDM package. The closest thing we have is
ddm_sdv that includes only sv if st and sz are not strictly necessary.
—
Reply to this email directly, view it on GitHub
<#803 (reply in thread)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAG7TFBKCK66TZXAUK3PIPT4FZVODAVCNFSM6AAAAACG2FA7USVHI2DSMVQWIX3LMV43URDJONRXK43TNFXW4Q3PNVWWK3TUHMYTKNBUGU4DCOA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I installed the hssm model using a conda virtual environment. However, running the data on Windows is very slow. How can I improve the speed? My system is equipped with an NVIDIA RTX 5070 and CUDA 13.0.
Beta Was this translation helpful? Give feedback.
All reactions