From 9f3a119684731ad241e60a3fbcac45bc9a36d0e7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 16 Jul 2025 11:56:26 +0200 Subject: [PATCH 1/7] Fix missing import --- pymc/logprob/abstract.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 3e0ad532f..4b8808a3b 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -35,6 +35,7 @@ # SOFTWARE. import abc +import warnings from collections.abc import Sequence from functools import singledispatch From 7164bc62ab48fc6d29f46df4727e6a85751fbeaf Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 5 May 2025 10:50:47 +0200 Subject: [PATCH 2/7] Fix progressbar with nested compound step samplers --- pymc/step_methods/compound.py | 23 ++++++++------------ pymc/step_methods/hmc/nuts.py | 16 ++++---------- pymc/step_methods/metropolis.py | 20 +++++++---------- pymc/step_methods/slicer.py | 15 ++++--------- pymc/util.py | 38 ++++++++++++++++++++++++--------- tests/test_util.py | 32 +++++++++++++++++++++++++++ 6 files changed, 85 insertions(+), 59 deletions(-) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0..50b23b1d4 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - return stats + def _make_update_stats_functions(): + def update_stats(step_stats): + return step_stats - return update_stats + return (update_stats,) # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): @@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_function(self): - update_fns = [method._make_update_stats_function() for method in self.methods] - - def update_stats(stats, step_stats, chain_idx): - for step_stat, update_fn in zip(step_stats, update_fns): - stats = update_fn(stats, step_stat, chain_idx) - - return stats - - return update_stats + def _make_update_stats_functions(self): + update_functions = [] + for method in self.methods: + update_functions.extend(method._make_update_stats_functions()) + return update_functions def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 18707c359..334a4eac3 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -248,19 +248,11 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_update_stats_functions(): + def update_stats(stats): + return {key: stats[key] for key in ("diverging", "step_size", "tree_size")} - if not step_stats["tune"]: - stats["divergences"][chain_idx] += step_stats["diverging"] - - stats["step_size"][chain_idx] = step_stats["step_size"] - stats["tree_size"][chain_idx] = step_stats["tree_size"] - return stats - - return update_stats + return (update_stats,) # A proposal for the next position diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653..4d798e947 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] - - stats["tune"][chain_idx] = step_stats["tune"] - stats["accept_rate"][chain_idx] = step_stats["accept"] - stats["scaling"][chain_idx] = step_stats["scaling"] - - return stats - - return update_stats + def _make_update_stats_functions(): + def update_stats(step_stats): + return { + "accept_rate" if key == "accept" else key: step_stats[key] + for key in ("tune", "accept", "scaling") + } + + return (update_stats,) def tune(scale, acc_rate): diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf..ef5bbebc4 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_update_stats_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} - stats["tune"][chain_idx] = step_stats["tune"] - stats["nstep_out"][chain_idx] = step_stats["nstep_out"] - stats["nstep_in"][chain_idx] = step_stats["nstep_in"] - - return stats - - return update_stats + return (update_stats,) diff --git a/pymc/util.py b/pymc/util.py index ad9256add..08e22e106 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -763,9 +763,8 @@ def __init__( progressbar=progressbar, progressbar_theme=progressbar_theme, ) - self.progress_stats = progress_stats - self.update_stats = step_method._make_update_stats_function() + self.update_stats_functions = step_method._make_update_stats_functions() self._show_progress = show_progress self.divergences = 0 @@ -829,12 +828,31 @@ def update(self, chain_idx, is_last, draw, tuning, stats): if not tuning and stats and stats[0].get("diverging"): self.divergences += 1 - self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) - more_updates = ( - {stat: value[chain_idx] for stat, value in self.progress_stats.items()} - if self.full_stats - else {} - ) + if self.full_stats: + # TODO: Index by chain already? + chain_progress_stats = [ + update_states_fn(step_stats) + for update_states_fn, step_stats in zip( + self.update_stats_functions, stats, strict=True + ) + ] + all_step_stats = {} + for step_stats in chain_progress_stats: + for key, val in step_stats.items(): + if key in all_step_stats: + # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now + continue + else: + all_step_stats[key] = val + + else: + all_step_stats = {} + + # more_updates = ( + # {stat: value[chain_idx] for stat, value in progress_stats.items()} + # if self.full_stats + # else {} + # ) self._progress.update( self.tasks[chain_idx], @@ -842,14 +860,14 @@ def update(self, chain_idx, is_last, draw, tuning, stats): draws=draw, sampling_speed=speed, speed_unit=unit, - **more_updates, + **all_step_stats, ) if is_last: self._progress.update( self.tasks[chain_idx], draws=draw + 1 if not self.combined_progress else draw, - **more_updates, + **all_step_stats, refresh=True, ) diff --git a/tests/test_util.py b/tests/test_util.py index 98cc168f0..cc07d75c2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -250,3 +250,35 @@ def test_get_value_vars_from_user_vars(): get_value_vars_from_user_vars([x2], model1) with pytest.raises(ValueError, match=rf"{prefix} \['det2'\]"): get_value_vars_from_user_vars([det2], model2) + + +def test_progressbar_nested_compound(): + # Regression test for https://github.com/pymc-devs/pymc/issues/7721 + + with pm.Model(): + a = pm.Poisson("a", mu=10) + b = pm.Binomial("b", n=a, p=0.8) + c = pm.Poisson("c", mu=11) + d = pm.Dirichlet("d", a=[c, b]) + + step = pm.CompoundStep( + [ + pm.CompoundStep([pm.Metropolis(a), pm.Metropolis(b), pm.Metropolis(c)]), + pm.NUTS([d]), + ] + ) + + kwargs = { + "draws": 10, + "tune": 10, + "chains": 2, + "compute_convergence_checks": False, + "step": step, + } + + # We don't parametrize to avoid recompiling the model functions + for cores in (1, 2): + pm.sample(**kwargs, cores=cores, progressbar=True) # default is split+stats + pm.sample(**kwargs, cores=cores, progressbar="combined") + pm.sample(**kwargs, cores=cores, progressbar="split") + pm.sample(**kwargs, cores=cores, progressbar=False) From fd230d39a24825a624cf6182e9d50d4a32484b54 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 15 Jul 2025 12:09:19 +0200 Subject: [PATCH 3/7] Move progressbar code to its own module --- .github/workflows/tests.yml | 1 + pymc/backends/arviz.py | 3 +- pymc/progress_bar.py | 429 ++++++++++++++++++++++++++++++++++ pymc/sampling/forward.py | 3 +- pymc/sampling/mcmc.py | 4 +- pymc/sampling/parallel.py | 3 +- pymc/sampling/population.py | 2 +- pymc/smc/sampling.py | 3 +- pymc/tuning/starting.py | 3 +- pymc/util.py | 420 +-------------------------------- pymc/variational/inference.py | 2 +- tests/test_progress_bar.py | 46 ++++ tests/test_util.py | 32 --- 13 files changed, 488 insertions(+), 463 deletions(-) create mode 100644 pymc/progress_bar.py create mode 100644 tests/test_progress_bar.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d9b4b000c..efbb457f7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,6 +69,7 @@ jobs: tests/distributions/test_shape_utils.py tests/distributions/test_mixture.py tests/test_testing.py + tests/test_progress_bar.py - | tests/distributions/test_continuous.py diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index f0f0eec96..71f08da82 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -39,8 +39,9 @@ import pymc from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import CustomProgress, default_progress_theme, get_default_varnames +from pymc.util import get_default_varnames if TYPE_CHECKING: from pymc.backends.base import MultiTrace diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py new file mode 100644 index 000000000..655560b08 --- /dev/null +++ b/pymc/progress_bar.py @@ -0,0 +1,429 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import TYPE_CHECKING, Literal + +from rich.box import SIMPLE_HEAD +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich.style import Style +from rich.table import Column, Table +from rich.theme import Theme + +if TYPE_CHECKING: + from pymc.step_methods.compound import BlockedStep, CompoundStep + +ProgressBarType = Literal[ + "combined", + "split", + "combined+stats", + "stats+combined", + "split+stats", + "stats+split", +] +default_progress_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", + } +) + + +class CustomProgress(Progress): + """A child of Progress that allows to disable progress bars and its container. + + The implementation simply checks an `is_enabled` flag and generates the progress bar only if + it's `True`. + """ + + def __init__(self, *args, disable=False, include_headers=False, **kwargs): + self.is_enabled = not disable + self.include_headers = include_headers + + if self.is_enabled: + super().__init__(*args, **kwargs) + + def __enter__(self): + """Enter the context manager.""" + if self.is_enabled: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager.""" + if self.is_enabled: + super().__exit__(exc_type, exc_val, exc_tb) + + def add_task(self, *args, **kwargs): + if self.is_enabled: + return super().add_task(*args, **kwargs) + return None + + def advance(self, task_id, advance=1) -> None: + if self.is_enabled: + super().advance(task_id, advance) + return None + + def update( + self, + task_id, + *, + total=None, + completed=None, + advance=None, + description=None, + visible=None, + refresh=False, + **fields, + ): + if self.is_enabled: + super().update( + task_id, + total=total, + completed=completed, + advance=advance, + description=description, + visible=visible, + refresh=refresh, + **fields, + ) + return None + + def make_tasks_table(self, tasks: Iterable[Task]) -> Table: + """Get a table to render the Progress display. + + Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. + + Parameters + ---------- + tasks: Iterable[Task] + An iterable of Task instances, one per row of the table. + + Returns + ------- + table: Table + A table instance. + """ + + def call_column(column, task): + # Subclass rich.BarColumn and add a callback method to dynamically update the display + if hasattr(column, "callbacks"): + column.callbacks(task) + + return column(task) + + table_columns = ( + ( + Column(no_wrap=True) + if isinstance(_column, str) + else _column.get_table_column().copy() + ) + for _column in self.columns + ) + if self.include_headers: + table = Table( + *table_columns, + padding=(0, 1), + expand=self.expand, + show_header=True, + show_edge=True, + box=SIMPLE_HEAD, + ) + else: + table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) + + for task in tasks: + if task.visible: + table.add_row( + *( + ( + column.format(task=task) + if isinstance(column, str) + else call_column(column, task) + ) + for column in self.columns + ) + ) + + return table + + +class DivergenceBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a divergence.""" + + def __init__(self, *args, diverging_color="red", **kwargs): + from matplotlib.colors import to_rgb + + self.diverging_color = diverging_color + self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] + + super().__init__(*args, **kwargs) + + self.non_diverging_style = self.complete_style + self.non_diverging_finished_style = self.finished_style + + def callbacks(self, task: "Task"): + divergences = task.fields.get("divergences", 0) + if isinstance(divergences, float | int) and divergences > 0: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + else: + self.complete_style = self.non_diverging_style + self.finished_style = self.non_diverging_finished_style + + +class ProgressBarManager: + """Manage progress bars displayed during sampling.""" + + def __init__( + self, + step_method: "BlockedStep | CompoundStep", + chains: int, + draws: int, + tune: int, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, + ): + """ + Manage progress bars displayed during sampling. + + When sampling, Step classes are responsible for computing and exposing statistics that can be reported on + progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` + and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics + that will be displayed on the progress bar. + + Parameters + ---------- + step_method: BlockedStep or CompoundStep + The step method being used to sample + chains: int + Number of chains being sampled + draws: int + Number of draws per chain + tune: int + Number of tuning steps per chain + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. + + progressbar_theme: Theme, optional + The theme to use for the progress bar. Defaults to the default theme. + """ + if progressbar_theme is None: + progressbar_theme = default_progress_theme + + match progressbar: + case True: + self.combined_progress = False + self.full_stats = True + show_progress = True + case False: + self.combined_progress = False + self.full_stats = True + show_progress = False + case "combined": + self.combined_progress = True + self.full_stats = False + show_progress = True + case "split": + self.combined_progress = False + self.full_stats = False + show_progress = True + case "combined+stats" | "stats+combined": + self.combined_progress = True + self.full_stats = True + show_progress = True + case "split+stats" | "stats+split": + self.combined_progress = False + self.full_stats = True + show_progress = True + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "one of 'combined', 'split', 'split+stats', or 'combined+stats." + ) + + progress_columns, progress_stats = step_method._progressbar_config(chains) + + self._progress = self.create_progress_bar( + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) + self.progress_stats = progress_stats + self.update_stats_functions = step_method._make_update_stats_functions() + + self._show_progress = show_progress + self.divergences = 0 + self.completed_draws = 0 + self.total_draws = draws + tune + self.desc = "Sampling chain" + self.chains = chains + + self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] + + def __enter__(self): + self._initialize_tasks() + + return self._progress.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._progress.__exit__(exc_type, exc_val, exc_tb) + + def _initialize_tasks(self): + if self.combined_progress: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws * self.chains - 1, + chain_idx=0, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[0] for stat, value in self.progress_stats.items()}, + ) + ] + + else: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws - 1, + chain_idx=chain_idx, + sampling_speed=0, + speed_unit="draws/s", + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, + ) + for chain_idx in range(self.chains) + ] + + def update(self, chain_idx, is_last, draw, tuning, stats): + if not self._show_progress: + return + + self.completed_draws += 1 + if self.combined_progress: + draw = self.completed_draws + chain_idx = 0 + + elapsed = self._progress.tasks[chain_idx].elapsed + speed, unit = compute_draw_speed(elapsed, draw) + + if not tuning and stats and stats[0].get("diverging"): + self.divergences += 1 + + if self.full_stats: + # TODO: Index by chain already? + chain_progress_stats = [ + update_states_fn(step_stats) + for update_states_fn, step_stats in zip( + self.update_stats_functions, stats, strict=True + ) + ] + all_step_stats = {} + for step_stats in chain_progress_stats: + for key, val in step_stats.items(): + if key in all_step_stats: + # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now + continue + else: + all_step_stats[key] = val + + else: + all_step_stats = {} + + # more_updates = ( + # {stat: value[chain_idx] for stat, value in progress_stats.items()} + # if self.full_stats + # else {} + # ) + + self._progress.update( + self.tasks[chain_idx], + completed=draw, + draws=draw, + sampling_speed=speed, + speed_unit=unit, + **all_step_stats, + ) + + if is_last: + self._progress.update( + self.tasks[chain_idx], + draws=draw + 1 if not self.combined_progress else draw, + **all_step_stats, + refresh=True, + ) + + def create_progress_bar(self, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if self.full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + DivergenceBarColumn( + table_column=Column("Progress", ratio=2), + diverging_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) + + +def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 1be14f77f..d65c6c011 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -57,12 +57,11 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.logprob.utils import rvs_in_graph from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import compile from pymc.util import ( - CustomProgress, RandomState, _get_seeds_per_chain, - default_progress_theme, get_default_varnames, point_wrapper, ) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index d3a02b91b..542797caa 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -54,6 +54,7 @@ from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext +from pymc.progress_bar import ProgressBarManager, ProgressBarType, default_progress_theme from pymc.sampling.parallel import Draw, _cpu_count from pymc.sampling.population import _sample_population from pymc.stats.convergence import ( @@ -65,12 +66,9 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - ProgressBarManager, - ProgressBarType, RandomSeed, RandomState, _get_seeds_per_chain, - default_progress_theme, drop_warning_stat, get_random_generator, get_untransformed_name, diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index af2106ce6..6e229b960 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,10 +33,9 @@ from pymc.backends.zarr import ZarrChain from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError +from pymc.progress_bar import ProgressBarManager, default_progress_theme from pymc.util import ( - ProgressBarManager, RandomGeneratorState, - default_progress_theme, get_state_from_generator, random_generator_from_state, ) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 92de63d0c..5bd177170 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -30,6 +30,7 @@ from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress from pymc.stats.convergence import log_warning_stats from pymc.step_methods import CompoundStep from pymc.step_methods.arraystep import ( @@ -39,7 +40,6 @@ ) from pymc.step_methods.compound import StepMethodState from pymc.step_methods.metropolis import DEMetropolis -from pymc.util import CustomProgress __all__ = () diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index f3176f464..5afd39828 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -39,10 +39,11 @@ from pymc.distributions.distribution import _support_point from pymc.logprob.abstract import _icdf, _logcdf, _logprob from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH from pymc.stats.convergence import log_warnings, run_convergence_checks -from pymc.util import CustomProgress, RandomState, _get_seeds_per_chain +from pymc.util import RandomState, _get_seeds_per_chain def sample_smc( diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 2fbbba633..1385f3348 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -36,9 +36,8 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.util import ( - CustomProgress, - default_progress_theme, get_default_varnames, get_value_vars_from_user_vars, ) diff --git a/pymc/util.py b/pymc/util.py index 08e22e106..3f108b8b0 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -16,9 +16,9 @@ import re from collections import namedtuple -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Literal, NewType, cast +from typing import NewType, cast import arviz import cloudpickle @@ -28,47 +28,11 @@ from cachetools import LRUCache, cachedmethod from pytensor import Variable from pytensor.compile import SharedVariable -from rich.box import SIMPLE_HEAD -from rich.console import Console -from rich.progress import ( - BarColumn, - Progress, - Task, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) -from rich.style import Style -from rich.table import Column, Table -from rich.theme import Theme from pymc.exceptions import BlockModelAccessError -if TYPE_CHECKING: - from pymc.step_methods.compound import BlockedStep, CompoundStep - - -ProgressBarType = Literal[ - "combined", - "split", - "combined+stats", - "stats+combined", - "split+stats", - "stats+split", -] - - VarName = NewType("VarName", str) -default_progress_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - "progress.remaining": "none", - "progress.elapsed": "none", - } -) - class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" @@ -532,386 +496,6 @@ def makeiter(a): return [a] -class CustomProgress(Progress): - """A child of Progress that allows to disable progress bars and its container. - - The implementation simply checks an `is_enabled` flag and generates the progress bar only if - it's `True`. - """ - - def __init__(self, *args, disable=False, include_headers=False, **kwargs): - self.is_enabled = not disable - self.include_headers = include_headers - - if self.is_enabled: - super().__init__(*args, **kwargs) - - def __enter__(self): - """Enter the context manager.""" - if self.is_enabled: - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Exit the context manager.""" - if self.is_enabled: - super().__exit__(exc_type, exc_val, exc_tb) - - def add_task(self, *args, **kwargs): - if self.is_enabled: - return super().add_task(*args, **kwargs) - return None - - def advance(self, task_id, advance=1) -> None: - if self.is_enabled: - super().advance(task_id, advance) - return None - - def update( - self, - task_id, - *, - total=None, - completed=None, - advance=None, - description=None, - visible=None, - refresh=False, - **fields, - ): - if self.is_enabled: - super().update( - task_id, - total=total, - completed=completed, - advance=advance, - description=description, - visible=visible, - refresh=refresh, - **fields, - ) - return None - - def make_tasks_table(self, tasks: Iterable[Task]) -> Table: - """Get a table to render the Progress display. - - Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. - - Parameters - ---------- - tasks: Iterable[Task] - An iterable of Task instances, one per row of the table. - - Returns - ------- - table: Table - A table instance. - """ - - def call_column(column, task): - # Subclass rich.BarColumn and add a callback method to dynamically update the display - if hasattr(column, "callbacks"): - column.callbacks(task) - - return column(task) - - table_columns = ( - ( - Column(no_wrap=True) - if isinstance(_column, str) - else _column.get_table_column().copy() - ) - for _column in self.columns - ) - if self.include_headers: - table = Table( - *table_columns, - padding=(0, 1), - expand=self.expand, - show_header=True, - show_edge=True, - box=SIMPLE_HEAD, - ) - else: - table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) - - for task in tasks: - if task.visible: - table.add_row( - *( - ( - column.format(task=task) - if isinstance(column, str) - else call_column(column, task) - ) - for column in self.columns - ) - ) - - return table - - -class DivergenceBarColumn(BarColumn): - """Rich colorbar that changes color when a chain has detected a divergence.""" - - def __init__(self, *args, diverging_color="red", **kwargs): - from matplotlib.colors import to_rgb - - self.diverging_color = diverging_color - self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] - - super().__init__(*args, **kwargs) - - self.non_diverging_style = self.complete_style - self.non_diverging_finished_style = self.finished_style - - def callbacks(self, task: "Task"): - divergences = task.fields.get("divergences", 0) - if isinstance(divergences, float | int) and divergences > 0: - self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - else: - self.complete_style = self.non_diverging_style - self.finished_style = self.non_diverging_finished_style - - -class ProgressBarManager: - """Manage progress bars displayed during sampling.""" - - def __init__( - self, - step_method: "BlockedStep | CompoundStep", - chains: int, - draws: int, - tune: int, - progressbar: bool | ProgressBarType = True, - progressbar_theme: Theme | None = None, - ): - """ - Manage progress bars displayed during sampling. - - When sampling, Step classes are responsible for computing and exposing statistics that can be reported on - progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` - and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which - columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics - that will be displayed on the progress bar. - - Parameters - ---------- - step_method: BlockedStep or CompoundStep - The step method being used to sample - chains: int - Number of chains being sampled - draws: int - Number of draws per chain - tune: int - Number of tuning steps per chain - progressbar: bool or ProgressType, optional - How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask - for one of the following: - - "combined": A single progress bar that displays the total progress across all chains. Only timing - information is shown. - - "split": A separate progress bar for each chain. Only timing information is shown. - - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all - chains. Aggregate sample statistics are also displayed. - - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain - are also displayed. - - If True, the default is "split+stats" is used. - - progressbar_theme: Theme, optional - The theme to use for the progress bar. Defaults to the default theme. - """ - if progressbar_theme is None: - progressbar_theme = default_progress_theme - - match progressbar: - case True: - self.combined_progress = False - self.full_stats = True - show_progress = True - case False: - self.combined_progress = False - self.full_stats = True - show_progress = False - case "combined": - self.combined_progress = True - self.full_stats = False - show_progress = True - case "split": - self.combined_progress = False - self.full_stats = False - show_progress = True - case "combined+stats" | "stats+combined": - self.combined_progress = True - self.full_stats = True - show_progress = True - case "split+stats" | "stats+split": - self.combined_progress = False - self.full_stats = True - show_progress = True - case _: - raise ValueError( - "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " - "one of 'combined', 'split', 'split+stats', or 'combined+stats." - ) - - progress_columns, progress_stats = step_method._progressbar_config(chains) - - self._progress = self.create_progress_bar( - progress_columns, - progressbar=progressbar, - progressbar_theme=progressbar_theme, - ) - self.progress_stats = progress_stats - self.update_stats_functions = step_method._make_update_stats_functions() - - self._show_progress = show_progress - self.divergences = 0 - self.completed_draws = 0 - self.total_draws = draws + tune - self.desc = "Sampling chain" - self.chains = chains - - self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] - - def __enter__(self): - self._initialize_tasks() - - return self._progress.__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - return self._progress.__exit__(exc_type, exc_val, exc_tb) - - def _initialize_tasks(self): - if self.combined_progress: - self.tasks = [ - self._progress.add_task( - self.desc.format(self), - completed=0, - draws=0, - total=self.total_draws * self.chains - 1, - chain_idx=0, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[0] for stat, value in self.progress_stats.items()}, - ) - ] - - else: - self.tasks = [ - self._progress.add_task( - self.desc.format(self), - completed=0, - draws=0, - total=self.total_draws - 1, - chain_idx=chain_idx, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, - ) - for chain_idx in range(self.chains) - ] - - def update(self, chain_idx, is_last, draw, tuning, stats): - if not self._show_progress: - return - - self.completed_draws += 1 - if self.combined_progress: - draw = self.completed_draws - chain_idx = 0 - - elapsed = self._progress.tasks[chain_idx].elapsed - speed, unit = compute_draw_speed(elapsed, draw) - - if not tuning and stats and stats[0].get("diverging"): - self.divergences += 1 - - if self.full_stats: - # TODO: Index by chain already? - chain_progress_stats = [ - update_states_fn(step_stats) - for update_states_fn, step_stats in zip( - self.update_stats_functions, stats, strict=True - ) - ] - all_step_stats = {} - for step_stats in chain_progress_stats: - for key, val in step_stats.items(): - if key in all_step_stats: - # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now - continue - else: - all_step_stats[key] = val - - else: - all_step_stats = {} - - # more_updates = ( - # {stat: value[chain_idx] for stat, value in progress_stats.items()} - # if self.full_stats - # else {} - # ) - - self._progress.update( - self.tasks[chain_idx], - completed=draw, - draws=draw, - sampling_speed=speed, - speed_unit=unit, - **all_step_stats, - ) - - if is_last: - self._progress.update( - self.tasks[chain_idx], - draws=draw + 1 if not self.combined_progress else draw, - **all_step_stats, - refresh=True, - ) - - def create_progress_bar(self, step_columns, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] - - if self.full_stats: - columns += step_columns - - columns += [ - TextColumn( - "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", - table_column=Column("Sampling Speed", ratio=1), - ), - TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), - TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), - ] - - return CustomProgress( - DivergenceBarColumn( - table_column=Column("Progress", ratio=2), - diverging_color="tab:red", - complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(31,119,180)"), # tab:blue - ), - *columns, - console=Console(theme=progressbar_theme), - disable=not progressbar, - include_headers=True, - ) - - -def compute_draw_speed(elapsed, draws): - speed = draws / max(elapsed, 1e-6) - - if speed > 1 or speed == 0: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed - - return speed, unit - - RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index d9da7fb78..b83c1db4a 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -23,7 +23,7 @@ import pymc as pm -from pymc.util import CustomProgress, default_progress_theme +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.variational import test_functions from pymc.variational.approximations import Empirical, FullRank, MeanField from pymc.variational.operators import KL, KSD diff --git a/tests/test_progress_bar.py b/tests/test_progress_bar.py new file mode 100644 index 000000000..6687db1ae --- /dev/null +++ b/tests/test_progress_bar.py @@ -0,0 +1,46 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pymc as pm + + +def test_progressbar_nested_compound(): + # Regression test for https://github.com/pymc-devs/pymc/issues/7721 + + with pm.Model(): + a = pm.Poisson("a", mu=10) + b = pm.Binomial("b", n=a, p=0.8) + c = pm.Poisson("c", mu=11) + d = pm.Dirichlet("d", a=[c, b]) + + step = pm.CompoundStep( + [ + pm.CompoundStep([pm.Metropolis(a), pm.Metropolis(b), pm.Metropolis(c)]), + pm.NUTS([d]), + ] + ) + + kwargs = { + "draws": 10, + "tune": 10, + "chains": 2, + "compute_convergence_checks": False, + "step": step, + } + + # We don't parametrize to avoid recompiling the model functions + for cores in (1, 2): + pm.sample(**kwargs, cores=cores, progressbar=True) # default is split+stats + pm.sample(**kwargs, cores=cores, progressbar="combined") + pm.sample(**kwargs, cores=cores, progressbar="split") + pm.sample(**kwargs, cores=cores, progressbar=False) diff --git a/tests/test_util.py b/tests/test_util.py index cc07d75c2..98cc168f0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -250,35 +250,3 @@ def test_get_value_vars_from_user_vars(): get_value_vars_from_user_vars([x2], model1) with pytest.raises(ValueError, match=rf"{prefix} \['det2'\]"): get_value_vars_from_user_vars([det2], model2) - - -def test_progressbar_nested_compound(): - # Regression test for https://github.com/pymc-devs/pymc/issues/7721 - - with pm.Model(): - a = pm.Poisson("a", mu=10) - b = pm.Binomial("b", n=a, p=0.8) - c = pm.Poisson("c", mu=11) - d = pm.Dirichlet("d", a=[c, b]) - - step = pm.CompoundStep( - [ - pm.CompoundStep([pm.Metropolis(a), pm.Metropolis(b), pm.Metropolis(c)]), - pm.NUTS([d]), - ] - ) - - kwargs = { - "draws": 10, - "tune": 10, - "chains": 2, - "compute_convergence_checks": False, - "step": step, - } - - # We don't parametrize to avoid recompiling the model functions - for cores in (1, 2): - pm.sample(**kwargs, cores=cores, progressbar=True) # default is split+stats - pm.sample(**kwargs, cores=cores, progressbar="combined") - pm.sample(**kwargs, cores=cores, progressbar="split") - pm.sample(**kwargs, cores=cores, progressbar=False) From 870488682df12158cabafccda9eb6d7a05f3cfdc Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 15 Jul 2025 12:14:26 +0200 Subject: [PATCH 4/7] Rename `make_update_stats_functions` to `_make_progressbar_update_functions` --- pymc/progress_bar.py | 6 +++--- pymc/step_methods/compound.py | 6 +++--- pymc/step_methods/hmc/nuts.py | 2 +- pymc/step_methods/metropolis.py | 2 +- pymc/step_methods/slicer.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index 655560b08..8d004fbba 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -209,8 +209,8 @@ def __init__( When sampling, Step classes are responsible for computing and exposing statistics that can be reported on progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` - and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which - columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics + and :meth:`pymc.step_methods.BlockedStep._make_progressbar_update_functions`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_progressbar_update_functions` computes the statistics that will be displayed on the progress bar. Parameters @@ -281,7 +281,7 @@ def __init__( progressbar_theme=progressbar_theme, ) self.progress_stats = progress_stats - self.update_stats_functions = step_method._make_update_stats_functions() + self.update_stats_functions = step_method._make_progressbar_update_functions() self._show_progress = show_progress self.divergences = 0 diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 50b23b1d4..a9cae903f 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -189,7 +189,7 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_functions(): + def _make_progressbar_update_functions(): def update_stats(step_stats): return step_stats @@ -332,10 +332,10 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_functions(self): + def _make_progressbar_update_functions(self): update_functions = [] for method in self.methods: - update_functions.extend(method._make_update_stats_functions()) + update_functions.extend(method._make_progressbar_update_functions()) return update_functions diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 334a4eac3..5ecaa8ae5 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -248,7 +248,7 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_functions(): + def _make_progressbar_update_functions(): def update_stats(stats): return {key: stats[key] for key in ("diverging", "step_size", "tree_size")} diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 4d798e947..4b2a8f18e 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -346,7 +346,7 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_functions(): + def _make_progressbar_update_functions(): def update_stats(step_stats): return { "accept_rate" if key == "accept" else key: step_stats[key] diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index ef5bbebc4..180ac1c88 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -212,7 +212,7 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_functions(): + def _make_progressbar_update_functions(): def update_stats(step_stats): return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} From 84f92f65eeaa627b47f81e8f7e64316df00ea63f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 15 Jul 2025 12:18:32 +0200 Subject: [PATCH 5/7] Cleanup progress_bar --- pymc/progress_bar.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index 8d004fbba..82afc9b42 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -330,6 +330,18 @@ def _initialize_tasks(self): for chain_idx in range(self.chains) ] + @staticmethod + def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + def update(self, chain_idx, is_last, draw, tuning, stats): if not self._show_progress: return @@ -340,7 +352,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): chain_idx = 0 elapsed = self._progress.tasks[chain_idx].elapsed - speed, unit = compute_draw_speed(elapsed, draw) + speed, unit = self.compute_draw_speed(elapsed, draw) if not tuning and stats and stats[0].get("diverging"): self.divergences += 1 @@ -365,12 +377,6 @@ def update(self, chain_idx, is_last, draw, tuning, stats): else: all_step_stats = {} - # more_updates = ( - # {stat: value[chain_idx] for stat, value in progress_stats.items()} - # if self.full_stats - # else {} - # ) - self._progress.update( self.tasks[chain_idx], completed=draw, @@ -415,15 +421,3 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme): disable=not progressbar, include_headers=True, ) - - -def compute_draw_speed(elapsed, draws): - speed = draws / max(elapsed, 1e-6) - - if speed > 1 or speed == 0: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed - - return speed, unit From 5482fafa4f3be9818656c1211b1f2001fde0082c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 9 Jul 2025 12:52:46 +0200 Subject: [PATCH 6/7] Abstract special behavior of NUTS divergences in ProgressBar Every step sampler can now decide whether sampling is failing or not by setting "failing" in the returned update dict --- pymc/progress_bar.py | 76 +++++++++++++++-------------- pymc/step_methods/hmc/base_hmc.py | 9 +++- pymc/step_methods/hmc/hmc.py | 33 +++++++++++++ pymc/step_methods/hmc/nuts.py | 5 +- tests/step_methods/hmc/test_nuts.py | 1 + 5 files changed, 85 insertions(+), 39 deletions(-) diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index 82afc9b42..729958430 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -168,28 +168,28 @@ def call_column(column, task): return table -class DivergenceBarColumn(BarColumn): - """Rich colorbar that changes color when a chain has detected a divergence.""" +class RecolorOnFailureBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a failure.""" - def __init__(self, *args, diverging_color="red", **kwargs): + def __init__(self, *args, failing_color="red", **kwargs): from matplotlib.colors import to_rgb - self.diverging_color = diverging_color - self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] + self.failing_color = failing_color + self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)] super().__init__(*args, **kwargs) - self.non_diverging_style = self.complete_style - self.non_diverging_finished_style = self.finished_style + self.default_complete_style = self.complete_style + self.default_finished_style = self.finished_style def callbacks(self, task: "Task"): - divergences = task.fields.get("divergences", 0) - if isinstance(divergences, float | int) and divergences > 0: - self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) + if task.fields["failing"]: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb)) else: - self.complete_style = self.non_diverging_style - self.finished_style = self.non_diverging_finished_style + # Recovered from failing yay + self.complete_style = self.default_complete_style + self.finished_style = self.default_finished_style class ProgressBarManager: @@ -284,7 +284,6 @@ def __init__( self.update_stats_functions = step_method._make_progressbar_update_functions() self._show_progress = show_progress - self.divergences = 0 self.completed_draws = 0 self.total_draws = draws + tune self.desc = "Sampling chain" @@ -311,6 +310,7 @@ def _initialize_tasks(self): chain_idx=0, sampling_speed=0, speed_unit="draws/s", + failing=False, **{stat: value[0] for stat, value in self.progress_stats.items()}, ) ] @@ -325,6 +325,7 @@ def _initialize_tasks(self): chain_idx=chain_idx, sampling_speed=0, speed_unit="draws/s", + failing=False, **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, ) for chain_idx in range(self.chains) @@ -354,28 +355,27 @@ def update(self, chain_idx, is_last, draw, tuning, stats): elapsed = self._progress.tasks[chain_idx].elapsed speed, unit = self.compute_draw_speed(elapsed, draw) - if not tuning and stats and stats[0].get("diverging"): - self.divergences += 1 + failing = False + all_step_stats = {} - if self.full_stats: - # TODO: Index by chain already? - chain_progress_stats = [ - update_states_fn(step_stats) - for update_states_fn, step_stats in zip( - self.update_stats_functions, stats, strict=True - ) - ] - all_step_stats = {} - for step_stats in chain_progress_stats: - for key, val in step_stats.items(): - if key in all_step_stats: - # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now - continue - else: - all_step_stats[key] = val - - else: - all_step_stats = {} + chain_progress_stats = [ + update_stats_fn(step_stats) + for update_stats_fn, step_stats in zip(self.update_stats_functions, stats, strict=True) + ] + for step_stats in chain_progress_stats: + for key, val in step_stats.items(): + if key == "failing": + failing |= val + continue + if not self.full_stats: + # Only care about the "failing" flag + continue + + if key in all_step_stats: + # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now + continue + else: + all_step_stats[key] = val self._progress.update( self.tasks[chain_idx], @@ -383,6 +383,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): draws=draw, sampling_speed=speed, speed_unit=unit, + failing=failing, **all_step_stats, ) @@ -390,6 +391,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self._progress.update( self.tasks[chain_idx], draws=draw + 1 if not self.combined_progress else draw, + failing=failing, **all_step_stats, refresh=True, ) @@ -410,9 +412,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme): ] return CustomProgress( - DivergenceBarColumn( + RecolorOnFailureBarColumn( table_column=Column("Progress", ratio=2), - diverging_color="tab:red", + failing_color="tab:red", complete_style=Style.parse("rgb(31,119,180)"), # tab:blue finished_style=Style.parse("rgb(31,119,180)"), # tab:blue ), diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index e8c96e8c4..297b095e2 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -184,6 +184,7 @@ def __init__( self._step_rand = step_rand self._num_divs_sample = 0 + self.divergences = 0 @abstractmethod def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: @@ -266,11 +267,15 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: divergence_info=info_store, ) + diverging = bool(hmc_step.divergence_info) + if not self.tune: + self.divergences += diverging self.iter_count += 1 stats: dict[str, Any] = { "tune": self.tune, - "diverging": bool(hmc_step.divergence_info), + "diverging": diverging, + "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, "process_time_diff": process_end - process_start, "perf_counter_start": perf_start, @@ -288,6 +293,8 @@ def reset_tuning(self, start=None): self.reset(start=None) def reset(self, start=None): + self.iter_count = 0 + self.divergences = 0 self.tune = True self.potential.reset() diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 565c1fd78..1697341bc 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -19,6 +19,9 @@ import numpy as np +from rich.progress import TextColumn +from rich.table import Column + from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData @@ -55,6 +58,7 @@ class HamiltonianMC(BaseHMC): "accept": (np.float64, []), "diverging": (bool, []), "energy_error": (np.float64, []), + "divergences": (np.int64, []), "energy": (np.float64, []), "path_length": (np.float64, []), "accepted": (bool, []), @@ -202,3 +206,32 @@ def competence(var, has_grad): if var.dtype in discrete_types or not has_grad: return Competence.INCOMPATIBLE return Competence.COMPATIBLE + + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), + TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)), + ] + + stats = { + "divergences": [0] * n_chains, + "n_steps": [0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_progressbar_update_functions(): + def update_stats(stats): + return { + key: stats[key] + for key in ( + "divergences", + "n_steps", + ) + } | { + "failing": stats["divergences"] > 0, + } + + return (update_stats,) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 5ecaa8ae5..0f19d3c08 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -115,6 +115,7 @@ class NUTS(BaseHMC): "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), "diverging": (bool, []), + "divergences": (int, []), "energy_error": (np.float64, []), "energy": (np.float64, []), "max_energy_error": (np.float64, []), @@ -250,7 +251,9 @@ def _progressbar_config(n_chains=1): @staticmethod def _make_progressbar_update_functions(): def update_stats(stats): - return {key: stats[key] for key in ("diverging", "step_size", "tree_size")} + return {key: stats[key] for key in ("divergences", "step_size", "tree_size")} | { + "failing": stats["divergences"] > 0, + } return (update_stats,) diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 432418a33..8d497f301 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -148,6 +148,7 @@ def test_sampler_stats(self): expected_stat_names = { "depth", "diverging", + "divergences", "energy", "energy_error", "model_logp", From 2b50d67fa09a4f3275721e5697f1d87dc24d83c1 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Wed, 9 Jul 2025 21:47:41 +0800 Subject: [PATCH 7/7] Correct `stats_dtypes_shapes` is `CategoricalGibbsMetropolis` --- pymc/step_methods/metropolis.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 4b2a8f18e..2cd2e1369 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -680,7 +680,6 @@ def competence(var): class CategoricalGibbsMetropolisState(StepMethodState): shuffle_dims: bool dimcats: list[tuple] - tune: bool class CategoricalGibbsMetropolis(ArrayStep): @@ -763,10 +762,6 @@ def __init__( else: raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. - self.tune = True - if compile_kwargs is None: compile_kwargs = {} super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) @@ -796,10 +791,8 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + # This step doesn't have any tunable parameters + return q, [{"tune": False}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -816,7 +809,8 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType for dim, k in dimcats: logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) - return q, [] + # This step doesn't have any tunable parameters + return q, [{"tune": False}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError()