Skip to content

Forward model argument to add_fit_to_inference_data in find_MAP #543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions pymc_extras/inference/laplace_approx/find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,20 @@ def find_MAP(
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
}

idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed)
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
idata = map_results_to_inference_data(
map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
)

idata = add_fit_to_inference_data(
idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
)

idata = add_optimizer_result_to_inference_data(
idata, optimizer_result, method, raveled_optimized, model
idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model
)

idata = add_data_to_inference_data(
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
)

return idata
16 changes: 16 additions & 0 deletions tests/inference/laplace_approx/test_find_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,22 @@ def test_find_MAP(
assert not hasattr(idata, "unconstrained_posterior")


def test_find_map_outside_model_context():
"""
Test that find_MAP can be called outside of a model context.
"""
with pm.Model() as m:
mu = pm.Normal("mu", 0, 1)
sigma = pm.Exponential("sigma", 1)
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))

idata = find_MAP(model=m, method="L-BFGS-B", use_grad=True, progressbar=False)

assert hasattr(idata, "posterior")
assert hasattr(idata, "fit")
assert hasattr(idata, "optimizer_result")
Comment on lines +199 to +201

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it relevant / worth checking the presence of any data group? I see above a call like the following within find_MAP

    idata = add_data_to_inference_data(
        idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
    )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crap yes it is, but I just clicked merge

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha, well another thing I just saw:

Is it necessary to compute deterministics within add_data_to_inference_data? See

if model.deterministics:
expand_dims = {}
if "chain" not in idata.posterior.coords:
expand_dims["chain"] = [0]
if "draw" not in idata.posterior.coords:
expand_dims["draw"] = [0]
idata.posterior = pm.compute_deterministics(
idata.posterior.expand_dims(expand_dims),
model=model,
merge_dataset=True,
progressbar=progressbar,
compile_kwargs=compile_kwargs,
)

I guess that function is not exposed to the user, but I just wanted to raise that potentially silent side effect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not necessary, but that's what pm.sample does right?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see now why it's included



@pytest.mark.parametrize(
"backend, gradient_backend",
[("jax", "jax")],
Expand Down
66 changes: 66 additions & 0 deletions tests/inference/laplace_approx/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)


def test_fit_laplace_outside_model_context():
with pm.Model() as m:
mu = pm.Normal("mu", 0, 1)
sigma = pm.Exponential("sigma", 1)
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=np.random.normal(size=10))

idata = fit_laplace(
model=m,
optimize_method="L-BFGS-B",
use_grad=True,
progressbar=False,
chains=1,
draws=100,
)

assert hasattr(idata, "posterior")
assert hasattr(idata, "fit")
assert hasattr(idata, "optimizer_result")

assert idata.posterior["mu"].shape == (1, 100)


@pytest.mark.parametrize(
"include_transformed", [True, False], ids=["include_transformed", "no_transformed"]
)
Expand Down Expand Up @@ -208,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng):
assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys())


def test_laplace_nonstandard_dims_2d():
true_P = np.array([[0.5, 0.3, 0.2], [0.1, 0.6, 0.3], [0.2, 0.4, 0.4]])
y_obs = pm.draw(
pmx.DiscreteMarkovChain.dist(
P=true_P,
init_dist=pm.Categorical.dist(
logit_p=np.ones(
3,
)
),
shape=(100, 5),
)
)

with pm.Model(
coords={
"time": range(y_obs.shape[0]),
"state": list("ABC"),
"next_state": list("ABC"),
"unit": [1, 2, 3, 4, 5],
}
) as model:
y = pm.Data("y", y_obs, dims=["time", "unit"])
init_dist = pm.Categorical.dist(
logit_p=np.ones(
3,
)
)
P = pm.Dirichlet("P", a=np.eye(3) * 2 + 1, dims=["state", "next_state"])
y_hat = pmx.DiscreteMarkovChain(
"y_hat", P=P, init_dist=init_dist, dims=["time", "unit"], observed=y_obs
)

idata = pmx.fit_laplace(progressbar=True)

# The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified
assert "state" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())

# The mutated dimension should be unknown coords
assert "P_simplex___dim_1" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())

assert idata.unconstrained_posterior.P_simplex__.shape[-2:] == (3, 2)


def test_laplace_nonscalar_rv_without_dims():
with pm.Model(coords={"test": ["A", "B", "C"]}) as model:
x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"])
Expand Down