diff --git a/changelog/181.feature.rst b/changelog/181.feature.rst new file mode 100644 index 00000000..304dc344 --- /dev/null +++ b/changelog/181.feature.rst @@ -0,0 +1 @@ +Created new `stixpy.visualisation.plotters.PixelPlotter` class to create pixel plots by moving and refactored code from current mixin so current API is maintained. This new approach simplifies the `~stixpy.product.sources.science` module and programmatic access to the plot controls enabling animations. diff --git a/docs/reference/visualisation.rst b/docs/reference/visualisation.rst index d1724a49..7184f4c6 100644 --- a/docs/reference/visualisation.rst +++ b/docs/reference/visualisation.rst @@ -8,3 +8,5 @@ The ``visualisation`` submodule contain classes and function related to the visu .. automodapi:: stixpy.visualisation .. automodapi:: stixpy.visualisation.map_reprojection + +.. automodapi:: stixpy.visualisation.plotters diff --git a/docs/tutorials/quickstart.rst b/docs/tutorials/quickstart.rst index 760b2e0f..1dc14606 100644 --- a/docs/tutorials/quickstart.rst +++ b/docs/tutorials/quickstart.rst @@ -110,6 +110,7 @@ quicklook data above but change the query to search a narrower time window and s sci_query = Fido.search(a.Time('2020-06-07T21:30', '2020-06-07T22:00'), a.Instrument.stix, a.stix.DataType.sci) + sci_query['stix'].filter_for_latest_version() # only keep latest versions This should return a list of data files similar to this. diff --git a/stixpy/product/sources/science.py b/stixpy/product/sources/science.py index fd900342..4754cf64 100644 --- a/stixpy/product/sources/science.py +++ b/stixpy/product/sources/science.py @@ -1,18 +1,14 @@ -import copy from pathlib import Path from itertools import product -from collections import defaultdict import astropy.units as u import numpy as np from astropy.table import QTable, vstack from astropy.time import Time from astropy.visualization import quantity_support -from matplotlib import cm, colors from matplotlib import pyplot as plt from matplotlib.colors import LogNorm -from matplotlib.dates import DateFormatter, HourLocator -from matplotlib.patches import Patch +from matplotlib.dates import ConciseDateFormatter, DateFormatter, HourLocator from matplotlib.widgets import Slider from sunpy.time.timerange import TimeRange @@ -36,6 +32,7 @@ "EnergyEdgeMasks", ] +from stixpy.visualisation.plotters import PixelPlotter quantity_support() @@ -301,7 +298,7 @@ def plot_timeseries( if axes is None: fig, axes = plt.subplots() else: - fig = axes.get_figure() + axes = axes.get_figure() if detector_indices == "all": detector_indices = [[0, 31]] @@ -336,9 +333,7 @@ def plot_timeseries( lines = axes.plot(times.to_datetime(), counts[:, did, pid, eid], label=labels[eid], **plot_kwarg) axes.set_yscale("log") - axes.xaxis.set_major_formatter(DateFormatter("%d %H:%M")) - fig.autofmt_xdate() - fig.tight_layout() + axes.xaxis.set_major_formatter(ConciseDateFormatter(axes.xaxis.get_major_locator())) return lines @@ -348,435 +343,10 @@ class PixelPlotMixin: Pixel plot mixin providing pixel plotting for pixel data. """ - def plot_pixels(self, *, kind="pixels", time_indices=None, energy_indices=None, fig=None, cmap=None): - """ - Plot individual pixel data for each detector. - - Parameters - ---------- - kind : `string` the options: 'pixels', 'errorbar', 'config' - This sets the visualization type of the subplots. The data will then be shown in the selected style. - time_indices : `list` or `numpy.ndarray` - If an 1xN array will be treated as mask if 2XN array will sum data between given - indices. For example `time_indices=[0, 2, 5]` would return only the first, third and - sixth times while `time_indices=[[0, 2],[3, 5]]` would sum the data between. - energy_indices : `list` or `numpy.ndarray` - If an 1xN array will be treated as mask if 2XN array will sum data between given - indices. For example `energy_indices=[0, 2, 5]` would return only the first, third and - sixth times while `energy_indices=[[0, 2],[3, 5]]` would sum the data between. - fig : optional `matplotlib.figure` - The figure where to which the pixel plot will be added. - cmap : `string` | `colormap` optional - If the kind is `pixels` a colormap will be shown. - String : default colormap name - colormap: a custom colormap - NOTE: If the color of the special detectors 'cfl', 'bkg' is way above - the imaging detectors, the color will be automatically set to white. - - Returns - ------- - `matplotlib.figure` - The figure - """ - - if kind not in ["pixels", "errorbar", "config"]: - kind = "pixels" - - if fig: - axes = fig.subplots(nrows=4, ncols=8, sharex=True, sharey=True, figsize=(7, 7)) - else: - fig, axes = plt.subplots(nrows=4, ncols=8, sharex=True, sharey=True, figsize=(7, 7)) - - counts, count_err, times, dt, energies = self.get_data(time_indices=time_indices, energy_indices=energy_indices) - - imaging_mask = np.ones(32, bool) - imaging_mask[8:10] = False - - max_counts = counts[:, imaging_mask, :, :].max().value - min_counts = counts[:, imaging_mask, :, :].min().value - - norm = plt.Normalize(min_counts, max_counts) # Needed to select the color values for the pixels plot. - det_font = {"weight": "regular", "size": 8} - axes_font = {"weight": "regular", "size": 7} - quadrant_font = {"weight": "regular", "size": 15} - - # pad counts back to have 12 pixels - unit = counts.unit - counts_pad = np.full((*counts.shape[:2], 12, counts.shape[-1]), np.nan) - counts_pad[..., self.pixel_masks.masks.astype(bool).flatten(), :] = counts - - count_err_pad = np.full((*counts.shape[:2], 12, counts.shape[-1]), np.nan) - count_err_pad[..., self.pixel_masks.masks.astype(bool).flatten(), :] = count_err - - counts = counts_pad << unit - count_err = count_err_pad << unit - - if cmap is None: - clrmap = copy.copy(cm.get_cmap("viridis")) - clrmap.set_over("w") - elif isinstance(cmap, str): - clrmap = copy.copy(cm.get_cmap(cmap)) - else: - clrmap = cmap - - def timeval(val): - return times[val].isot - - def energyval(val): - return f"{energies[val]['e_low'].value}-{energies[val]['e_high']}" - - def det_pixels_plot(counts, norm, axes, clrmap, fig, last=False): - """ - Shows a plot to visualize the pixel counts; the pixels plot. - - Parameters - ---------- - counts : `List` - data collection with the number of counts - norm : `function` - normalizes the data in the parentheses - axes : `matplotlib.axes` - the axes in which the data will be plotted - clrmap : `colormap` - the colormap which will be used to visualize the data/counts - fig : `matplotlib.figure` - the current figure to use - - Returns - ------- - top: The 4 pixels positioned at the top - bottom: The 4 pixels positioned at the bottom - small: The 4 small pixels in the middle - """ - - # Set the variables needed. - bar1 = [1, 1, 1, 1] - bar2 = [-1, -1, -1, -1] - bar3 = [0.2, 0.2, 0.2, 0.2] - x_pos = ["A", "B", "C", "D"] - - counts = counts.reshape(3, 4) - - # plot the pixels - top = axes.bar( - x_pos, bar1, color=clrmap(norm(counts[0, :])), width=1, zorder=1, edgecolor="w", linewidth=0.5 - ) - bottom = axes.bar( - x_pos, bar2, color=clrmap(norm(counts[1, :])), width=1, zorder=1, edgecolor="w", linewidth=0.5 - ) - small = axes.bar( - x_pos, - bar3, - color=clrmap(norm(counts[2, :])), - width=-0.5, - align="edge", - bottom=-0.1, - zorder=1, - edgecolor="w", - linewidth=0.5, - ) - - # hide most of the axes ticks - if last: - axes.set_xticks(range(4)) - axes.set_xticklabels(x_pos) - axes.axes.get_xaxis().set_visible(True) - axes.axes.get_yaxis().set_visible(False) - else: - axes.set_xticks([]) - axes.axes.get_xaxis().set_visible(False) - axes.axes.get_yaxis().set_visible(False) - - for i in range(4): - top[i].data = counts[0, i] - bottom[i].data = counts[1, i] - small[i].data = counts[2, i] - - # Create the label annotation - annot = axes.annotate( - "", - xy=(0, 0), - xytext=(-60, 20), - textcoords="offset points", - bbox=dict(boxstyle="round", fc="w"), - arrowprops=dict(arrowstyle="-"), - zorder=33, - ) - - annot.set_visible(False) - - # Create a hover function - def update_annot(artist, annot): - """update tooltip when hovering a given plotted object""" - # find the middle of the bar - center_x = artist.get_x() + artist.get_width() / 2 - center_y = artist.get_y() + artist.get_height() / 2 - annot.xy = (center_x, center_y) - - annot.set_text(artist.data.round(decimals=3)) - # annot.get_bbox_patch().set_alpha(1) - - def hover(event): - """update and show a tooltip while hovering an object; hide it otherwise""" - # one wants to hide the annotation only if no artist in the graph is hovered - annot.set_visible(False) - if isinstance(event.inaxes, type(axes)): - for p in [top, bottom, small]: - for artist in p: - contains, _ = artist.contains(event) - if contains: - update_annot(artist, annot) - annot.set_visible(True) - if last: - fig.canvas.draw_idle() - - fig.canvas.mpl_connect("motion_notify_event", hover) - return top, bottom, small - - def det_errorbar_plot(counts, count_err, pixel_ids, detector_id, axes): - """Shows a plot to visualize the counts; the errorbar plot.""" - plot_cont = [] - for pixel_id in pixel_ids: - plot_cont.append( - axes.errorbar((0.5, 1.5, 2.5, 3.5), counts[pixel_id], yerr=count_err[pixel_id], xerr=0.5, ls="") - ) - axes.set_xticks([]) - if detector_id > 0: - axes.set_ylabel("") - return plot_cont - - def det_config_plot(detector_config, axes, font, detector_id): - """Shows a plot with the configurations of the detectors; the config plot.""" - - # Create Functions to convert 'Front' and 'Rear Orient'. - def mm2deg(x): - return x * 360.0 / 1 - - def deg2mm(x): - return x / 360.0 * 1 - - # get the information that will be plotted - if detector_config["Phase Sense"] > 0: - phase_sense = "+" - elif detector_config["Phase Sense"] < 0: - phase_sense = "-" - else: - phase_sense = "n" - - y = [ - detector_config["Slit Width"], - detector_config["Front Pitch"], - detector_config["Rear Pitch"], - 0, - deg2mm(detector_config["Front Orient"]), - deg2mm(detector_config["Rear Orient"]), - ] - - x = np.arange(len(y)) - color = ["black", "orange", "#1f77b4", "b", "orange", "#1f77b4"] - - # plot the information on axes - axes.bar(x, y, color=color) - axes.text(x=0.8, y=0.7, s=f"Phase: {phase_sense}", **font) - axes.set_ylim(0, 1) - axes.axes.get_xaxis().set_visible(False) - - # Create secondary y axis - ax2 = axes.secondary_yaxis("right", functions=(mm2deg, deg2mm)) - ax2.set_yticks([0, 90, 270, 360]) - ax2.set_yticklabels(["0°", "90°", "270°", "360°"], fontsize=8) - ax2.set_visible(False) - axes.axes.get_yaxis().set_visible(False) - - # Create axes labeling and legend - if detector_id == 0: - axes.set_yticks([0, 1]) - axes.set_ylabel("mm", **font) - axes.yaxis.set_label_coords(-0.1, 0.5) - axes.axes.get_yaxis().set_visible(True) - legend_bars = [Patch(facecolor="orange"), Patch(facecolor="#1f77b4")] - axes.legend(legend_bars, ["Front", "Rear"], loc="center right", bbox_to_anchor=(0, 2.5)) - if detector_id == 31: - ax2.set_visible(True) - axes.axes.get_xaxis().set_visible(True) - axes.set_xticks([0, 1.5, 4.5]) - axes.set_xticklabels(["Slit Width", "Pitch", "Orientation"], rotation=90) - # leave the spaces to set the correct x position of the label! - ax2.set_ylabel(" deg °", rotation=0, **font) - # x parameter doesn't change anything because it's a secondary - # y axis (has only 1 x position). - ax2.yaxis.set_label_coords(x=1, y=0.55) - - def colorbar(counts, min_counts, max_counts, clrmap, fig): - """ - Creates a colormap at the left side of the created figure. - - NOTE: If the color of the special detectors 'cfl', 'bkg' is way above - the rest, the color will be automatically set to white. - """ - - norm = colors.Normalize(vmin=min_counts, vmax=max_counts) - cax = fig.add_axes([0.05, 0.15, 0.025, 0.8]) - cbar = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=clrmap), orientation="vertical", cax=cax) - cbar.ax.set_title(f"{str(counts.unit)}", rotation=90, x=-0.8, y=0.4) - - def instrument_layout(fig, font): - """Shows the layout of the instrument to make it easier to locate the detectors.""" - x = [0, 2] - y = [1, 1] - fig.add_axes([0.06, 0.055, 0.97, 0.97]) - plt.plot(x, y, c="b") - plt.plot(y, x, c="b") - plt.axis("off") - fig.add_axes([0.09, 0.08, 0.91, 0.92]) - draw_circle_1 = plt.Circle((0.545, 0.540), 0.443, color="b", alpha=0.1) - draw_circle_2 = plt.Circle((0.545, 0.540), 0.07, color="#2b330b", alpha=0.95) - fig.add_artist(draw_circle_1) - fig.add_artist(draw_circle_2) - plt.axis("off") - - # Label the quandrants of the instrument - fig.add_axes([0, 0, 1, 1]) - plt.text(0.19, 0.89, "Q1", **font) - plt.text(0.19, 0.17, "Q2", **font) - plt.text(0.86, 0.17, "Q3", **font) - plt.text(0.86, 0.89, "Q4", **font) - plt.axis("off") - - instrument_layout(fig, quadrant_font) # Call the instrument layout - - # Create the energy and time slider add the bottom of the figure - axcolor = "lightgoldenrodyellow" - axenergy = plt.axes([0.15, 0.05, 0.55, 0.03], facecolor=axcolor) - senergy = SliderCustomValue( - ax=axenergy, label="Energy", valmin=0, valmax=len(energies) - 1, format_func=energyval, valinit=0, valstep=1 - ) - axetime = plt.axes([0.15, 0.01, 0.55, 0.03], facecolor=axcolor) - stime = SliderCustomValue( - ax=axetime, label="Time", valmin=0, valmax=counts.shape[0] - 1, format_func=timeval, valinit=1, valstep=1 - ) - - pixel_ids = [slice(0, 4), slice(4, 8), slice(8, 12)] - if counts.shape[2] == 4: - pixel_ids = [slice(0, 4)] - - containers = defaultdict(list) - - xnorm = plt.Normalize(SubCollimatorConfig["SC Xcen"].min() * 1.5, SubCollimatorConfig["SC Xcen"].max() * 1.5) - ynorm = plt.Normalize(SubCollimatorConfig["SC Ycen"].min() * 1.4, SubCollimatorConfig["SC Ycen"].max() * 1.4) - if kind == "pixels": - colorbar(counts, min_counts, max_counts, clrmap, fig) - - # plot the layout of the 32 detectors - for detector_id in range(32): - row, col = divmod(detector_id, 8) - plot_cont = object - if kind == "pixels": - plot_cont = det_pixels_plot( - counts[0, detector_id, :, 0], norm, axes[row, col], clrmap, fig, last=(detector_id == 31) - ) - elif kind == "errorbar": - plot_cont = det_errorbar_plot( - counts[0, detector_id, :, 0], - count_err[0, detector_id, :, 0], - pixel_ids, - detector_id, - axes[row, col], - ) - elif kind == "config": - plot_cont = det_config_plot(SubCollimatorConfig[detector_id], axes[row, col], axes_font, detector_id) - - axes[row, col].set_zorder(100) - - # set the custom position of the detectors - axes[row, col].set_position( - [ - xnorm(SubCollimatorConfig["SC Xcen"][detector_id]), - ynorm(SubCollimatorConfig["SC Ycen"][detector_id]), - 1 / 11.0, - 1 / 11.0, - ] - ) - - containers[row, col].append(plot_cont) - axes[row, col].set_title(f"Det {SubCollimatorConfig['Grid Label'][detector_id]}", y=0.89, **det_font) - - def update_void(_): - """get the value as this will update the slider""" - _ = senergy.val - _ = stime.val - - def update_pixels(_): - """Update the value of the pixels plot when the energy and time slider is being used.""" - energy_index = senergy.val - time_index = stime.val - - for detector_id in range(32): - row, col = divmod(detector_id, 8) - cnts = counts[time_index, detector_id, :, energy_index] - top, bottom, small = containers[row, col][0] - cnts = cnts.reshape([3, 4]) - for pix_artist, pix in zip(range(4), range(12)): - norm_counts = norm(cnts[0][pix].value) - top[pix_artist].set_color(clrmap(norm_counts)) - top[pix_artist].data = cnts[0][pix] - top[pix_artist].set_edgecolor("w") - - norm_counts = norm(cnts[1][pix].value) - bottom[pix_artist].set_color(clrmap(norm_counts)) - bottom[pix_artist].data = cnts[1][pix] - bottom[pix_artist].set_edgecolor("w") - - norm_counts = norm(cnts[2][pix].value) - small[pix_artist].set_color(clrmap(norm_counts)) - small[pix_artist].data = cnts[2][pix] - small[pix_artist].set_edgecolor("w") - - def update_errorbar(_): - """Update the errorbar plot when the energy and time slider is being used.""" - energy_index = senergy.val - time_index = stime.val - pids_ = [slice(0, 4), slice(4, 8), slice(8, 12)] - if counts.shape[2] == 4: - pids_ = [slice(0, 4)] - - for did in range(32): - r, c = divmod(did, 8) - axes[r, c].set_ylim(0, np.nanmax(counts[time_index, imaging_mask, :, energy_index]) * 1.2) - - for i, pid in enumerate(pids_): - lines, caps, bars = containers[r, c][0][i] - lines.set_ydata(counts[time_index, did, pid, energy_index]) - - # horizontal bars at value - segs = np.array(bars[0].get_segments()) - if segs.size > 0: - segs[:, 0, 0] = [0.0, 1.0, 2.0, 3.0] - segs[:, 1, 0] = [1.0, 2.0, 3.0, 4.0] - segs[:, 0, 1] = counts[time_index, did, pid, energy_index] - segs[:, 1, 1] = counts[time_index, did, pid, energy_index] - bars[0].set_segments(segs) - # vertical bars at +/- error - segs = np.array(bars[1].get_segments()) - segs[:, 0, 0] = [0.5, 1.5, 2.5, 3.5] - segs[:, 1, 0] = [0.5, 1.5, 2.5, 3.5] - segs[:, 0, 1] = ( - counts[time_index, did, pid, energy_index] - count_err[time_index, did, pid, energy_index] - ) - segs[:, 1, 1] = ( - counts[time_index, did, pid, energy_index] + count_err[time_index, did, pid, energy_index] - ) - bars[1].set_segments(segs) - - update_function = update_pixels - if kind == "config": - update_function = update_void - elif kind == "errorbar": - update_function = update_errorbar - - # Call the update functions - senergy.on_changed(update_function) - stime.on_changed(update_function) + def plot_pixels(self, *, kind="pixel", time_indices=None, energy_indices=None, fig=None, cmap=None, **kwargs): + pixel_plotter = PixelPlotter(self, time_indices=time_indices, energy_indices=energy_indices) + pixel_plotter.plot(kind=kind, fig=fig, cmap=cmap, **kwargs) + return pixel_plotter class ScienceData(L1Product): @@ -1002,7 +572,7 @@ def get_data( e_norm = e_norm.reshape(1, 1, 1, -1) if t_norm.size != 1: - t_norm = t_norm.reshape(-1, 1, 1, 1) + t_norm = t_norm.reshape(-1, 1, 1, 1).to("s") counts_err = np.sqrt(counts * u.ct + counts_var) / (e_norm * t_norm) counts = counts / (e_norm * t_norm) diff --git a/stixpy/visualisation/plotters.py b/stixpy/visualisation/plotters.py new file mode 100644 index 00000000..b4ea0c2e --- /dev/null +++ b/stixpy/visualisation/plotters.py @@ -0,0 +1,530 @@ +import copy +from pathlib import Path +from collections import defaultdict + +import astropy.units as u +import numpy as np +from matplotlib import cm +from matplotlib import pyplot as plt +from matplotlib.colors import LogNorm, Normalize +from matplotlib.patches import Circle, Patch +from matplotlib.widgets import Slider + +from stixpy.io.readers import read_subc_params + +SubCollimatorConfig = read_subc_params( + Path(__file__).parent.parent / "config" / "data" / "detector" / "stx_subc_params.csv" +) + + +__all__ = ["PixelPlotter", "SliderCustomValue"] + + +class PixelPlotter: + """ + Plot individual pixel data for each detector. + + Support three kinds of plots: + * 'pixel' which show counts as rectangular patches in the correct pixel locations using a color map + * 'errorbar' which shows the counts and error as error bar plots one per detector + * 'config' which display per sub-collimator configuration + + Parameters + ---------- + prod : `Product` + Pixel data product to plot + kind : `string` optional + This sets the visualization type of the subplots the supported options are: 'pixel', 'errorbar', 'config'. + time_indices : `list` or `numpy.ndarray` + If an 1xN array will be treated as mask if 2XN array will sum data between given + indices. For example `time_indices=[0, 2, 5]` would return only the first, third and + sixth times while `time_indices=[[0, 2],[3, 5]]` would sum the data between. + energy_indices : `list` or `numpy.ndarray` + If an 1xN array will be treated as mask if 2XN array will sum data between given + indices. For example `energy_indices=[0, 2, 5]` would return only the first, third and + sixth times while `energy_indices=[[0, 2],[3, 5]]` would sum the data between. + fig : optional `matplotlib.figure` + The figure where to which the pixel plot will be added. + cmap : `string` | `colormap` optional + If the kind is `pixels` a colormap will be shown. + + NOTE: If the color of the special detectors 'cfl', 'bkg' is way above + the imaging detectors, the color will be automatically set to white. + + Returns + ------- + `tuple[matplotlib.figure.Figure,matplotlib.axes.Axes]` + """ + + def __init__(self, prod, time_indices=None, energy_indices=None): + self.time_indices = time_indices + self.energy_indices = energy_indices + from stixpy.product.sources import CompressedPixelData, RawPixelData, SummedCompressedPixelData + + if not isinstance(prod, (RawPixelData, CompressedPixelData, SummedCompressedPixelData)): + raise ValueError(f"Can not create a pixel plot as {prod.__class__} does not contain pixel data.") + self.prod = prod + self.kind = "pixel" + self.fig = None + self.axes = None + + self.containers = defaultdict(list) + self._prepare_data() + + def plot(self, kind="pixel", fig=None, cmap=None): + r""" + Generates and returns the main plot figure and axes + + Parameters + ---------- + kind : str + The visualization type: 'pixels', 'errorbar', or 'config'. + fig : matplotlib.figure.Figure, optional + An existing figure to draw on. + cmap : str or colormap, optional + The colormap for the 'pixels' plot. + + Returns + ------- + + """ + if kind not in ["pixel", "errorbar", "config"]: + raise ValueError(f"Kind must be 'pixel', 'errorbar' or 'config' not '{kind}'.") + self.kind = kind + + if fig is None: + fig, axes = plt.subplots(nrows=4, ncols=8, sharex=True, sharey=True, figsize=(7, 7)) + else: + axes = fig.subplots(nrows=4, ncols=8, sharex=True, sharey=True) + + self.fig = fig + self.axes = axes + + self._setup_plot_elements(cmap) + self._create_main_layout() + self._create_sliders() + self._connect_update_function() + + return self.fig, self.axes + + def _setup_plot_elements(self, cmap): + """Sets up normalization, colormaps, and fonts.""" + max_counts = np.max(self.counts[np.isfinite(self.counts)]).value + min_counts = np.min(self.counts[self.counts > 0]).value + self.norm = CountNorm(min_counts, max_counts) + self.det_font = {"weight": "regular", "size": 8} + self.axes_font = {"weight": "regular", "size": 7} + self.quadrant_font = {"weight": "regular", "size": 15} + + if cmap is None: + self.clrmap = copy.copy(cm.get_cmap("viridis")) + self.clrmap.set_over("gray") + self.clrmap.set_under("white") + self.clrmap.set_bad("gray") + + elif isinstance(cmap, str): + self.clrmap = copy.copy(cm.get_cmap(cmap)) + else: + self.clrmap = cmap + + def _prepare_data(self): + # Get the necessary data from the product + counts, count_err, times, durations, energies = self.prod.get_data( + time_indices=self.time_indices, energy_indices=self.energy_indices + ) + + nt, ndet, npix, ne = counts.shape + dmask = self.prod.detector_masks.masks[0].astype(bool) + + counts_pad = [] + count_err_pad = [] + for i, pm in enumerate(self.prod.data["pixel_masks"].value): + tmp_counts = np.full((32, 12, ne), np.nan) + tmp_err = np.full((32, 12, ne), np.nan) + tmp_counts[np.ix_(dmask, pm.astype(bool))] = counts[i][:, pm.astype(bool)[: counts.shape[2]], :] + tmp_err[np.ix_(dmask, pm.astype(bool))] = count_err[i][:, pm.astype(bool)[: counts.shape[2]], :] + + counts_pad.append(tmp_counts) + count_err_pad.append(tmp_err) + + counts_pad = np.stack(counts_pad) + count_err_pad = np.stack(count_err_pad) + + self.times = times + self.energies = energies + self.counts = counts_pad * counts.unit + self.count_err = count_err_pad * count_err.unit + + def _create_main_layout(self): + """Draws the instrument layout and the 32 detector subplots.""" + self._draw_instrument_layout() + if self.kind == "pixel": + self._draw_colorbar() + + xnorm = Normalize(SubCollimatorConfig["SC Xcen"].min() * 1.5, SubCollimatorConfig["SC Xcen"].max() * 1.5) + ynorm = Normalize(SubCollimatorConfig["SC Ycen"].min() * 1.4, SubCollimatorConfig["SC Ycen"].max() * 1.4) + + pixel_ids = [slice(0, 4), slice(4, 8), slice(8, 12)] + if self.counts.shape[2] == 4: + pixel_ids = [slice(0, 4)] + + for det_id in range(32): + row, col = divmod(det_id, 8) + ax = self.axes[row, col] + plot_container = None + + if self.kind == "pixel": + plot_container = self._det_pixels_plot(self.counts[0, det_id, :, 0], ax, last=(det_id == 31)) + elif self.kind == "errorbar": + plot_container = self._det_errorbar_plot( + self.counts[0, det_id, :, 0], self.count_err[0, det_id, :, 0], pixel_ids, det_id, ax + ) + elif self.kind == "config": + plot_container = self._det_config_plot(SubCollimatorConfig[det_id], ax, det_id) + + ax.set_zorder(100) + ax.set_position( + [ + xnorm(SubCollimatorConfig["SC Xcen"][det_id]), + ynorm(SubCollimatorConfig["SC Ycen"][det_id]), + 1 / 11.0, + 1 / 11.0, + ] + ) + self.containers[row, col].append(plot_container) + resolutions = np.arctan2(0.5 * SubCollimatorConfig["Front Pitch"].to("um"), 545.30 * u.mm).to("arcsec") + ax.set_title( + f"{SubCollimatorConfig['Det #'][det_id]}" + f" {SubCollimatorConfig['Grid Label'][det_id]}" + f'{resolutions[det_id].value: 0.1f}"', + y=0.89, + **self.det_font, + ) + + def _create_sliders(self): + """Creates the time and energy sliders.""" + axcolor = "lightgoldenrodyellow" + axenergy = plt.axes([0.15, 0.05, 0.55, 0.03], facecolor=axcolor) + self.senergy = SliderCustomValue( + ax=axenergy, + label="Energy", + valmin=0, + valmax=len(self.energies) - 1, + format_func=self._format_energy, + valinit=0, + valstep=1, + ) + axetime = plt.axes([0.15, 0.01, 0.55, 0.03], facecolor=axcolor) + self.stime = SliderCustomValue( + ax=axetime, + label="Time", + valmin=0, + valmax=self.counts.shape[0] - 1, + format_func=self._format_time, + valinit=1, + valstep=1, + ) + + # --- Formatting and Drawing Helpers --- + + def _format_time(self, val): + return self.times[val].isot + + def _format_energy(self, val): + return f"{self.energies[val]['e_low'].value}-{self.energies[val]['e_high']}" + + def _draw_colorbar(self): + """Creates a colormap on the left side of the figure.""" + cax = self.fig.add_axes([0.05, 0.15, 0.025, 0.8]) + cbar = self.fig.colorbar(cm.ScalarMappable(norm=self.norm, cmap=self.clrmap), orientation="vertical", cax=cax) + cbar.ax.set_title(f"{str(self.counts.unit)}", rotation=90, x=-0.8, y=0.4) + + def _draw_instrument_layout(self): + """Shows the layout of the instrument.""" + x = [0, 2] + y = [1, 1] + ax = self.fig.add_axes([0.06, 0.055, 0.97, 0.97]) + ax.plot(x, y, c="b") + ax.plot(y, x, c="b") + ax.axis("off") + + ax = self.fig.add_axes([0.09, 0.08, 0.91, 0.92]) + draw_circle_1 = Circle((0.545, 0.540), 0.443, color="b", alpha=0.1) + draw_circle_2 = Circle((0.545, 0.540), 0.07, color="#2b330b", alpha=0.95) + self.fig.add_artist(draw_circle_1) + self.fig.add_artist(draw_circle_2) + ax.axis("off") + + ax = self.fig.add_axes([0, 0, 1, 1]) + ax.text(0.19, 0.89, "Q1", **self.quadrant_font) + ax.text(0.19, 0.17, "Q2", **self.quadrant_font) + ax.text(0.86, 0.17, "Q3", **self.quadrant_font) + ax.text(0.86, 0.89, "Q4", **self.quadrant_font) + ax.axis("off") + + # --- Per-Detector Plotting Logic --- + + def _det_pixels_plot(self, counts, axes, last=False): + """Shows a plot to visualize the pixel counts.""" + x_pos, bar1, bar2, bar3 = ["A", "B", "C", "D"], [1] * 4, [-1] * 4, [0.2] * 4 + counts = counts.reshape(3, 4) + + top = axes.bar( + x_pos, bar1, color=self.clrmap(self.norm(counts[0, :])), width=1, zorder=1, edgecolor="k", linewidth=0.5 + ) + bottom = axes.bar( + x_pos, bar2, color=self.clrmap(self.norm(counts[1, :])), width=1, zorder=1, edgecolor="k", linewidth=0.5 + ) + small = axes.bar( + x_pos, + bar3, + color=self.clrmap(self.norm(counts[2, :])), + width=-0.5, + align="edge", + bottom=-0.1, + zorder=1, + edgecolor="k", + linewidth=0.5, + ) + + axes.axes.get_yaxis().set_visible(False) + if last: + axes.set_xticks(range(4)) + axes.set_xticklabels(x_pos) + axes.axes.get_xaxis().set_visible(True) + else: + axes.set_xticks([]) + axes.axes.get_xaxis().set_visible(False) + + for i in range(4): + top[i].data = counts[0, i] + bottom[i].data = counts[1, i] + small[i].data = counts[2, i] + + self._create_hover_tooltip(axes, [top, bottom, small], last) + return top, bottom, small + + def _det_errorbar_plot(self, counts, count_err, pixel_ids, detector_id, axes): + """Shows an errorbar plot of counts.""" + plot_cont = [ + axes.errorbar((0.5, 1.5, 2.5, 3.5), counts[pid], yerr=count_err[pid], xerr=0.5, ls="") for pid in pixel_ids + ] + axes.set_xticks([]) + if detector_id > 0: + axes.set_ylabel("") + return plot_cont + + def _det_config_plot(self, detector_config, axes, detector_id): + """Shows a plot with detector configurations.""" + + # Create Functions to convert 'Front' and 'Rear Orient'. + def mm2deg(x): + return x * 360.0 / 1 + + def deg2mm(x): + return x / 360.0 * 1 + + # get the information that will be plotted + if detector_config["Phase Sense"] > 0: + phase_sense = "+" + elif detector_config["Phase Sense"] < 0: + phase_sense = "-" + else: + phase_sense = "n" + + y = [ + detector_config["Slit Width"], + detector_config["Front Pitch"], + detector_config["Rear Pitch"], + 0, + deg2mm(detector_config["Front Orient"]), + deg2mm(detector_config["Rear Orient"]), + ] + + x = np.arange(len(y)) + color = ["black", "orange", "#1f77b4", "b", "orange", "#1f77b4"] + + # plot the information on axes + axes.bar(x, y, color=color) + axes.text(x=0.8, y=0.7, s=f"Phase: {phase_sense}", **self.axes_font) + axes.set_ylim(0, 1) + axes.axes.get_xaxis().set_visible(False) + + # Create secondary y axis + ax2 = axes.secondary_yaxis("right", functions=(mm2deg, deg2mm)) + ax2.set_yticks([0, 90, 270, 360]) + ax2.set_yticklabels(["0°", "90°", "270°", "360°"], fontsize=8) + ax2.set_visible(False) + axes.axes.get_yaxis().set_visible(False) + + # Create axes labeling and legend + if detector_id == 0: + axes.set_yticks([0, 1]) + axes.set_ylabel("mm", **self.axes_font) + axes.yaxis.set_label_coords(-0.1, 0.5) + axes.axes.get_yaxis().set_visible(True) + legend_bars = [Patch(facecolor="orange"), Patch(facecolor="#1f77b4")] + axes.legend(legend_bars, ["Front", "Rear"], loc="center right", bbox_to_anchor=(0, 2.5)) + if detector_id == 31: + ax2.set_visible(True) + axes.axes.get_xaxis().set_visible(True) + axes.set_xticks([0, 1.5, 4.5]) + axes.set_xticklabels(["Slit Width", "Pitch", "Orientation"], rotation=90) + # leave the spaces to set the correct x position of the label! + ax2.set_ylabel(" deg °", rotation=0, **self.axes_font) + # x parameter doesn't change anything because it's a secondary + # y axis (has only 1 x position). + ax2.yaxis.set_label_coords(x=1, y=0.55) + + def _create_hover_tooltip(self, axes, artists_list, last): + """Creates and manages the hover annotation for a subplot.""" + annot = axes.annotate( + "", + xy=(0, 0), + xytext=(-60, 20), + textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="-"), + zorder=33, + ) + annot.set_visible(False) + + def update_annot(artist): + center_x = artist.get_x() + artist.get_width() / 2 + center_y = artist.get_y() + artist.get_height() / 2 + annot.xy = (center_x, center_y) + annot.set_text(format(artist.data, ".2e")) + + def hover(event): + annot.set_visible(False) + if event.inaxes == axes: + for artist_group in artists_list: + for artist in artist_group: + contains, _ = artist.contains(event) + if contains: + update_annot(artist) + annot.set_visible(True) + break + if last: + self.fig.canvas.draw_idle() + + self.fig.canvas.mpl_connect("motion_notify_event", hover) + + # --- Update Functions for Sliders --- + + def _connect_update_function(self): + """Connects the appropriate update function to the sliders.""" + update_function = self._update_void + if self.kind == "pixel": + update_function = self._update_pixels + elif self.kind == "errorbar": + update_function = self._update_errorbar + + self.senergy.on_changed(update_function) + self.stime.on_changed(update_function) + + def _update_void(self, _): + """Dummy update function for static plots.""" + pass + + def _update_pixels(self, _): + """Updates the pixel colors based on slider values.""" + energy_index, time_index = self.senergy.val, self.stime.val + for detector_id in range(32): + row, col = divmod(detector_id, 8) + top, bottom, small = self.containers[row, col][0] + cnts = self.counts[time_index, detector_id, :, energy_index].reshape([3, 4]) + + for idx in range(4): + norm_counts = self.norm(cnts[0][idx].value) + top[idx].set_color(self.clrmap(norm_counts)) + top[idx].data = cnts[0][idx] + top[idx].set_edgecolor("k") + + norm_counts = self.norm(cnts[1][idx].value) + bottom[idx].set_color(self.clrmap(norm_counts)) + bottom[idx].data = cnts[1][idx] + bottom[idx].set_edgecolor("k") + + norm_counts = self.norm(cnts[2][idx].value) + small[idx].set_color(self.clrmap(norm_counts)) + small[idx].data = cnts[2][idx] + small[idx].set_edgecolor("k") + + self.fig.canvas.draw_idle() + + def _update_errorbar(self, _): + energy_index = self.senergy.val + time_index = self.stime.val + pids_ = [slice(0, 4), slice(4, 8), slice(8, 12)] + if self.counts.shape[2] == 4: + pids_ = [slice(0, 4)] + + for did in range(32): + r, c = divmod(did, 8) + self.axes[r, c].set_ylim(0, np.nanmax(self.counts[time_index, :, :, energy_index]) * 1.2) + + for i, pid in enumerate(pids_): + lines, caps, bars = self.containers[r, c][0][i] + lines.set_ydata(self.counts[time_index, did, pid, energy_index]) + + # horizontal bars at value + segs = np.array(bars[0].get_segments()) + if segs.size > 0: + segs[:, 0, 0] = [0.0, 1.0, 2.0, 3.0] + segs[:, 1, 0] = [1.0, 2.0, 3.0, 4.0] + segs[:, 0, 1] = self.counts[time_index, did, pid, energy_index] + segs[:, 1, 1] = self.counts[time_index, did, pid, energy_index] + bars[0].set_segments(segs) + # vertical bars at +/- error + segs = np.array(bars[1].get_segments()) + segs[:, 0, 0] = [0.5, 1.5, 2.5, 3.5] + segs[:, 1, 0] = [0.5, 1.5, 2.5, 3.5] + segs[:, 0, 1] = ( + self.counts[time_index, did, pid, energy_index] + - self.count_err[time_index, did, pid, energy_index] + ) + segs[:, 1, 1] = ( + self.counts[time_index, did, pid, energy_index] + + self.count_err[time_index, did, pid, energy_index] + ) + bars[1].set_segments(segs) + + self.fig.canvas.draw_idle() + + +class SliderCustomValue(Slider): + """ + A slider with a customisable formatter + """ + + def __init__(self, *args, format_func=None, **kwargs): + if format_func is not None: + self._format = format_func + super().__init__(*args, **kwargs) + + +class CountNorm(Normalize): + """ + A LogNorm but allows 0s to be kept and plotted with colormaps under color + """ + + def __init__(self, vmin=None, vmax=None, **kwargs): + super().__init__(vmin, vmax, **kwargs) + self.lognorm = LogNorm(vmin=vmin, vmax=vmax) + + def __call__(self, value): + tmp, is_scaler = self.process_value(value) + unit = tmp.unit if hasattr(tmp, "unit") else 1 + zeros = np.nonzero(tmp == 0) + tmp[zeros] = np.finfo(np.float32).tiny * unit + res = self.lognorm(tmp) + return res if not is_scaler else res[0] + + def inverse(self, value): + tmp, is_scaler = self.process_value(value) + unit = tmp.unit if hasattr(tmp, "unit") else 1 + zeros = np.nonzero(tmp == 0) + tmp[zeros] = np.finfo(np.float32).tiny * unit + res = self.lognorm.inverse(tmp) + return res if not is_scaler else res[0]