20
20
from typing import Dict , List , Tuple , Union , Optional
21
21
from functools import partial
22
22
23
+ from copy import deepcopy
23
24
import lmfit
24
25
import numpy as np
25
26
import pandas as pd
31
32
)
32
33
from qiskit_experiments .framework .containers import FigureType , ArtifactData
33
34
from qiskit_experiments .data_processing .exceptions import DataProcessorError
35
+ from qiskit_experiments .visualization import PlotStyle
34
36
35
37
from .base_curve_analysis import BaseCurveAnalysis , DATA_ENTRY_PREFIX , PARAMS_ENTRY_PREFIX
36
38
from .curve_data import FitOptions , CurveFitResult
@@ -123,6 +125,7 @@ def __init__(
123
125
124
126
self ._models = models or []
125
127
self ._name = name or self .__class__ .__name__
128
+ self ._plot_config_cache = {}
126
129
127
130
@property
128
131
def name (self ) -> str :
@@ -148,6 +151,118 @@ def model_names(self) -> List[str]:
148
151
"""Return model names."""
149
152
return [getattr (m , "_name" , f"model-{ i } " ) for i , m in enumerate (self ._models )]
150
153
154
+ def set_options (self , ** fields ):
155
+ """Set the analysis options for :meth:`run` method.
156
+
157
+ Args:
158
+ fields: The fields to update the options
159
+
160
+ Raises:
161
+ KeyError: When removed option ``curve_fitter`` is set.
162
+ """
163
+ if fields .get ("plot_residuals" ) and not self .options .get ("plot_residuals" ):
164
+ # checking there are no subplots for the figure to prevent collision in subplot indices.
165
+ if self .plotter .options .get ("subplots" ) != (1 , 1 ):
166
+ warnings .warn (
167
+ "Residuals plotting is currently supported for analysis with 1 subplot." ,
168
+ UserWarning ,
169
+ stacklevel = 2 ,
170
+ )
171
+ fields ["plot_residuals" ] = False
172
+ else :
173
+ self ._add_residuals_plot_config ()
174
+ if not fields .get ("plot_residuals" , True ) and self .options .get ("plot_residuals" ):
175
+ self ._remove_residuals_plot_config ()
176
+
177
+ super ().set_options (** fields )
178
+
179
+ def _add_residuals_plot_config (self ):
180
+ """Configure plotter options for residuals plot."""
181
+ # check we have model to fit into
182
+ residual_plot_y_axis_size = 3
183
+ if self .models :
184
+ # Cache figure options.
185
+ self ._plot_config_cache ["figure_options" ] = {}
186
+ self ._plot_config_cache ["figure_options" ]["ylabel" ] = self .plotter .figure_options .get (
187
+ "ylabel"
188
+ )
189
+ self ._plot_config_cache ["figure_options" ]["series_params" ] = deepcopy (
190
+ self .plotter .figure_options .get ("series_params" )
191
+ )
192
+ self ._plot_config_cache ["figure_options" ]["sharey" ] = self .plotter .figure_options .get (
193
+ "sharey"
194
+ )
195
+
196
+ self .plotter .set_figure_options (
197
+ ylabel = [
198
+ self .plotter .figure_options .get ("ylabel" , "" ),
199
+ "Residuals" ,
200
+ ],
201
+ )
202
+
203
+ model_names = self .model_names ()
204
+ series_params = self .plotter .figure_options ["series_params" ]
205
+ for model_name in model_names :
206
+ if series_params .get (model_name ):
207
+ series_params [model_name ]["canvas" ] = 0
208
+ else :
209
+ series_params [model_name ] = {"canvas" : 0 }
210
+ series_params [model_name + "_residuals" ] = series_params [model_name ].copy ()
211
+ series_params [model_name + "_residuals" ]["canvas" ] = 1
212
+ self .plotter .set_figure_options (sharey = False , series_params = series_params )
213
+
214
+ # Cache plotter options.
215
+ self ._plot_config_cache ["plotter" ] = {}
216
+ self ._plot_config_cache ["plotter" ]["subplots" ] = self .plotter .options .get ("subplots" )
217
+ self ._plot_config_cache ["plotter" ]["style" ] = deepcopy (
218
+ self .plotter .options .get ("style" , PlotStyle ({}))
219
+ )
220
+
221
+ # removing the name from the plotter style, so it will not clash with the new name
222
+ previous_plotter_style = self ._plot_config_cache ["plotter" ]["style" ].copy ()
223
+ previous_plotter_style .pop ("style_name" , "" )
224
+
225
+ # creating new fig size based on previous size
226
+ new_figsize = self .plotter .drawer .options .get ("figsize" , (8 , 5 ))
227
+ new_figsize = (new_figsize [0 ], new_figsize [1 ] + residual_plot_y_axis_size )
228
+
229
+ # Here add the configuration for the residuals plot:
230
+ self .plotter .set_options (
231
+ subplots = (2 , 1 ),
232
+ style = PlotStyle .merge (
233
+ PlotStyle (
234
+ {
235
+ "figsize" : new_figsize ,
236
+ "textbox_rel_pos" : (0.28 , - 0.10 ),
237
+ "sub_plot_heights_list" : [7 / 10 , 3 / 10 ],
238
+ "sub_plot_widths_list" : [1 ],
239
+ "style_name" : "residuals" ,
240
+ }
241
+ ),
242
+ previous_plotter_style ,
243
+ ),
244
+ )
245
+
246
+ def _remove_residuals_plot_config (self ):
247
+ """set options for a single plot to its cached values."""
248
+ if self .models :
249
+ self .plotter .set_figure_options (
250
+ ylabel = self ._plot_config_cache ["figure_options" ]["ylabel" ],
251
+ sharey = self ._plot_config_cache ["figure_options" ]["sharey" ],
252
+ series_params = self ._plot_config_cache ["figure_options" ]["series_params" ],
253
+ )
254
+
255
+ # Here add the style_name so the plotter will know not to print the residual data.
256
+ self .plotter .set_options (
257
+ subplots = self ._plot_config_cache ["plotter" ]["subplots" ],
258
+ style = PlotStyle .merge (
259
+ self ._plot_config_cache ["plotter" ]["style" ],
260
+ PlotStyle ({"style_name" : "canceled_residuals" }),
261
+ ),
262
+ )
263
+
264
+ self ._plot_config_cache = {}
265
+
151
266
def _run_data_processing (
152
267
self ,
153
268
raw_data : List [Dict ],
@@ -335,8 +450,13 @@ def _run_curve_fit(
335
450
fit_options = [fit_options ]
336
451
337
452
# Create convenient function to compute residual of the models.
338
- partial_residuals = []
453
+ partial_weighted_residuals = []
339
454
valid_uncertainty = np .all (np .isfinite (curve_data .y_err ))
455
+
456
+ # creating storage for residual plotting
457
+ if self .options .get ("plot_residuals" ):
458
+ residual_weights_list = []
459
+
340
460
for idx , sub_data in curve_data .iter_by_series_id ():
341
461
if valid_uncertainty :
342
462
nonzero_yerr = np .where (
@@ -350,16 +470,23 @@ def _run_curve_fit(
350
470
# some yerr values might be very close to zero, yielding significant weights.
351
471
# With such outlier, the fit doesn't sense residual of other data points.
352
472
maximum_weight = np .percentile (raw_weights , 90 )
353
- weights = np .clip (raw_weights , 0.0 , maximum_weight )
473
+ weights_list = np .clip (raw_weights , 0.0 , maximum_weight )
354
474
else :
355
- weights = None
356
- model_residual = partial (
475
+ weights_list = None
476
+ model_weighted_residual = partial (
357
477
self ._models [idx ]._residual ,
358
478
data = sub_data .y ,
359
- weights = weights ,
479
+ weights = weights_list ,
360
480
x = sub_data .x ,
361
481
)
362
- partial_residuals .append (model_residual )
482
+ partial_weighted_residuals .append (model_weighted_residual )
483
+
484
+ # adding weights to weights_list for residuals
485
+ if self .options .get ("plot_residuals" ):
486
+ if weights_list is None :
487
+ residual_weights_list .append (None )
488
+ else :
489
+ residual_weights_list .append (weights_list )
363
490
364
491
# Run fit for each configuration
365
492
res = None
@@ -379,7 +506,7 @@ def _run_curve_fit(
379
506
try :
380
507
with np .errstate (all = "ignore" ):
381
508
new = lmfit .minimize (
382
- fcn = lambda x : np .concatenate ([p (x ) for p in partial_residuals ]),
509
+ fcn = lambda x : np .concatenate ([p (x ) for p in partial_weighted_residuals ]),
383
510
params = guess_params ,
384
511
method = self .options .fit_method ,
385
512
scale_covar = not valid_uncertainty ,
@@ -396,11 +523,30 @@ def _run_curve_fit(
396
523
if new .success and res .redchi > new .redchi :
397
524
res = new
398
525
526
+ # if `plot_residuals` is ``False`` I would like the `residuals_model` be None to emphasize it
527
+ # wasn't calculated.
528
+ residuals_model = [] if self .options .get ("plot_residuals" ) else None
529
+ if res and res .success and self .options .get ("plot_residuals" ):
530
+ for weights in residual_weights_list :
531
+ if weights is None :
532
+ residuals_model .append (res .residual )
533
+ else :
534
+ residuals_model .append (
535
+ [
536
+ weighted_res / np .abs (weight )
537
+ for weighted_res , weight in zip (res .residual , weights )
538
+ ]
539
+ )
540
+
541
+ if residuals_model is not None :
542
+ residuals_model = np .array (residuals_model )
543
+
399
544
return convert_lmfit_result (
400
545
res ,
401
546
self ._models ,
402
547
curve_data .x ,
403
548
curve_data .y ,
549
+ residuals_model ,
404
550
)
405
551
406
552
def _create_figures (
@@ -449,6 +595,14 @@ def _create_figures(
449
595
y_interp_err = fit_stdev ,
450
596
)
451
597
598
+ if self .options .get ("plot_residuals" ):
599
+ residuals_data = sub_data .filter (category = "residuals" )
600
+ self .plotter .set_series_data (
601
+ series_name = model_name ,
602
+ x_residuals = residuals_data .x ,
603
+ y_residuals = residuals_data .y ,
604
+ )
605
+
452
606
return [self .plotter .figure ()]
453
607
454
608
def _run_analysis (
@@ -526,6 +680,22 @@ def _run_analysis(
526
680
category = "fitted" ,
527
681
analysis = self .name ,
528
682
)
683
+
684
+ if self .options .get ("plot_residuals" ):
685
+ # need to add here the residuals plot.
686
+ xval_residual = sub_data .x
687
+ yval_residuals = unp .nominal_values (fit_data .residuals [series_id ])
688
+
689
+ for xval , yval in zip (xval_residual , yval_residuals ):
690
+ table .add_row (
691
+ xval = xval ,
692
+ yval = yval ,
693
+ series_name = model_names [series_id ],
694
+ series_id = series_id ,
695
+ category = "residuals" ,
696
+ analysis = self .name ,
697
+ )
698
+
529
699
result_data .extend (
530
700
self ._create_analysis_results (
531
701
fit_data = fit_data ,
0 commit comments