Skip to content

Commit 2ebff6f

Browse files
Forward model argument to add_fit_to_inference_data in find_MAP
1 parent 82621a2 commit 2ebff6f

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,20 @@ def find_MAP(
335335
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
336336
}
337337

338-
idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed)
339-
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
338+
idata = map_results_to_inference_data(
339+
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
340+
)
341+
342+
idata = add_fit_to_inference_data(
343+
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
344+
)
345+
340346
idata = add_optimizer_result_to_inference_data(
341-
idata, optimizer_result, method, raveled_optimized, model
347+
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
342348
)
349+
343350
idata = add_data_to_inference_data(
344-
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
351+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
345352
)
346353

347354
return idata

tests/inference/laplace_approx/test_find_map.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,22 @@ def test_find_MAP(
185185
assert not hasattr(idata, "unconstrained_posterior")
186186

187187

188+
def test_find_map_outside_model_context():
189+
"""
190+
Test that find_MAP can be called outside of a model context.
191+
"""
192+
with pm.Model() as m:
193+
mu = pm.Normal("mu", 0, 1)
194+
sigma = pm.Exponential("sigma", 1)
195+
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))
196+
197+
idata = find_MAP(model=m, method="L-BFGS-B", use_grad=True, progressbar=False)
198+
199+
assert hasattr(idata, "posterior")
200+
assert hasattr(idata, "fit")
201+
assert hasattr(idata, "optimizer_result")
202+
203+
188204
@pytest.mark.parametrize(
189205
"backend, gradient_backend",
190206
[("jax", "jax")],

tests/inference/laplace_approx/test_laplace.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
8383
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)
8484

8585

86+
def test_fit_laplace_outside_model_context():
87+
with pm.Model() as m:
88+
mu = pm.Normal("mu", 0, 1)
89+
sigma = pm.Exponential("sigma", 1)
90+
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))
91+
92+
idata = fit_laplace(
93+
model=m,
94+
optimize_method="L-BFGS-B",
95+
use_grad=True,
96+
progressbar=False,
97+
chains=1,
98+
draws=100,
99+
)
100+
101+
assert hasattr(idata, "posterior")
102+
assert hasattr(idata, "fit")
103+
assert hasattr(idata, "optimizer_result")
104+
105+
assert idata.posterior["mu"].shape == (1, 100)
106+
107+
86108
@pytest.mark.parametrize(
87109
"include_transformed", [True, False], ids=["include_transformed", "no_transformed"]
88110
)

0 commit comments

Comments
 (0)