diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index 79a1dea5..b717e0ba 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -335,13 +335,20 @@ def find_MAP( var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) } - idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed) - idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv) + idata = map_results_to_inference_data( + map_point=optimized_point, model=frozen_model, include_transformed=include_transformed + ) + + idata = add_fit_to_inference_data( + idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model + ) + idata = add_optimizer_result_to_inference_data( - idata, optimizer_result, method, raveled_optimized, model + idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model ) + idata = add_data_to_inference_data( - idata, progressbar=False, model=model, compile_kwargs=compile_kwargs + idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs ) return idata diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index 309876ca..f1406ca6 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -185,6 +185,22 @@ def test_find_MAP( assert not hasattr(idata, "unconstrained_posterior") +def test_find_map_outside_model_context(): + """ + Test that find_MAP can be called outside of a model context. + """ + with pm.Model() as m: + mu = pm.Normal("mu", 0, 1) + sigma = pm.Exponential("sigma", 1) + y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10)) + + idata = find_MAP(model=m, method="L-BFGS-B", use_grad=True, progressbar=False) + + assert hasattr(idata, "posterior") + assert hasattr(idata, "fit") + assert hasattr(idata, "optimizer_result") + + @pytest.mark.parametrize( "backend, gradient_backend", [("jax", "jax")], diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index ab0ed34b..6196673b 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend): np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3) +def test_fit_laplace_outside_model_context(): + with pm.Model() as m: + mu = pm.Normal("mu", 0, 1) + sigma = pm.Exponential("sigma", 1) + y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10)) + + idata = fit_laplace( + model=m, + optimize_method="L-BFGS-B", + use_grad=True, + progressbar=False, + chains=1, + draws=100, + ) + + assert hasattr(idata, "posterior") + assert hasattr(idata, "fit") + assert hasattr(idata, "optimizer_result") + + assert idata.posterior["mu"].shape == (1, 100) + + @pytest.mark.parametrize( "include_transformed", [True, False], ids=["include_transformed", "no_transformed"] ) @@ -208,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng): assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys()) +def test_laplace_nonstandard_dims_2d(): + true_P = np.array([[0.5, 0.3, 0.2], [0.1, 0.6, 0.3], [0.2, 0.4, 0.4]]) + y_obs = pm.draw( + pmx.DiscreteMarkovChain.dist( + P=true_P, + init_dist=pm.Categorical.dist( + logit_p=np.ones( + 3, + ) + ), + shape=(100, 5), + ) + ) + + with pm.Model( + coords={ + "time": range(y_obs.shape[0]), + "state": list("ABC"), + "next_state": list("ABC"), + "unit": [1, 2, 3, 4, 5], + } + ) as model: + y = pm.Data("y", y_obs, dims=["time", "unit"]) + init_dist = pm.Categorical.dist( + logit_p=np.ones( + 3, + ) + ) + P = pm.Dirichlet("P", a=np.eye(3) * 2 + 1, dims=["state", "next_state"]) + y_hat = pmx.DiscreteMarkovChain( + "y_hat", P=P, init_dist=init_dist, dims=["time", "unit"], observed=y_obs + ) + + idata = pmx.fit_laplace(progressbar=True) + + # The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified + assert "state" in list(idata.unconstrained_posterior.P_simplex__.coords.keys()) + + # The mutated dimension should be unknown coords + assert "P_simplex___dim_1" in list(idata.unconstrained_posterior.P_simplex__.coords.keys()) + + assert idata.unconstrained_posterior.P_simplex__.shape[-2:] == (3, 2) + + def test_laplace_nonscalar_rv_without_dims(): with pm.Model(coords={"test": ["A", "B", "C"]}) as model: x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"])