Skip to content

Commit fd2933f

Browse files
Forward model argument to add_fit_to_inference_data in find_MAP (#543)
* Forward model argument to add_fit_to_inference_data in find_MAP * Add 2d test for nonstandard dims
1 parent 82621a2 commit fd2933f

File tree

3 files changed

+93
-4
lines changed

3 files changed

+93
-4
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,20 @@ def find_MAP(
335335
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
336336
}
337337

338-
idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed)
339-
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
338+
idata = map_results_to_inference_data(
339+
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
340+
)
341+
342+
idata = add_fit_to_inference_data(
343+
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
344+
)
345+
340346
idata = add_optimizer_result_to_inference_data(
341-
idata, optimizer_result, method, raveled_optimized, model
347+
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
342348
)
349+
343350
idata = add_data_to_inference_data(
344-
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
351+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
345352
)
346353

347354
return idata

tests/inference/laplace_approx/test_find_map.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,22 @@ def test_find_MAP(
185185
assert not hasattr(idata, "unconstrained_posterior")
186186

187187

188+
def test_find_map_outside_model_context():
189+
"""
190+
Test that find_MAP can be called outside of a model context.
191+
"""
192+
with pm.Model() as m:
193+
mu = pm.Normal("mu", 0, 1)
194+
sigma = pm.Exponential("sigma", 1)
195+
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))
196+
197+
idata = find_MAP(model=m, method="L-BFGS-B", use_grad=True, progressbar=False)
198+
199+
assert hasattr(idata, "posterior")
200+
assert hasattr(idata, "fit")
201+
assert hasattr(idata, "optimizer_result")
202+
203+
188204
@pytest.mark.parametrize(
189205
"backend, gradient_backend",
190206
[("jax", "jax")],

tests/inference/laplace_approx/test_laplace.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
8383
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)
8484

8585

86+
def test_fit_laplace_outside_model_context():
87+
with pm.Model() as m:
88+
mu = pm.Normal("mu", 0, 1)
89+
sigma = pm.Exponential("sigma", 1)
90+
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))
91+
92+
idata = fit_laplace(
93+
model=m,
94+
optimize_method="L-BFGS-B",
95+
use_grad=True,
96+
progressbar=False,
97+
chains=1,
98+
draws=100,
99+
)
100+
101+
assert hasattr(idata, "posterior")
102+
assert hasattr(idata, "fit")
103+
assert hasattr(idata, "optimizer_result")
104+
105+
assert idata.posterior["mu"].shape == (1, 100)
106+
107+
86108
@pytest.mark.parametrize(
87109
"include_transformed", [True, False], ids=["include_transformed", "no_transformed"]
88110
)
@@ -208,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng):
208230
assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys())
209231

210232

233+
def test_laplace_nonstandard_dims_2d():
234+
true_P = np.array([[0.5, 0.3, 0.2], [0.1, 0.6, 0.3], [0.2, 0.4, 0.4]])
235+
y_obs = pm.draw(
236+
pmx.DiscreteMarkovChain.dist(
237+
P=true_P,
238+
init_dist=pm.Categorical.dist(
239+
logit_p=np.ones(
240+
3,
241+
)
242+
),
243+
shape=(100, 5),
244+
)
245+
)
246+
247+
with pm.Model(
248+
coords={
249+
"time": range(y_obs.shape[0]),
250+
"state": list("ABC"),
251+
"next_state": list("ABC"),
252+
"unit": [1, 2, 3, 4, 5],
253+
}
254+
) as model:
255+
y = pm.Data("y", y_obs, dims=["time", "unit"])
256+
init_dist = pm.Categorical.dist(
257+
logit_p=np.ones(
258+
3,
259+
)
260+
)
261+
P = pm.Dirichlet("P", a=np.eye(3) * 2 + 1, dims=["state", "next_state"])
262+
y_hat = pmx.DiscreteMarkovChain(
263+
"y_hat", P=P, init_dist=init_dist, dims=["time", "unit"], observed=y_obs
264+
)
265+
266+
idata = pmx.fit_laplace(progressbar=True)
267+
268+
# The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified
269+
assert "state" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())
270+
271+
# The mutated dimension should be unknown coords
272+
assert "P_simplex___dim_1" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())
273+
274+
assert idata.unconstrained_posterior.P_simplex__.shape[-2:] == (3, 2)
275+
276+
211277
def test_laplace_nonscalar_rv_without_dims():
212278
with pm.Model(coords={"test": ["A", "B", "C"]}) as model:
213279
x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"])

0 commit comments

Comments
 (0)