diff --git a/src/aiidalab_qe/common/bandpdoswidget.py b/src/aiidalab_qe/common/bandpdoswidget.py index 04d43a1b5..0550c93a1 100644 --- a/src/aiidalab_qe/common/bandpdoswidget.py +++ b/src/aiidalab_qe/common/bandpdoswidget.py @@ -30,17 +30,32 @@ class BandPdosPlotly: "horizontal_range_pdos": [-10, 10], } - def __init__(self, bands_data=None, pdos_data=None): + def __init__(self, bands_data=None, pdos_data=None, project_bands=False): self.bands_data = bands_data self.pdos_data = pdos_data self.fermi_energy = self._get_fermi_energy() + self.project_bands = project_bands and "projected_bands" in self.bands_data # Plotly Axis # Plotly settings self._bands_xaxis = self._band_xaxis() self._bands_yaxis = self._band_yaxis() - self._dos_xaxis = self._dos_xaxis() - self._dos_yaxis = self._dos_yaxis() + self._pdos_xaxis = self._pdos_xaxis() + self._pdos_yaxis = self._pdos_yaxis() + + @property + def plot_type(self): + """Define the plot type.""" + if self.bands_data and self.pdos_data: + return "combined" + elif self.bands_data: + return "bands" + elif self.pdos_data: + return "pdos" + + @property + def bandspdosfigure(self): + return self._get_bandspdos_plot() def _get_fermi_energy(self): """Function to return the Fermi energy information depending on the data available.""" @@ -112,133 +127,95 @@ def _band_yaxis(self): return bandyaxis - def _dos_xaxis(self): - """Function to return the xaxis for the dos plot.""" + def _pdos_xaxis(self): + """Function to return the xaxis for the pdos plot.""" if not self.pdos_data: return None + # For combined plot + axis_settings = { + "showgrid": True, + "showline": True, + "mirror": "ticks", + "ticks": "inside", + "linewidth": 2, + "tickwidth": 2, + "linecolor": self.SETTINGS["axis_linecolor"], + "title": "Density of states", + "side": "bottom", + "automargin": True, + } - if self.bands_data: - dosxaxis = go.layout.XAxis( - title="Density of states", - side="bottom", - showgrid=True, - showline=True, - linecolor=self.SETTINGS["axis_linecolor"], - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - automargin=True, - ) - - else: - dosxaxis = go.layout.XAxis( - title="Density of states (eV)", - showgrid=True, - showline=True, - linecolor=self.SETTINGS["axis_linecolor"], - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - range=self.SETTINGS["horizontal_range_pdos"], - ) + if self.plot_type != "combined": + axis_settings["title"] = "Density of states (eV)" + axis_settings["range"] = self.SETTINGS["horizontal_range_pdos"] + axis_settings.pop("side") + axis_settings.pop("automargin") - return dosxaxis + return go.layout.XAxis(**axis_settings) - def _dos_yaxis(self): - """Function to return the yaxis for the dos plot.""" + def _pdos_yaxis(self): + """Function to return the yaxis for the pdos plot.""" if not self.pdos_data: return None - if self.bands_data: - dosyaxis = go.layout.YAxis( - # title= {"text":"Density of states (eV)", "standoff": 1}, - showgrid=True, - showline=True, - side="right", - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - linecolor=self.SETTINGS["axis_linecolor"], - zerolinewidth=2, - ) - - else: - dosyaxis = go.layout.YAxis( - # title="Density of states (eV)", - showgrid=True, - showline=True, - side="left", - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - linecolor=self.SETTINGS["axis_linecolor"], - zerolinewidth=2, - ) + axis_settings = { + "showgrid": True, + "showline": True, + "side": "right" if self.plot_type == "combined" else "left", + "mirror": "ticks", + "ticks": "inside", + "linewidth": 2, + "tickwidth": 2, + "linecolor": self.SETTINGS["axis_linecolor"], + "zerolinewidth": 2, + } - return dosyaxis + return go.layout.YAxis(**axis_settings) def _get_bandspdos_plot(self): """Function to return the bands plot widget.""" - conditions = { - (True, False): self._create_bands_only_plot, - (False, True): self._create_dos_only_plot, - (True, True): self._create_combined_plot, - } - - return conditions.get((bool(self.bands_data), bool(self.pdos_data)), None)() - def _create_bands_only_plot(self): - """Function to return the bands plot widget.""" + fig = self._create_fig() + if self.bands_data: + self._add_band_traces(fig) - fig = go.Figure() - paths = self.bands_data.get("paths") + band_labels = self.bands_data.get("pathlabels") + for label in band_labels[1]: + fig.add_vline( + x=label, + line=dict(color=self.SETTINGS["vertical_linecolor"], width=1), + ) - self._add_band_traces(fig, paths, "bands_only") + if self.project_bands: + self._add_projection_traces(fig) - band_labels = self.bands_data.get("pathlabels") - for i in band_labels[1]: - fig.add_vline( - x=i, line=dict(color=self.SETTINGS["vertical_linecolor"], width=1) - ) - fig.update_layout( - xaxis=self._bands_xaxis, - yaxis=self._bands_yaxis, - plot_bgcolor="white", - height=self.SETTINGS["bands_plot_height"], - width=self.SETTINGS["bands_plot_width"], - ) - return go.FigureWidget(fig) + if self.pdos_data: + self._add_pdos_traces(fig) + if self.plot_type == "pdos": + fig.add_vline( + x=0, + line=dict( + color=self.SETTINGS["vertical_linecolor"], width=1, dash="dot" + ), + ) - def _create_dos_only_plot(self): - """Function to return the pdos plot widget.""" + if self.plot_type == "combined": + self._customize_combined_layout(fig) + else: + self._customize_single_layout(fig) - fig = go.Figure() - # Extract DOS data - self._add_dos_traces(fig, plot_type="dos_only") - # Add a vertical line at zero energy - fig.add_vline( - x=0, - line=dict(color=self.SETTINGS["vertical_linecolor"], width=1, dash="dot"), - ) + return go.FigureWidget(fig) - # Update the layout of the Figure - fig.update_layout( - xaxis=self._dos_xaxis, - yaxis=self._dos_yaxis, - plot_bgcolor="white", - height=self.SETTINGS["pdos_plot_height"], - width=self.SETTINGS["pdos_plot_width"], - ) + def _create_fig(self): + """Create a plotly figure. - return go.FigureWidget(fig) + The figure layout is different depending on the plot type. + """ + if self.plot_type != "combined": + return go.Figure() - def _create_combined_plot(self): fig = make_subplots( rows=1, cols=2, @@ -246,106 +223,69 @@ def _create_combined_plot(self): column_widths=self.SETTINGS["combined_column_widths"], horizontal_spacing=0.015, ) - paths = self.bands_data.get("paths") - self._add_band_traces(fig, paths, plot_type="combined") - self._add_dos_traces(fig, plot_type="combined") - band_labels = self.bands_data.get("pathlabels") - for i in band_labels[1]: - fig.add_vline( - x=i, - line=dict(color=self.SETTINGS["vertical_linecolor"], width=1), - row=1, - col=1, - ) - self._customize_combined_layout(fig) - return go.FigureWidget(fig) - - def _add_band_traces(self, fig, paths, plot_type): - paths = self.bands_data.get("paths") + return fig + + def _add_traces_to_fig(self, fig, traces, col): + """Add a list of traces to a figure.""" + if self.plot_type == "combined": + rows = [1] * len(traces) + cols = [col] * len(traces) + fig.add_traces(traces, rows=rows, cols=cols) + else: + fig.add_traces(traces) + + def _add_band_traces(self, fig): + """Generate the band traces and add them to the figure.""" + colors = { + (True, 0): self.SETTINGS["bands_up_linecolor"], + (True, 1): self.SETTINGS["bands_down_linecolor"], + (False, 0): self.SETTINGS["bands_linecolor"], + } + fermi_energy_mapping = { + (False, 0): self.fermi_energy.get("fermi_energy_up", None), + (False, 1): self.fermi_energy.get("fermi_energy_down", None), + } - # Spin condition: True if spin-polarized False if not - spin_type = paths[0].get("two_band_types") + bands_data = self.bands_data # Convert paths to a list of Scatter objects scatter_objects = [] - for band in paths: - if not spin_type: - # Non-spin-polarized case - for bands in band["values"]: - bands_np = np.array(bands) - scatter_objects.append( - go.Scatter( - x=band["x"], - y=bands_np - self.fermi_energy["fermi_energy"], - mode="lines", - line=dict( - color=self.SETTINGS["bands_linecolor"], - shape="spline", - smoothing=1.3, - ), - showlegend=False, - ) - ) - else: - half_len = len(band["values"]) // 2 - first_half = band["values"][:half_len] - second_half = band["values"][half_len:] - - # Red line for the Spin up - color_first_half = self.SETTINGS["bands_up_linecolor"] - # Blue line for the Spin down - color_second_half = self.SETTINGS["bands_down_linecolor"] - if "fermi_energy" in self.fermi_energy: - for bands, color in zip( - (first_half, second_half), (color_first_half, color_second_half) - ): - bands_np = np.array(bands) - scatter_objects.append( - go.Scatter( - x=band["x"], - y=bands_np - self.fermi_energy["fermi_energy"], - mode="lines", - line=dict( - color=color, - shape="spline", - smoothing=1.3, - ), - showlegend=False, - ) - ) - else: - for bands, color, fermi_energy in zip( - (first_half, second_half), - (color_first_half, color_second_half), - ( - self.fermi_energy["fermi_energy_up"], - self.fermi_energy["fermi_energy_down"], - ), - ): - for band_values in bands: - bands_np = np.array(band_values) - scatter_objects.append( - go.Scatter( - x=band["x"], - y=bands_np - fermi_energy, - mode="lines", - line=dict( - color=color, - shape="spline", - smoothing=1.3, - ), - showlegend=False, - ) - ) - - if plot_type == "bands_only": - fig.add_traces(scatter_objects) - else: - rows = [1] * len(scatter_objects) - cols = [1] * len(scatter_objects) - fig.add_traces(scatter_objects, rows=rows, cols=cols) + spin_polarized = 1 in bands_data["band_type_idx"] + for spin in [0, 1]: + # In case of non-spin-polarized or SOC calculations, the spin index is only 0 + if spin not in bands_data["band_type_idx"]: + continue + + x_bands = np.array(bands_data["x"]).reshape(1, -1) + # New shape: (number of bands, number of kpoints) + y_bands = bands_data["y"][:, bands_data["band_type_idx"] == spin].T + # Concatenate the bands and prepare the traces + x_bands_comb, y_bands_comb = _prepare_combined_plotly_traces( + x_bands, y_bands + ) + + fermi_energy = fermi_energy_mapping.get( + ("fermi_energy" in self.fermi_energy, spin), + self.fermi_energy.get("fermi_energy"), + ) + + scatter_objects.append( + go.Scatter( + x=x_bands_comb, + y=y_bands_comb - fermi_energy, + mode="lines", + line=dict( + color=colors[(spin_polarized, spin)], + shape="spline", + smoothing=1.3, + ), + showlegend=False, + ) + ) + + self._add_traces_to_fig(fig, scatter_objects, 1) - def _add_dos_traces(self, fig, plot_type): + def _add_pdos_traces(self, fig): # Extract DOS data dos_data = self.pdos_data["dos"] @@ -353,62 +293,75 @@ def _add_dos_traces(self, fig, plot_type): num_traces = len(dos_data) scatter_objects = [None] * num_traces + # dictionary with keys (bool(spin polarized), bool(spin up)) + fermi_energy_spin_mapping = { + (False, True): self.fermi_energy.get("fermi_energy_up", None), + (False, False): self.fermi_energy.get("fermi_energy_down", None), + } + # Vectorize Scatter object creation for i, trace in enumerate(dos_data): dos_np = np.array(trace["x"]) - fill = "tozerox" if plot_type == "combined" else "tozeroy" + fill = "tozerox" if self.plot_type == "combined" else "tozeroy" + fermi_energy = fermi_energy_spin_mapping.get( + ("fermi_energy" in self.fermi_energy, trace["label"].endswith("(↑)")), + self.fermi_energy.get("fermi_energy"), + ) - if "fermi_energy" in self.fermi_energy: - y_data = ( - dos_np - self.fermi_energy["fermi_energy"] - if plot_type == "combined" - else trace["y"] - ) - x_data = ( - trace["y"] - if plot_type == "combined" - else dos_np - self.fermi_energy["fermi_energy"] - ) - else: - if trace["label"].endswith("(↑)"): - y_data = ( - dos_np - self.fermi_energy["fermi_energy_up"] - if plot_type == "combined" - else trace["y"] - ) - x_data = ( - trace["y"] - if plot_type == "combined" - else dos_np - self.fermi_energy["fermi_energy_up"] - ) - else: - y_data = ( - dos_np - self.fermi_energy["fermi_energy_down"] - if plot_type == "combined" - else trace["y"] - ) - x_data = ( - trace["y"] - if plot_type == "combined" - else dos_np - self.fermi_energy["fermi_energy_down"] - ) + x_data = ( + trace["y"] if self.plot_type == "combined" else dos_np - fermi_energy + ) + y_data = ( + dos_np - fermi_energy if self.plot_type == "combined" else trace["y"] + ) scatter_objects[i] = go.Scatter( x=x_data, y=y_data, fill=fill, name=trace["label"], line=dict(color=trace["borderColor"], shape="spline", smoothing=1.0), + legendgroup=trace["label"], ) - if plot_type == "dos_only": - fig.add_traces(scatter_objects) - else: - rows = [1] * len(scatter_objects) - cols = [2] * len(scatter_objects) - fig.add_traces(scatter_objects, rows=rows, cols=cols) + + self._add_traces_to_fig(fig, scatter_objects, 2) + + def _add_projection_traces(self, fig): + """Function to add the projected bands traces to the bands plot.""" + projected_bands = self.bands_data["projected_bands"] + # dictionary with keys (bool(spin polarized), bool(spin up)) + fermi_energy_spin_mapping = { + (False, True): self.fermi_energy.get("fermi_energy_up", None), + (False, False): self.fermi_energy.get("fermi_energy_down", None), + } + + scatter_objects = [] + for proj_bands in projected_bands: + fermi_energy = fermi_energy_spin_mapping.get( + ( + "fermi_energy" in self.fermi_energy, + proj_bands["label"].endswith("(↑)"), + ), + self.fermi_energy.get("fermi_energy"), + ) + scatter_objects.append( + go.Scatter( + x=proj_bands["x"], + y=np.array(proj_bands["y"]) - fermi_energy, + fill="toself", + legendgroup=proj_bands["label"], + mode="lines", + line=dict(width=0, color=proj_bands["color"]), + name=proj_bands["label"], + # If PDOS is present, use those legend entries + showlegend=True if self.plot_type == "bands" else False, + ) + ) + + self._add_traces_to_fig(fig, scatter_objects, 1) def _customize_combined_layout(self, fig): self._customize_layout(fig, self._bands_xaxis, self._bands_yaxis) - self._customize_layout(fig, self._dos_xaxis, self._dos_yaxis, col=2) + self._customize_layout(fig, self._pdos_xaxis, self._pdos_yaxis, col=2) fig.update_layout( legend=dict(xanchor="left", x=1.06), height=self.SETTINGS["combined_plot_height"], @@ -426,9 +379,17 @@ def _customize_layout(self, fig, xaxis, yaxis, row=1, col=1): col=col, ) - @property - def bandspdosfigure(self): - return self._get_bandspdos_plot() + def _customize_single_layout(self, fig): + xaxis = getattr(self, f"_{self.plot_type}_xaxis") + yaxis = getattr(self, f"_{self.plot_type}_yaxis") + + fig.update_layout( + xaxis=xaxis, + yaxis=yaxis, + plot_bgcolor="white", + height=self.SETTINGS[f"{self.plot_type}_plot_height"], + width=self.SETTINGS[f"{self.plot_type}_plot_width"], + ) class BandPdosWidget(ipw.VBox): @@ -446,6 +407,7 @@ class BandPdosWidget(ipw.VBox): - selected_atoms: Text widget to select specific atoms for PDOS plotting. - update_plot_button: Button widget to update the plot. - download_button: Button widget to download the data. + - project_bands_box: Checkbox widget to choose whether projected bands should be plotted. - dos_data: PDOS data. - bands_data: Band structure data. - bandsplot_widget: Plotly widget for band structure and PDOS plot. @@ -458,6 +420,7 @@ class BandPdosWidget(ipw.VBox): Select the style of plotting the projected density of states. """ ) + projected_bands_width = 0.5 def __init__(self, bands=None, pdos=None, **kwargs): if bands is None and pdos is None: @@ -504,28 +467,35 @@ def __init__(self, bands=None, pdos=None, **kwargs): disabled=False, layout=ipw.Layout(visibility="hidden"), ) + self.project_bands_box = ipw.Checkbox( + value=False, + description="Add `fat bands` projections", + ) # Information for the plot - self.dos_data = self._get_dos_data() + self.pdos_data = self._get_pdos_data() self.bands_data = self._get_bands_data() # Plotly widget self.bandsplot_widget = BandPdosPlotly( - bands_data=self.bands_data, pdos_data=self.dos_data + bands_data=self.bands_data, pdos_data=self.pdos_data ).bandspdosfigure # Output widget to display the bandsplot widget self.bands_widget = ipw.Output() # Output widget to clear the specific widgets self.pdos_options_out = ipw.Output() - self.pdos_options = ipw.VBox( - [ - self.description, - self.dos_atoms_group, - self.dos_plot_group, - ipw.HBox([self.selected_atoms, self._wrong_syntax]), - self.update_plot_button, - ] - ) + pdos_options_list = [ + self.description, + self.dos_atoms_group, + self.dos_plot_group, + ipw.HBox([self.selected_atoms, self._wrong_syntax]), + self.update_plot_button, + ] + # If projections are available in the bands data, include the box to plot fat-bands + if self.bands_data and "projected_bands" in self.bands_data: + pdos_options_list.insert(4, self.project_bands_box) + + self.pdos_options = ipw.VBox(pdos_options_list) self._initial_view() @@ -541,22 +511,31 @@ def __init__(self, bands=None, pdos=None, **kwargs): ], **kwargs, ) - if self.pdos: + + # Plot the options only if the pdos is provided or in case the bands data contains projections + if self.pdos or (self.bands_data and "projected_bands" in self.bands_data): with self.pdos_options_out: display(self.pdos_options) def download_data(self, _=None): """Function to download the data.""" file_name_bands = "bands_data.json" - file_name_dos = "dos_data.json" + file_name_pdos = "dos_data.json" if self.bands_data: - json_str = json.dumps(self.bands_data) + bands_data_export = {} + for key, value in self.bands_data.items(): + if isinstance(value, np.ndarray): + bands_data_export[key] = value.tolist() + else: + bands_data_export[key] = value + + json_str = json.dumps(bands_data_export) b64_str = base64.b64encode(json_str.encode()).decode() self._download(payload=b64_str, filename=file_name_bands) - if self.dos_data: - json_str = json.dumps(self.dos_data) + if self.pdos_data: + json_str = json.dumps(self.pdos_data) b64_str = base64.b64encode(json_str.encode()).decode() - self._download(payload=b64_str, filename=file_name_dos) + self._download(payload=b64_str, filename=file_name_pdos) @staticmethod def _download(payload, filename): @@ -575,34 +554,45 @@ def _download(payload, filename): ) display(javas) - def _get_dos_data(self): + def _get_pdos_data(self): if not self.pdos: return None expanded_selection, syntax_ok = string_range_to_list( self.selected_atoms.value, shift=-1 ) if syntax_ok: - dos = get_pdos_data( + pdos = get_pdos_data( self.pdos, group_tag=self.dos_atoms_group.value, plot_tag=self.dos_plot_group.value, selected_atoms=expanded_selection, ) - return dos - else: - return None + return pdos + return None def _get_bands_data(self): if not self.bands: return None - bands = export_bands_data(self.bands) - return bands + expanded_selection, syntax_ok = string_range_to_list( + self.selected_atoms.value, shift=-1 + ) + if syntax_ok: + bands = get_bands_projections_data( + self.bands, + group_tag=self.dos_atoms_group.value, + plot_tag=self.dos_plot_group.value, + selected_atoms=expanded_selection, + bands_width=self.projected_bands_width, + ) + return bands + return None def _initial_view(self): with self.bands_widget: self._clear_output_and_display(self.bandsplot_widget) self.download_button.layout.visibility = "visible" + self.project_bands_box.layout.visibility = "visible" def _update_plot(self, _=None): with self.bands_widget: @@ -613,9 +603,12 @@ def _update_plot(self, _=None): self._wrong_syntax.message = """