Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ methods in the current release of PyMC experimental.
marginalize
recover_marginals
model_builder.ModelBuilder
opt_sample

Inference
=========
Expand Down
1 change: 1 addition & 0 deletions pymc_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
recover_marginals,
)
from pymc_extras.model.model_api import as_model
from pymc_extras.sampling.mcmc import opt_sample

_log = logging.getLogger("pmx")

Expand Down
16 changes: 2 additions & 14 deletions pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.pytensorf import constant_fold
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.mode import Mode
from pytensor.graph import FunctionGraph, Op, vectorize_graph
from pytensor.graph.basic import equal_computations
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.type import Variable
from pytensor.scan import map as scan_map
from pytensor.scan import scan
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType

from pymc_extras.distributions import DiscreteMarkovChain
from pymc_extras.utils.ofg import inline_ofg_outputs


class MarginalRV(OpFromGraph, MeasurableOp):
Expand Down Expand Up @@ -206,19 +207,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
return logp.transpose(*dims_alignment)


def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
"""Inline the inner graph (outputs) of an OpFromGraph Op.

Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
the inner graph.
"""
return graph_replace(
op.inner_outputs,
replace=tuple(zip(op.inner_inputs, inputs)),
strict=False,
)


class NonSeparableLogpWarning(UserWarning):
pass

Expand Down
2 changes: 1 addition & 1 deletion pymc_extras/model/marginal/graph_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pymc.model.fgraph import ModelVar
from pymc.variational.minibatch_rv import MinibatchRandomVariable
from pytensor.graph import Variable, ancestors
from pytensor.graph.basic import io_toposort
from pytensor.graph.traversal import io_toposort
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
Expand Down
Empty file.
114 changes: 114 additions & 0 deletions pymc_extras/sampling/mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import sys

from pymc.model.core import Model
from pymc.sampling.mcmc import sample
from pytensor.graph.rewriting.basic import GraphRewriter

from pymc_extras.sampling.optimizations.optimize import (
TAGS_TYPE,
optimize_model_for_mcmc_sampling,
)


def opt_sample(
*args,
model: Model | None = None,
include: TAGS_TYPE = ("default",),
exclude: TAGS_TYPE = None,
rewriter: GraphRewriter | None = None,
verbose: bool = False,
**kwargs,
):
"""Sample from a model after applying optimizations.

Parameters
----------
model : Model, optinoal
The model to sample from. If None, use the model associated with the context.
include : TAGS_TYPE
The tags to include in the optimizations. Ignored if `rewriter` is not None.
exclude : TAGS_TYPE
The tags to exclude from the optimizations. Ignored if `rewriter` is not None.
rewriter : RewriteDatabaseQuery (optional)
The rewriter to use. If None, use the default rewriter with the given `include` and `exclude` tags.
verbose : bool, default=False
Print information about the optimizations applied.
*args, **kwargs:
Passed to `pm.sample`

Returns
-------
sample_output:
The output of `pm.sample`

Examples
--------
.. code:: python
import pymc as pm
import pymc_experimental as pmx

with pm.Model() as m:
p = pm.Beta("p", 1, 1, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)

idata = pmx.opt_sample(verbose=True)

# Applied optimization: beta_binomial_conjugacy 1x
# ConjugateRVSampler: [p]


You can control which optimizations are applied using the `include` and `exclude` arguments:

.. code:: python
import pymc as pm
import pymc_experimental as pmx

with pm.Model() as m:
p = pm.Beta("p", 1, 1, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)

idata = pmx.opt_sample(exclude="conjugacy", verbose=True)

# No optimizations applied
# NUTS: [p]

.. code:: python
import pymc as pm
import pymc_experimental as pmx

with pm.Model() as m:
a = pm.InverseGamma("a", 1, 1)
b = pm.InverseGamma("b", 1, 1)
p = pm.Beta("p", a, b, shape=(1000,))
y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250)

# By default, the conjugacy of p will not be applied because it depends on other free variables
idata = pmx.opt_sample(include="conjugacy-eager", verbose=True)

# Applied optimization: beta_binomial_conjugacy_eager 1x
# CompoundStep
# >NUTS: [a, b]
# >ConjugateRVSampler: [p]

"""
if kwargs.get("step", None) is not None:
raise ValueError(
"The `step` argument is not supported in `opt_sample`, as custom steps would refer to the original model.\n"
"You can manually transform the model with `pymc_experimental.sampling.optimizations.optimize_model_for_mcmc_sampling` "
"and then define the custom steps and forward them to `pymc.sample`."
)

opt_model, rewrite_counters = optimize_model_for_mcmc_sampling(
model, include=include, exclude=exclude, rewriter=rewriter
)

if verbose:
applied_opt = False
for rewrite_counter in rewrite_counters:
for rewrite, counts in rewrite_counter.items():
applied_opt = True
print(f"Applied optimization: {rewrite} {counts}x", file=sys.stdout)
if not applied_opt:
print("No optimizations applied", file=sys.stdout)

return sample(*args, model=opt_model, **kwargs)
13 changes: 13 additions & 0 deletions pymc_extras/sampling/optimizations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# ruff: noqa: F401
# Add rewrites to the optimization DBs

from pymc_extras.sampling.optimizations import conjugacy, summary_stats
from pymc_extras.sampling.optimizations.optimize import (
optimize_model_for_mcmc_sampling,
posterior_optimization_db,
)

__all__ = [
"posterior_optimization_db",
"optimize_model_for_mcmc_sampling",
]
Loading
Loading