How to speed up model fitting of CustomDist? #6661
-
| 
         Hello, I'm trying to implement Skew Student-T distribution using  def logp_skewt(value, nu, mu, sigma, alpha, *args, **kwargs):
    return (
        pm.math.log(2) + 
        pm.logp(pm.StudentT.dist(nu, mu=mu, sigma=sigma), value) + 
        pm.logcdf(pm.StudentT.dist(nu, mu=mu, sigma=sigma), alpha*value) - 
        pm.math.log(sigma)
    )I am able to sample from this distribution with pm.Model():
    pm.CustomDist('target', 1, 0, 3, -10, logp=logp_skewt)
    model_trace = pm.sample(
        nuts_sampler="numpyro",
        draws=2_000,
        chains=1,
    )
samples = model_trace.posterior.target.to_numpy()
eps = 0.01
min_val, max_val = np.quantile(samples, [eps, 1 - eps])
valid_samples = samples[(samples >= min_val) & (samples <= max_val)]However, when I try to re-fit the model, it became very slow with pm.Model() as fitted_model:
    nu = pm.HalfCauchy('nu', beta=1)
    mu = pm.Normal('mu', mu=0, sigma=1)
    sigma = pm.HalfCauchy('sigma', beta=1)
    alpha = pm.Normal('alpha', mu=0, sigma=1)
    
    skewt = pm.CustomDist('likelihood', nu + eps, mu, sigma + eps, alpha, logp=logp_skewt, observed=valid_samples[:1000])
    
    model_trace = pm.sample(
        nuts_sampler="pymc",
        draws=100,
        tune=100,
        chains=1,
    )There are warnings which are It took about 16 minutes to finish fitting on  So  is there a common way to speedup the computation?  | 
  
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
| 
         We are working in speeding up these type of gradients in pymc-devs/pytensor#174 Right now they are implemented in Numpy and can't be compiled to JAX. I will try to push that PR over the finish line sometime in the next weeks. For now, if you want to speed them you might need to re-implement the Ops manually in your target backend which isn't trivial if you are not familiar with PyTensor and/or JAX  | 
  
Beta Was this translation helpful? Give feedback.
We are working in speeding up these type of gradients in pymc-devs/pytensor#174
Right now they are implemented in Numpy and can't be compiled to JAX.
I will try to push that PR over the finish line sometime in the next weeks.
For now, if you want to speed them you might need to re-implement the Ops manually in your target backend which isn't trivial if you are not familiar with PyTensor and/or JAX