Skip to content

Commit 96f7a2e

Browse files
zaxtaxricardoV94
authored andcommitted
Wrap function arguments with pm.Data if they support it.
1 parent 9c7a6fb commit 96f7a2e

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

pymc_extras/model/model_api.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from functools import wraps
2+
from inspect import signature
23

3-
from pymc import Model
4+
import pytensor.tensor as pt
5+
6+
from pymc import Data, Model
47

58

69
def as_model(*model_args, **model_kwargs):
@@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
912
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
1013
Additionally, a coords argument is added to the function so coords can be changed during function invocation
1114
15+
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
16+
1217
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
1318
1419
Examples
@@ -47,8 +52,19 @@ def decorator(f):
4752
@wraps(f)
4853
def make_model(*args, **kwargs):
4954
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
55+
sig = signature(f)
56+
ba = sig.bind(*args, **kwargs)
57+
ba.apply_defaults()
58+
5059
with Model(*model_args, coords=coords, **model_kwargs) as m:
51-
f(*args, **kwargs)
60+
for name, v in ba.arguments.items():
61+
# Only wrap pm.Data around values pytensor can process
62+
try:
63+
_ = pt.as_tensor_variable(v)
64+
ba.arguments[name] = Data(name, v)
65+
except (NotImplementedError, TypeError, ValueError):
66+
pass
67+
f(*ba.args, **ba.kwargs)
5268
return m
5369

5470
return make_model

tests/model/test_model_api.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,14 @@ def model_wrapped2():
2525

2626
mw2 = model_wrapped2(coords=coords)
2727

28+
@pmx.as_model()
29+
def model_wrapped3(mu):
30+
pm.Normal("x", mu, 1.0, dims="obs")
31+
32+
mw3 = model_wrapped3(0.0, coords=coords)
33+
mw4 = model_wrapped3(np.array([np.nan]), coords=coords)
34+
2835
np.testing.assert_equal(model.point_logps(), mw.point_logps())
2936
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
37+
assert mw3["mu"] in mw3.data_vars
38+
assert "mu" not in mw4

0 commit comments

Comments
 (0)