Skip to content

Commit 37f4588

Browse files
More options for DADVI (#579)
* Add more options to DADVI minimization * Rename `include_transformed` for consistency, and return uncontrained_posterior in a separate group * re-run dadvi notebook * Allow basinhopping when fitting DADVI * Respond to feedback
1 parent 07c6ab4 commit 37f4588

File tree

8 files changed

+2024
-539
lines changed

8 files changed

+2024
-539
lines changed

notebooks/deterministic_advi_example.ipynb

Lines changed: 1609 additions & 425 deletions
Large diffs are not rendered by default.

pymc_extras/inference/dadvi/dadvi.py

Lines changed: 162 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,88 +5,97 @@
55
import pytensor.tensor as pt
66
import xarray
77

8-
from better_optimize import minimize
8+
from better_optimize import basinhopping, minimize
99
from better_optimize.constants import minimize_method
1010
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
1111
from pymc.backends.arviz import (
1212
PointFunc,
1313
apply_function_over_dataset,
1414
coords_and_dims_for_inferencedata,
1515
)
16+
from pymc.blocking import RaveledVars
1617
from pymc.util import RandomSeed, get_default_varnames
1718
from pytensor.tensor.variable import TensorVariable
1819

20+
from pymc_extras.inference.laplace_approx.idata import (
21+
add_data_to_inference_data,
22+
add_optimizer_result_to_inference_data,
23+
)
1924
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
2025
from pymc_extras.inference.laplace_approx.scipy_interface import (
21-
_compile_functions_for_scipy_optimize,
26+
scipy_optimize_funcs_from_loss,
27+
set_optimizer_function_defaults,
2228
)
2329

2430

2531
def fit_dadvi(
2632
model: Model | None = None,
2733
n_fixed_draws: int = 30,
28-
random_seed: RandomSeed = None,
2934
n_draws: int = 1000,
30-
keep_untransformed: bool = False,
35+
include_transformed: bool = False,
3136
optimizer_method: minimize_method = "trust-ncg",
32-
use_grad: bool = True,
33-
use_hessp: bool = True,
34-
use_hess: bool = False,
35-
**minimize_kwargs,
37+
use_grad: bool | None = None,
38+
use_hessp: bool | None = None,
39+
use_hess: bool | None = None,
40+
gradient_backend: str = "pytensor",
41+
compile_kwargs: dict | None = None,
42+
random_seed: RandomSeed = None,
43+
progressbar: bool = True,
44+
**optimizer_kwargs,
3645
) -> az.InferenceData:
3746
"""
38-
Does inference using deterministic ADVI (automatic differentiation
39-
variational inference), DADVI for short.
47+
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
4048
41-
For full details see the paper cited in the references:
42-
https://www.jmlr.org/papers/v25/23-1015.html
49+
For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
4350
4451
Parameters
4552
----------
4653
model : pm.Model
4754
The PyMC model to be fit. If None, the current model context is used.
4855
4956
n_fixed_draws : int
50-
The number of fixed draws to use for the optimisation. More
51-
draws will result in more accurate estimates, but also
52-
increase inference time. Usually, the default of 30 is a good
53-
tradeoff.between speed and accuracy.
57+
The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
58+
also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
5459
5560
random_seed: int
56-
The random seed to use for the fixed draws. Running the optimisation
57-
twice with the same seed should arrive at the same result.
61+
The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
62+
the same result.
5863
5964
n_draws: int
6065
The number of draws to return from the variational approximation.
6166
62-
keep_untransformed: bool
63-
Whether or not to keep the unconstrained variables (such as
64-
logs of positive-constrained parameters) in the output.
67+
include_transformed: bool
68+
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
69+
output.
6570
6671
optimizer_method: str
67-
Which optimization method to use. The function calls
68-
``scipy.optimize.minimize``, so any of the methods there can
69-
be used. The default is trust-ncg, which uses second-order
70-
information and is generally very reliable. Other methods such
71-
as L-BFGS-B might be faster but potentially more brittle and
72-
may not converge exactly to the optimum.
73-
74-
minimize_kwargs:
75-
Additional keyword arguments to pass to the
76-
``scipy.optimize.minimize`` function. See the documentation of
77-
that function for details.
72+
Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
73+
can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
74+
Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
75+
the optimum.
7876
79-
use_grad:
80-
If True, pass the gradient function to
81-
`scipy.optimize.minimize` (where it is referred to as `jac`).
77+
gradient_backend: str
78+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
8279
83-
use_hessp:
80+
compile_kwargs: dict, optional
81+
Additional keyword arguments to pass to `pytensor.function`
82+
83+
use_grad: bool, optional
84+
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
85+
86+
use_hessp: bool, optional
8487
If True, pass the hessian vector product to `scipy.optimize.minimize`.
8588
86-
use_hess:
87-
If True, pass the hessian to `scipy.optimize.minimize`. Note that
88-
this is generally not recommended since its computation can be slow
89-
and memory-intensive if there are many parameters.
89+
use_hess: bool, optional
90+
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
91+
computation can be slow and memory-intensive if there are many parameters.
92+
93+
progressbar: bool
94+
Whether or not to show a progress bar during optimization. Default is True.
95+
96+
optimizer_kwargs:
97+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98+
that function for details.
9099
91100
Returns
92101
-------
@@ -95,16 +104,25 @@ def fit_dadvi(
95104
96105
References
97106
----------
98-
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99-
Variational Inference with a Deterministic Objective: Faster, More
100-
Accurate, and Even More Black Box. Journal of Machine Learning
101-
Research, 25(18), 1–39.
107+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
108+
Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102109
"""
103110

104111
model = pymc.modelcontext(model) if model is None else model
112+
do_basinhopping = optimizer_method == "basinhopping"
113+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
114+
115+
if do_basinhopping:
116+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118+
# if one isn't provided.
119+
120+
optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
121+
minimizer_kwargs["method"] = optimizer_method
105122

106123
initial_point_dict = model.initial_point()
107-
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
124+
initial_point = DictToArrayBijection.map(initial_point_dict)
125+
n_params = initial_point.data.shape[0]
108126

109127
var_params, objective = create_dadvi_graph(
110128
model,
@@ -113,31 +131,65 @@ def fit_dadvi(
113131
n_params=n_params,
114132
)
115133

116-
f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117-
objective,
118-
[var_params],
119-
compute_grad=use_grad,
120-
compute_hessp=use_hessp,
121-
compute_hess=use_hess,
134+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
135+
optimizer_method, use_grad, use_hess, use_hessp
122136
)
123137

124-
derivative_kwargs = {}
125-
126-
if use_grad:
127-
derivative_kwargs["jac"] = True
128-
if use_hessp:
129-
derivative_kwargs["hessp"] = f_hessp
130-
if use_hess:
131-
derivative_kwargs["hess"] = True
138+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
139+
loss=objective,
140+
inputs=[var_params],
141+
initial_point_dict=None,
142+
use_grad=use_grad,
143+
use_hessp=use_hessp,
144+
use_hess=use_hess,
145+
gradient_backend=gradient_backend,
146+
compile_kwargs=compile_kwargs,
147+
inputs_are_flat=True,
148+
)
132149

133-
result = minimize(
134-
f_fused,
135-
np.zeros(2 * n_params),
136-
method=optimizer_method,
137-
**derivative_kwargs,
138-
**minimize_kwargs,
150+
dadvi_initial_point = {
151+
f"{var_name}_mu": np.zeros_like(value).ravel()
152+
for var_name, value in initial_point_dict.items()
153+
}
154+
dadvi_initial_point.update(
155+
{
156+
f"{var_name}_sigma__log": np.zeros_like(value).ravel()
157+
for var_name, value in initial_point_dict.items()
158+
}
139159
)
140160

161+
dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
162+
args = optimizer_kwargs.pop("args", ())
163+
164+
if do_basinhopping:
165+
if "args" not in minimizer_kwargs:
166+
minimizer_kwargs["args"] = args
167+
if "hessp" not in minimizer_kwargs:
168+
minimizer_kwargs["hessp"] = f_hessp
169+
if "method" not in minimizer_kwargs:
170+
minimizer_kwargs["method"] = optimizer_method
171+
172+
result = basinhopping(
173+
func=f_fused,
174+
x0=dadvi_initial_point.data,
175+
progressbar=progressbar,
176+
minimizer_kwargs=minimizer_kwargs,
177+
**optimizer_kwargs,
178+
)
179+
180+
else:
181+
result = minimize(
182+
f=f_fused,
183+
x0=dadvi_initial_point.data,
184+
args=args,
185+
method=optimizer_method,
186+
hessp=f_hessp,
187+
progressbar=progressbar,
188+
**optimizer_kwargs,
189+
)
190+
191+
raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
192+
141193
opt_var_params = result.x
142194
opt_means, opt_log_sds = np.split(opt_var_params, 2)
143195

@@ -148,9 +200,29 @@ def fit_dadvi(
148200
draws = opt_means + draws_raw * np.exp(opt_log_sds)
149201
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
150202

151-
transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
203+
idata = dadvi_result_to_idata(
204+
draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
205+
)
152206

153-
return transformed_draws
207+
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
208+
var_name_to_model_var.update(
209+
{f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
210+
)
211+
212+
idata = add_optimizer_result_to_inference_data(
213+
idata=idata,
214+
result=result,
215+
method=optimizer_method,
216+
mu=raveled_optimized,
217+
model=model,
218+
var_name_to_model_var=var_name_to_model_var,
219+
)
220+
221+
idata = add_data_to_inference_data(
222+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
223+
)
224+
225+
return idata
154226

155227

156228
def create_dadvi_graph(
@@ -213,10 +285,11 @@ def create_dadvi_graph(
213285
return var_params, objective
214286

215287

216-
def transform_draws(
288+
def dadvi_result_to_idata(
217289
unstacked_draws: xarray.Dataset,
218290
model: Model,
219-
keep_untransformed: bool = False,
291+
include_transformed: bool = False,
292+
progressbar: bool = True,
220293
):
221294
"""
222295
Transforms the unconstrained draws back into the constrained space.
@@ -232,9 +305,12 @@ def transform_draws(
232305
n_draws: int
233306
The number of draws to return from the variational approximation.
234307
235-
keep_untransformed: bool
308+
include_transformed: bool
236309
Whether or not to keep the unconstrained variables in the output.
237310
311+
progressbar: bool
312+
Whether or not to show a progress bar during the transformation. Default is True.
313+
238314
Returns
239315
-------
240316
:class:`~arviz.InferenceData`
@@ -243,7 +319,7 @@ def transform_draws(
243319

244320
filtered_var_names = model.unobserved_value_vars
245321
vars_to_sample = list(
246-
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
322+
get_default_varnames(filtered_var_names, include_transformed=include_transformed)
247323
)
248324
fn = pytensor.function(model.value_vars, vars_to_sample)
249325
point_func = PointFunc(fn)
@@ -256,6 +332,20 @@ def transform_draws(
256332
output_var_names=[x.name for x in vars_to_sample],
257333
coords=coords,
258334
dims=dims,
335+
progressbar=progressbar,
259336
)
260337

261-
return transformed_result
338+
constrained_names = [
339+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
340+
]
341+
all_varnames = [
342+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
343+
]
344+
unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
345+
346+
idata = az.InferenceData(posterior=transformed_result[constrained_names])
347+
348+
if unconstrained_names and include_transformed:
349+
idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
350+
351+
return idata

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pymc as pm
88

99
from better_optimize import basinhopping, minimize
10-
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
10+
from better_optimize.constants import minimize_method
1111
from pymc.blocking import DictToArrayBijection, RaveledVars
1212
from pymc.initial_point import make_initial_point_fn
1313
from pymc.model.transform.optimization import freeze_dims_and_data
@@ -24,40 +24,12 @@
2424
from pymc_extras.inference.laplace_approx.scipy_interface import (
2525
GradientBackend,
2626
scipy_optimize_funcs_from_loss,
27+
set_optimizer_function_defaults,
2728
)
2829

2930
_log = logging.getLogger(__name__)
3031

3132

32-
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
33-
method_info = MINIMIZE_MODE_KWARGS[method].copy()
34-
35-
if use_hess and use_hessp:
36-
_log.warning(
37-
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
38-
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
39-
'Setting "use_hess" to False.'
40-
)
41-
use_hess = False
42-
43-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
44-
45-
if use_hessp is not None and use_hess is None:
46-
use_hess = not use_hessp
47-
48-
elif use_hess is not None and use_hessp is None:
49-
use_hessp = not use_hess
50-
51-
elif use_hessp is None and use_hess is None:
52-
use_hessp = method_info["uses_hessp"]
53-
use_hess = method_info["uses_hess"]
54-
if use_hessp and use_hess:
55-
# If a method could use either hess or hessp, we default to using hessp
56-
use_hess = False
57-
58-
return use_grad, use_hess, use_hessp
59-
60-
6133
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
6234
"""
6335
Compute the nearest positive semi-definite matrix to a given matrix.

0 commit comments

Comments
 (0)