Skip to content

Commit 6845440

Browse files
Add option to plot residuals (#1382)
### Summary This PR will give the user the option to pass `plot_residual=True` to the analysis. This will add a residual plot to the figure. ### Details and comments Some details that should be in this section include: - Came from an issue that was open #1169 . - What tests and documentation have been added/updated - What do users and developers need to know about this change What needed to do in this PR: - [x] Add support for different subplot sizes. - [x] Add automatic residual calculation. - [x] Plot residuals. - [x] Style the plot's color, legend, limits, etc. When setting `plot_residuals=True` in analysis option, a residual plot will be added. currently work only for experiments with 1 plot in its figure. Example for output with `plot_residuals=True` for `RamseyXY` experiment: ![image](https://github.com/Qiskit-Extensions/qiskit-experiments/assets/51112651/206d7900-d196-474f-a87a-4dea0d15c531) --------- Co-authored-by: Yael Ben-Haim <[email protected]>
1 parent 0bbd426 commit 6845440

File tree

9 files changed

+293
-13
lines changed

9 files changed

+293
-13
lines changed

docs/tutorials/visualization.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ Plotters have two sets of options that customize their behavior and figure conte
105105
and ``figure_options``, which have figure-specific parameters that control aspects of the
106106
figure itself, such as axis labels and series colors.
107107

108+
To see the residual plot, set ``plot_residuals=True`` in the analysis options:
109+
110+
.. jupyter-execute::
111+
112+
# Set to ``True`` analysis option for residual plot
113+
rabi.analysis.set_options(plot_residuals=True)
114+
115+
# Run experiment
116+
rabi_data = rabi.run().block_for_results()
117+
rabi_data.figure(0)
118+
119+
120+
This option works for experiments without subplots in their figures.
121+
108122
Here is a more complicated experiment in which we customize the figure of a DRAG
109123
experiment before it's run, so that we don't need to regenerate the figure like in
110124
the previous example. First, we run the experiment without customizing the options

qiskit_experiments/curve_analysis/base_curve_analysis.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def _default_options(cls) -> Options:
153153
the analysis result.
154154
plot_raw_data (bool): Set ``True`` to draw processed data points,
155155
dataset without formatting, on canvas. This is ``False`` by default.
156+
plot_residuals (bool): Set ``True`` to draw the residuals data for the
157+
fitting model. This is ``False`` by default.
156158
plot (bool): Set ``True`` to create figure for fit result or ``False`` to
157159
not create a figure. This overrides the behavior of ``generate_figures``.
158160
return_fit_parameters (bool): (Deprecated) Set ``True`` to return all fit model parameters
@@ -207,6 +209,7 @@ def _default_options(cls) -> Options:
207209

208210
options.plotter = CurvePlotter(MplDrawer())
209211
options.plot_raw_data = False
212+
options.plot_residuals = False
210213
options.return_fit_parameters = True
211214
options.return_data_points = False
212215
options.data_processor = None

qiskit_experiments/curve_analysis/curve_analysis.py

Lines changed: 177 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Dict, List, Tuple, Union, Optional
2121
from functools import partial
2222

23+
from copy import deepcopy
2324
import lmfit
2425
import numpy as np
2526
import pandas as pd
@@ -31,6 +32,7 @@
3132
)
3233
from qiskit_experiments.framework.containers import FigureType, ArtifactData
3334
from qiskit_experiments.data_processing.exceptions import DataProcessorError
35+
from qiskit_experiments.visualization import PlotStyle
3436

3537
from .base_curve_analysis import BaseCurveAnalysis, DATA_ENTRY_PREFIX, PARAMS_ENTRY_PREFIX
3638
from .curve_data import FitOptions, CurveFitResult
@@ -123,6 +125,7 @@ def __init__(
123125

124126
self._models = models or []
125127
self._name = name or self.__class__.__name__
128+
self._plot_config_cache = {}
126129

127130
@property
128131
def name(self) -> str:
@@ -148,6 +151,118 @@ def model_names(self) -> List[str]:
148151
"""Return model names."""
149152
return [getattr(m, "_name", f"model-{i}") for i, m in enumerate(self._models)]
150153

154+
def set_options(self, **fields):
155+
"""Set the analysis options for :meth:`run` method.
156+
157+
Args:
158+
fields: The fields to update the options
159+
160+
Raises:
161+
KeyError: When removed option ``curve_fitter`` is set.
162+
"""
163+
if fields.get("plot_residuals") and not self.options.get("plot_residuals"):
164+
# checking there are no subplots for the figure to prevent collision in subplot indices.
165+
if self.plotter.options.get("subplots") != (1, 1):
166+
warnings.warn(
167+
"Residuals plotting is currently supported for analysis with 1 subplot.",
168+
UserWarning,
169+
stacklevel=2,
170+
)
171+
fields["plot_residuals"] = False
172+
else:
173+
self._add_residuals_plot_config()
174+
if not fields.get("plot_residuals", True) and self.options.get("plot_residuals"):
175+
self._remove_residuals_plot_config()
176+
177+
super().set_options(**fields)
178+
179+
def _add_residuals_plot_config(self):
180+
"""Configure plotter options for residuals plot."""
181+
# check we have model to fit into
182+
residual_plot_y_axis_size = 3
183+
if self.models:
184+
# Cache figure options.
185+
self._plot_config_cache["figure_options"] = {}
186+
self._plot_config_cache["figure_options"]["ylabel"] = self.plotter.figure_options.get(
187+
"ylabel"
188+
)
189+
self._plot_config_cache["figure_options"]["series_params"] = deepcopy(
190+
self.plotter.figure_options.get("series_params")
191+
)
192+
self._plot_config_cache["figure_options"]["sharey"] = self.plotter.figure_options.get(
193+
"sharey"
194+
)
195+
196+
self.plotter.set_figure_options(
197+
ylabel=[
198+
self.plotter.figure_options.get("ylabel", ""),
199+
"Residuals",
200+
],
201+
)
202+
203+
model_names = self.model_names()
204+
series_params = self.plotter.figure_options["series_params"]
205+
for model_name in model_names:
206+
if series_params.get(model_name):
207+
series_params[model_name]["canvas"] = 0
208+
else:
209+
series_params[model_name] = {"canvas": 0}
210+
series_params[model_name + "_residuals"] = series_params[model_name].copy()
211+
series_params[model_name + "_residuals"]["canvas"] = 1
212+
self.plotter.set_figure_options(sharey=False, series_params=series_params)
213+
214+
# Cache plotter options.
215+
self._plot_config_cache["plotter"] = {}
216+
self._plot_config_cache["plotter"]["subplots"] = self.plotter.options.get("subplots")
217+
self._plot_config_cache["plotter"]["style"] = deepcopy(
218+
self.plotter.options.get("style", PlotStyle({}))
219+
)
220+
221+
# removing the name from the plotter style, so it will not clash with the new name
222+
previous_plotter_style = self._plot_config_cache["plotter"]["style"].copy()
223+
previous_plotter_style.pop("style_name", "")
224+
225+
# creating new fig size based on previous size
226+
new_figsize = self.plotter.drawer.options.get("figsize", (8, 5))
227+
new_figsize = (new_figsize[0], new_figsize[1] + residual_plot_y_axis_size)
228+
229+
# Here add the configuration for the residuals plot:
230+
self.plotter.set_options(
231+
subplots=(2, 1),
232+
style=PlotStyle.merge(
233+
PlotStyle(
234+
{
235+
"figsize": new_figsize,
236+
"textbox_rel_pos": (0.28, -0.10),
237+
"sub_plot_heights_list": [7 / 10, 3 / 10],
238+
"sub_plot_widths_list": [1],
239+
"style_name": "residuals",
240+
}
241+
),
242+
previous_plotter_style,
243+
),
244+
)
245+
246+
def _remove_residuals_plot_config(self):
247+
"""set options for a single plot to its cached values."""
248+
if self.models:
249+
self.plotter.set_figure_options(
250+
ylabel=self._plot_config_cache["figure_options"]["ylabel"],
251+
sharey=self._plot_config_cache["figure_options"]["sharey"],
252+
series_params=self._plot_config_cache["figure_options"]["series_params"],
253+
)
254+
255+
# Here add the style_name so the plotter will know not to print the residual data.
256+
self.plotter.set_options(
257+
subplots=self._plot_config_cache["plotter"]["subplots"],
258+
style=PlotStyle.merge(
259+
self._plot_config_cache["plotter"]["style"],
260+
PlotStyle({"style_name": "canceled_residuals"}),
261+
),
262+
)
263+
264+
self._plot_config_cache = {}
265+
151266
def _run_data_processing(
152267
self,
153268
raw_data: List[Dict],
@@ -335,8 +450,13 @@ def _run_curve_fit(
335450
fit_options = [fit_options]
336451

337452
# Create convenient function to compute residual of the models.
338-
partial_residuals = []
453+
partial_weighted_residuals = []
339454
valid_uncertainty = np.all(np.isfinite(curve_data.y_err))
455+
456+
# creating storage for residual plotting
457+
if self.options.get("plot_residuals"):
458+
residual_weights_list = []
459+
340460
for idx, sub_data in curve_data.iter_by_series_id():
341461
if valid_uncertainty:
342462
nonzero_yerr = np.where(
@@ -350,16 +470,23 @@ def _run_curve_fit(
350470
# some yerr values might be very close to zero, yielding significant weights.
351471
# With such outlier, the fit doesn't sense residual of other data points.
352472
maximum_weight = np.percentile(raw_weights, 90)
353-
weights = np.clip(raw_weights, 0.0, maximum_weight)
473+
weights_list = np.clip(raw_weights, 0.0, maximum_weight)
354474
else:
355-
weights = None
356-
model_residual = partial(
475+
weights_list = None
476+
model_weighted_residual = partial(
357477
self._models[idx]._residual,
358478
data=sub_data.y,
359-
weights=weights,
479+
weights=weights_list,
360480
x=sub_data.x,
361481
)
362-
partial_residuals.append(model_residual)
482+
partial_weighted_residuals.append(model_weighted_residual)
483+
484+
# adding weights to weights_list for residuals
485+
if self.options.get("plot_residuals"):
486+
if weights_list is None:
487+
residual_weights_list.append(None)
488+
else:
489+
residual_weights_list.append(weights_list)
363490

364491
# Run fit for each configuration
365492
res = None
@@ -379,7 +506,7 @@ def _run_curve_fit(
379506
try:
380507
with np.errstate(all="ignore"):
381508
new = lmfit.minimize(
382-
fcn=lambda x: np.concatenate([p(x) for p in partial_residuals]),
509+
fcn=lambda x: np.concatenate([p(x) for p in partial_weighted_residuals]),
383510
params=guess_params,
384511
method=self.options.fit_method,
385512
scale_covar=not valid_uncertainty,
@@ -396,11 +523,30 @@ def _run_curve_fit(
396523
if new.success and res.redchi > new.redchi:
397524
res = new
398525

526+
# if `plot_residuals` is ``False`` I would like the `residuals_model` be None to emphasize it
527+
# wasn't calculated.
528+
residuals_model = [] if self.options.get("plot_residuals") else None
529+
if res and res.success and self.options.get("plot_residuals"):
530+
for weights in residual_weights_list:
531+
if weights is None:
532+
residuals_model.append(res.residual)
533+
else:
534+
residuals_model.append(
535+
[
536+
weighted_res / np.abs(weight)
537+
for weighted_res, weight in zip(res.residual, weights)
538+
]
539+
)
540+
541+
if residuals_model is not None:
542+
residuals_model = np.array(residuals_model)
543+
399544
return convert_lmfit_result(
400545
res,
401546
self._models,
402547
curve_data.x,
403548
curve_data.y,
549+
residuals_model,
404550
)
405551

406552
def _create_figures(
@@ -449,6 +595,14 @@ def _create_figures(
449595
y_interp_err=fit_stdev,
450596
)
451597

598+
if self.options.get("plot_residuals"):
599+
residuals_data = sub_data.filter(category="residuals")
600+
self.plotter.set_series_data(
601+
series_name=model_name,
602+
x_residuals=residuals_data.x,
603+
y_residuals=residuals_data.y,
604+
)
605+
452606
return [self.plotter.figure()]
453607

454608
def _run_analysis(
@@ -526,6 +680,22 @@ def _run_analysis(
526680
category="fitted",
527681
analysis=self.name,
528682
)
683+
684+
if self.options.get("plot_residuals"):
685+
# need to add here the residuals plot.
686+
xval_residual = sub_data.x
687+
yval_residuals = unp.nominal_values(fit_data.residuals[series_id])
688+
689+
for xval, yval in zip(xval_residual, yval_residuals):
690+
table.add_row(
691+
xval=xval,
692+
yval=yval,
693+
series_name=model_names[series_id],
694+
series_id=series_id,
695+
category="residuals",
696+
analysis=self.name,
697+
)
698+
529699
result_data.extend(
530700
self._create_analysis_results(
531701
fit_data=fit_data,

qiskit_experiments/curve_analysis/curve_data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def __init__(
168168
var_names: Optional[List[str]] = None,
169169
x_data: Optional[np.ndarray] = None,
170170
y_data: Optional[np.ndarray] = None,
171+
weighted_residuals: Optional[np.ndarray] = None,
172+
residuals: Optional[np.ndarray] = None,
171173
covar: Optional[np.ndarray] = None,
172174
):
173175
"""Create new Qiskit curve analysis result object.
@@ -188,6 +190,8 @@ def __init__(
188190
var_names: Name of variables, i.e. fixed parameters are excluded from the list.
189191
x_data: X values used for the fitting.
190192
y_data: Y values used for the fitting.
193+
weighted_residuals: The residuals from the fitting after assigning weights for each ydata.
194+
residuals: residuals of the fitted model.
191195
covar: Covariance matrix of fitting variables.
192196
"""
193197
self.method = method
@@ -205,6 +209,8 @@ def __init__(
205209
self.var_names = var_names
206210
self.x_data = x_data
207211
self.y_data = y_data
212+
self.weighted_residuals = weighted_residuals
213+
self.residuals = residuals
208214
self.covar = covar
209215

210216
@property

qiskit_experiments/curve_analysis/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def convert_lmfit_result(
115115
models: List[lmfit.Model],
116116
xdata: np.ndarray,
117117
ydata: np.ndarray,
118+
residuals: np.ndarray,
118119
) -> CurveFitResult:
119120
"""A helper function to convert LMFIT ``MinimizerResult`` into :class:`.CurveFitResult`.
120121
@@ -128,6 +129,7 @@ def convert_lmfit_result(
128129
models: Model used for the fitting. Function description is extracted.
129130
xdata: X values used for the fitting.
130131
ydata: Y values used for the fitting.
132+
residuals: The residuals of the ydata from the model.
131133
132134
Returns:
133135
QiskitExperiments :class:`.CurveFitResult` object.
@@ -169,6 +171,8 @@ def convert_lmfit_result(
169171
var_names=result.var_names,
170172
x_data=xdata,
171173
y_data=ydata,
174+
weighted_residuals=result.residual,
175+
residuals=residuals,
172176
covar=covar,
173177
)
174178

0 commit comments

Comments
 (0)