Skip to content

Possibly improve NUTS tuning #689

@amal-ghamdi

Description

@amal-ghamdi

General description of task

The demo https://github.com/CUQI-DTU/CUQIpy/blob/main/demos/demo37_experimental_mcmc_module_Gibbs.ipynb uses NUTS within Gibbs. For the NUTS example, before the fix in #621, the inferred signal is (exec. time ~20 min)

Image

After this fix, the signal is (exec. time ~ 20 min)

Image

Note that this fix not necessarily improve the inference (both inferences looks a bit similar). However, for small number of steps, the old implementation seems to perform better (10 steps).

num_sampling_steps = {
    "x" : 1,
    "s" : 1
}

sampler = HybridGibbs(target, sampling_strategy, num_sampling_steps)

sampler.warmup(10)
sampler.sample(10)

before fix:
Image

After fix:
Image

This is likely because tuning in NUTS suggests worse time step than the method _find_good_epsilon which is called at the beginning of sampling. Before fixing NUTS statefulness this method seems to be called every time the sampler is re-initialized (each Gibbs step)

Definition of Done (Feature/change)

  • description1
  • description2

Definition of Done (Mandatory)

  • Documentation added (docstrings on all public methods/classes)
  • Unit tests added/updated (and passing!)
  • Code reviewed and approved by 2 CUQI-maintainers
  • Online documentation is rendered correctly (esp. math)

Definition of Done (Optional)

  • HowTo/tutorial added

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions