Skip to content

Commit 3a0186e

Browse files
authored
Add support for sample_stats group in pymc.testing.mock_sample (#7887)
1 parent 1950e4c commit 3a0186e

File tree

2 files changed

+79
-10
lines changed

2 files changed

+79
-10
lines changed

pymc/testing.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import numpy as np
2121
import pytensor
2222
import pytensor.tensor as pt
23+
import xarray as xr
2324

2425
from arviz import InferenceData
2526
from numpy import random as nr
2627
from numpy import testing as npt
28+
from numpy.typing import NDArray
2729
from pytensor.compile import SharedVariable
2830
from pytensor.compile.mode import Mode
2931
from pytensor.graph.basic import Constant, Variable, equal_computations
@@ -977,7 +979,14 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
977979
raise AssertionError(f"RV found in graph: {rvs}")
978980

979981

980-
def mock_sample(draws: int = 10, **kwargs):
982+
SampleStatsCreator = Callable[[tuple[int, int]], NDArray]
983+
984+
985+
def mock_sample(
986+
draws: int = 10,
987+
sample_stats: dict[str, SampleStatsCreator] | None = None,
988+
**kwargs,
989+
) -> InferenceData:
981990
"""Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`.
982991
983992
Useful for testing models that use pm.sample without running MCMC sampling.
@@ -1007,6 +1016,36 @@ def mock_pymc_sample():
10071016
10081017
pm.sample = original_sample
10091018
1019+
By default, the sample_stats group is not created. Pass a dictionary of functions
1020+
that create sample statistics, where the keys are the names of the statistics
1021+
and the values are functions that take a size tuple and return an array of that size.
1022+
1023+
.. code-block:: python
1024+
1025+
from functools import partial
1026+
1027+
import numpy as np
1028+
from numpy.typing import NDArray
1029+
1030+
from pymc.testing import mock_sample
1031+
1032+
1033+
def mock_diverging(size: tuple[int, int]) -> NDArray:
1034+
return np.zeros(size)
1035+
1036+
1037+
def mock_tree_depth(size: tuple[int, int]) -> NDArray:
1038+
return np.random.choice(range(2, 10), size=size)
1039+
1040+
1041+
mock_sample_with_stats = partial(
1042+
mock_sample,
1043+
sample_stats={
1044+
"diverging": mock_diverging,
1045+
"tree_depth": mock_tree_depth,
1046+
},
1047+
)
1048+
10101049
"""
10111050
random_seed = kwargs.get("random_seed", None)
10121051
model = kwargs.get("model", None)
@@ -1031,6 +1070,16 @@ def mock_pymc_sample():
10311070
del idata["prior"]
10321071
if "prior_predictive" in idata:
10331072
del idata["prior_predictive"]
1073+
1074+
if sample_stats is not None:
1075+
sizes = idata["posterior"].sizes
1076+
size = (sizes["chain"], sizes["draw"])
1077+
sample_stats_ds = xr.Dataset(
1078+
{name: (("chain", "draw"), creator(size)) for name, creator in sample_stats.items()},
1079+
coords=idata["posterior"].coords,
1080+
)
1081+
idata.add_groups(sample_stats=sample_stats_ds)
1082+
10341083
return idata
10351084

10361085

tests/test_testing.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from contextlib import ExitStack as does_not_raise
1515

16+
import numpy as np
1617
import pytest
1718

1819
import pymc as pm
@@ -38,28 +39,47 @@ def test_domain(values, edges, expectation):
3839

3940

4041
@pytest.mark.parametrize(
41-
"args, kwargs, expected_size",
42+
"args, kwargs, expected_size, sample_stats",
4243
[
43-
pytest.param((), {}, (1, 10), id="default"),
44-
pytest.param((100,), {}, (1, 100), id="positional-draws"),
45-
pytest.param((), {"draws": 100}, (1, 100), id="keyword-draws"),
46-
pytest.param((100,), {"chains": 6}, (6, 100), id="chains"),
44+
pytest.param((), {}, (1, 10), None, id="default"),
45+
pytest.param((100,), {}, (1, 100), None, id="positional-draws"),
46+
pytest.param((), {"draws": 100}, (1, 100), None, id="keyword-draws"),
47+
pytest.param((100,), {"chains": 6}, (6, 100), None, id="chains"),
48+
pytest.param(
49+
(100,),
50+
{"chains": 6},
51+
(6, 100),
52+
{
53+
"diverging": np.zeros,
54+
"tree_depth": lambda size: np.random.choice(range(2, 10), size=size),
55+
},
56+
id="with_sample_stats",
57+
),
4758
],
4859
)
49-
def test_mock_sample(args, kwargs, expected_size) -> None:
60+
def test_mock_sample(args, kwargs, expected_size, sample_stats) -> None:
5061
expected_chains, expected_draws = expected_size
5162
_, model, _ = simple_normal(bounded_prior=True)
5263

5364
with model:
54-
idata = mock_sample(*args, **kwargs)
65+
idata = mock_sample(*args, **kwargs, sample_stats=sample_stats)
5566

5667
assert "posterior" in idata
5768
assert "observed_data" in idata
5869
assert "prior" not in idata
5970
assert "posterior_predictive" not in idata
60-
assert "sample_stats" not in idata
6171

62-
assert idata.posterior.sizes == {"chain": expected_chains, "draw": expected_draws}
72+
expected_sizes = {"chain": expected_chains, "draw": expected_draws}
73+
74+
if sample_stats:
75+
sample_stats_ds = idata["sample_stats"]
76+
for name in sample_stats.keys():
77+
assert sample_stats_ds[name].sizes == expected_sizes
78+
79+
else:
80+
assert "sample_stats" not in idata
81+
82+
assert idata.posterior.sizes == expected_sizes
6383

6484

6585
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)

0 commit comments

Comments
 (0)