5
5
import pytensor .tensor as pt
6
6
import xarray
7
7
8
- from better_optimize import minimize
8
+ from better_optimize import basinhopping , minimize
9
9
from better_optimize .constants import minimize_method
10
10
from pymc import DictToArrayBijection , Model , join_nonshared_inputs
11
11
from pymc .backends .arviz import (
12
12
PointFunc ,
13
13
apply_function_over_dataset ,
14
14
coords_and_dims_for_inferencedata ,
15
15
)
16
+ from pymc .blocking import RaveledVars
16
17
from pymc .util import RandomSeed , get_default_varnames
17
18
from pytensor .tensor .variable import TensorVariable
18
19
20
+ from pymc_extras .inference .laplace_approx .idata import (
21
+ add_data_to_inference_data ,
22
+ add_optimizer_result_to_inference_data ,
23
+ )
19
24
from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
20
25
from pymc_extras .inference .laplace_approx .scipy_interface import (
21
- _compile_functions_for_scipy_optimize ,
26
+ scipy_optimize_funcs_from_loss ,
27
+ set_optimizer_function_defaults ,
22
28
)
23
29
24
30
25
31
def fit_dadvi (
26
32
model : Model | None = None ,
27
33
n_fixed_draws : int = 30 ,
28
- random_seed : RandomSeed = None ,
29
34
n_draws : int = 1000 ,
30
- keep_untransformed : bool = False ,
35
+ include_transformed : bool = False ,
31
36
optimizer_method : minimize_method = "trust-ncg" ,
32
- use_grad : bool = True ,
33
- use_hessp : bool = True ,
34
- use_hess : bool = False ,
35
- ** minimize_kwargs ,
37
+ use_grad : bool | None = None ,
38
+ use_hessp : bool | None = None ,
39
+ use_hess : bool | None = None ,
40
+ gradient_backend : str = "pytensor" ,
41
+ compile_kwargs : dict | None = None ,
42
+ random_seed : RandomSeed = None ,
43
+ progressbar : bool = True ,
44
+ ** optimizer_kwargs ,
36
45
) -> az .InferenceData :
37
46
"""
38
- Does inference using deterministic ADVI (automatic differentiation
39
- variational inference), DADVI for short.
47
+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
40
48
41
- For full details see the paper cited in the references:
42
- https://www.jmlr.org/papers/v25/23-1015.html
49
+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
43
50
44
51
Parameters
45
52
----------
46
53
model : pm.Model
47
54
The PyMC model to be fit. If None, the current model context is used.
48
55
49
56
n_fixed_draws : int
50
- The number of fixed draws to use for the optimisation. More
51
- draws will result in more accurate estimates, but also
52
- increase inference time. Usually, the default of 30 is a good
53
- tradeoff.between speed and accuracy.
57
+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
58
+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
54
59
55
60
random_seed: int
56
- The random seed to use for the fixed draws. Running the optimisation
57
- twice with the same seed should arrive at the same result.
61
+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
62
+ the same result.
58
63
59
64
n_draws: int
60
65
The number of draws to return from the variational approximation.
61
66
62
- keep_untransformed : bool
63
- Whether or not to keep the unconstrained variables (such as
64
- logs of positive-constrained parameters) in the output.
67
+ include_transformed : bool
68
+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
69
+ output.
65
70
66
71
optimizer_method: str
67
- Which optimization method to use. The function calls
68
- ``scipy.optimize.minimize``, so any of the methods there can
69
- be used. The default is trust-ncg, which uses second-order
70
- information and is generally very reliable. Other methods such
71
- as L-BFGS-B might be faster but potentially more brittle and
72
- may not converge exactly to the optimum.
73
-
74
- minimize_kwargs:
75
- Additional keyword arguments to pass to the
76
- ``scipy.optimize.minimize`` function. See the documentation of
77
- that function for details.
72
+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
73
+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
74
+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
75
+ the optimum.
78
76
79
- use_grad:
80
- If True, pass the gradient function to
81
- `scipy.optimize.minimize` (where it is referred to as `jac`).
77
+ gradient_backend: str
78
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
82
79
83
- use_hessp:
80
+ compile_kwargs: dict, optional
81
+ Additional keyword arguments to pass to `pytensor.function`
82
+
83
+ use_grad: bool, optional
84
+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
85
+
86
+ use_hessp: bool, optional
84
87
If True, pass the hessian vector product to `scipy.optimize.minimize`.
85
88
86
- use_hess:
87
- If True, pass the hessian to `scipy.optimize.minimize`. Note that
88
- this is generally not recommended since its computation can be slow
89
- and memory-intensive if there are many parameters.
89
+ use_hess: bool, optional
90
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
91
+ computation can be slow and memory-intensive if there are many parameters.
92
+
93
+ progressbar: bool
94
+ Whether or not to show a progress bar during optimization. Default is True.
95
+
96
+ optimizer_kwargs:
97
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98
+ that function for details.
90
99
91
100
Returns
92
101
-------
@@ -95,16 +104,25 @@ def fit_dadvi(
95
104
96
105
References
97
106
----------
98
- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99
- Variational Inference with a Deterministic Objective: Faster, More
100
- Accurate, and Even More Black Box. Journal of Machine Learning
101
- Research, 25(18), 1–39.
107
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
108
+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102
109
"""
103
110
104
111
model = pymc .modelcontext (model ) if model is None else model
112
+ do_basinhopping = optimizer_method == "basinhopping"
113
+ minimizer_kwargs = optimizer_kwargs .pop ("minimizer_kwargs" , {})
114
+
115
+ if do_basinhopping :
116
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118
+ # if one isn't provided.
119
+
120
+ optimizer_method = minimizer_kwargs .pop ("method" , "L-BFGS-B" )
121
+ minimizer_kwargs ["method" ] = optimizer_method
105
122
106
123
initial_point_dict = model .initial_point ()
107
- n_params = DictToArrayBijection .map (initial_point_dict ).data .shape [0 ]
124
+ initial_point = DictToArrayBijection .map (initial_point_dict )
125
+ n_params = initial_point .data .shape [0 ]
108
126
109
127
var_params , objective = create_dadvi_graph (
110
128
model ,
@@ -113,31 +131,65 @@ def fit_dadvi(
113
131
n_params = n_params ,
114
132
)
115
133
116
- f_fused , f_hessp = _compile_functions_for_scipy_optimize (
117
- objective ,
118
- [var_params ],
119
- compute_grad = use_grad ,
120
- compute_hessp = use_hessp ,
121
- compute_hess = use_hess ,
134
+ use_grad , use_hess , use_hessp = set_optimizer_function_defaults (
135
+ optimizer_method , use_grad , use_hess , use_hessp
122
136
)
123
137
124
- derivative_kwargs = {}
125
-
126
- if use_grad :
127
- derivative_kwargs ["jac" ] = True
128
- if use_hessp :
129
- derivative_kwargs ["hessp" ] = f_hessp
130
- if use_hess :
131
- derivative_kwargs ["hess" ] = True
138
+ f_fused , f_hessp = scipy_optimize_funcs_from_loss (
139
+ loss = objective ,
140
+ inputs = [var_params ],
141
+ initial_point_dict = None ,
142
+ use_grad = use_grad ,
143
+ use_hessp = use_hessp ,
144
+ use_hess = use_hess ,
145
+ gradient_backend = gradient_backend ,
146
+ compile_kwargs = compile_kwargs ,
147
+ inputs_are_flat = True ,
148
+ )
132
149
133
- result = minimize (
134
- f_fused ,
135
- np .zeros (2 * n_params ),
136
- method = optimizer_method ,
137
- ** derivative_kwargs ,
138
- ** minimize_kwargs ,
150
+ dadvi_initial_point = {
151
+ f"{ var_name } _mu" : np .zeros_like (value ).ravel ()
152
+ for var_name , value in initial_point_dict .items ()
153
+ }
154
+ dadvi_initial_point .update (
155
+ {
156
+ f"{ var_name } _sigma__log" : np .zeros_like (value ).ravel ()
157
+ for var_name , value in initial_point_dict .items ()
158
+ }
139
159
)
140
160
161
+ dadvi_initial_point = DictToArrayBijection .map (dadvi_initial_point )
162
+ args = optimizer_kwargs .pop ("args" , ())
163
+
164
+ if do_basinhopping :
165
+ if "args" not in minimizer_kwargs :
166
+ minimizer_kwargs ["args" ] = args
167
+ if "hessp" not in minimizer_kwargs :
168
+ minimizer_kwargs ["hessp" ] = f_hessp
169
+ if "method" not in minimizer_kwargs :
170
+ minimizer_kwargs ["method" ] = optimizer_method
171
+
172
+ result = basinhopping (
173
+ func = f_fused ,
174
+ x0 = dadvi_initial_point .data ,
175
+ progressbar = progressbar ,
176
+ minimizer_kwargs = minimizer_kwargs ,
177
+ ** optimizer_kwargs ,
178
+ )
179
+
180
+ else :
181
+ result = minimize (
182
+ f = f_fused ,
183
+ x0 = dadvi_initial_point .data ,
184
+ args = args ,
185
+ method = optimizer_method ,
186
+ hessp = f_hessp ,
187
+ progressbar = progressbar ,
188
+ ** optimizer_kwargs ,
189
+ )
190
+
191
+ raveled_optimized = RaveledVars (result .x , dadvi_initial_point .point_map_info )
192
+
141
193
opt_var_params = result .x
142
194
opt_means , opt_log_sds = np .split (opt_var_params , 2 )
143
195
@@ -148,9 +200,29 @@ def fit_dadvi(
148
200
draws = opt_means + draws_raw * np .exp (opt_log_sds )
149
201
draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
150
202
151
- transformed_draws = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
203
+ idata = dadvi_result_to_idata (
204
+ draws_arviz , model , include_transformed = include_transformed , progressbar = progressbar
205
+ )
152
206
153
- return transformed_draws
207
+ var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
208
+ var_name_to_model_var .update (
209
+ {f"{ var_name } _sigma__log" : var_name for var_name in initial_point_dict .keys ()}
210
+ )
211
+
212
+ idata = add_optimizer_result_to_inference_data (
213
+ idata = idata ,
214
+ result = result ,
215
+ method = optimizer_method ,
216
+ mu = raveled_optimized ,
217
+ model = model ,
218
+ var_name_to_model_var = var_name_to_model_var ,
219
+ )
220
+
221
+ idata = add_data_to_inference_data (
222
+ idata = idata , progressbar = False , model = model , compile_kwargs = compile_kwargs
223
+ )
224
+
225
+ return idata
154
226
155
227
156
228
def create_dadvi_graph (
@@ -213,10 +285,11 @@ def create_dadvi_graph(
213
285
return var_params , objective
214
286
215
287
216
- def transform_draws (
288
+ def dadvi_result_to_idata (
217
289
unstacked_draws : xarray .Dataset ,
218
290
model : Model ,
219
- keep_untransformed : bool = False ,
291
+ include_transformed : bool = False ,
292
+ progressbar : bool = True ,
220
293
):
221
294
"""
222
295
Transforms the unconstrained draws back into the constrained space.
@@ -232,9 +305,12 @@ def transform_draws(
232
305
n_draws: int
233
306
The number of draws to return from the variational approximation.
234
307
235
- keep_untransformed : bool
308
+ include_transformed : bool
236
309
Whether or not to keep the unconstrained variables in the output.
237
310
311
+ progressbar: bool
312
+ Whether or not to show a progress bar during the transformation. Default is True.
313
+
238
314
Returns
239
315
-------
240
316
:class:`~arviz.InferenceData`
@@ -243,7 +319,7 @@ def transform_draws(
243
319
244
320
filtered_var_names = model .unobserved_value_vars
245
321
vars_to_sample = list (
246
- get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
322
+ get_default_varnames (filtered_var_names , include_transformed = include_transformed )
247
323
)
248
324
fn = pytensor .function (model .value_vars , vars_to_sample )
249
325
point_func = PointFunc (fn )
@@ -256,6 +332,20 @@ def transform_draws(
256
332
output_var_names = [x .name for x in vars_to_sample ],
257
333
coords = coords ,
258
334
dims = dims ,
335
+ progressbar = progressbar ,
259
336
)
260
337
261
- return transformed_result
338
+ constrained_names = [
339
+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = False )
340
+ ]
341
+ all_varnames = [
342
+ x .name for x in get_default_varnames (model .unobserved_value_vars , include_transformed = True )
343
+ ]
344
+ unconstrained_names = sorted (set (all_varnames ) - set (constrained_names ))
345
+
346
+ idata = az .InferenceData (posterior = transformed_result [constrained_names ])
347
+
348
+ if unconstrained_names and include_transformed :
349
+ idata ["unconstrained_posterior" ] = transformed_result [unconstrained_names ]
350
+
351
+ return idata
0 commit comments