diff --git a/src/aiidalab_qe/common/bandpdoswidget.py b/src/aiidalab_qe/common/bandpdoswidget.py index 04d43a1b5..0550c93a1 100644 --- a/src/aiidalab_qe/common/bandpdoswidget.py +++ b/src/aiidalab_qe/common/bandpdoswidget.py @@ -30,17 +30,32 @@ class BandPdosPlotly: "horizontal_range_pdos": [-10, 10], } - def __init__(self, bands_data=None, pdos_data=None): + def __init__(self, bands_data=None, pdos_data=None, project_bands=False): self.bands_data = bands_data self.pdos_data = pdos_data self.fermi_energy = self._get_fermi_energy() + self.project_bands = project_bands and "projected_bands" in self.bands_data # Plotly Axis # Plotly settings self._bands_xaxis = self._band_xaxis() self._bands_yaxis = self._band_yaxis() - self._dos_xaxis = self._dos_xaxis() - self._dos_yaxis = self._dos_yaxis() + self._pdos_xaxis = self._pdos_xaxis() + self._pdos_yaxis = self._pdos_yaxis() + + @property + def plot_type(self): + """Define the plot type.""" + if self.bands_data and self.pdos_data: + return "combined" + elif self.bands_data: + return "bands" + elif self.pdos_data: + return "pdos" + + @property + def bandspdosfigure(self): + return self._get_bandspdos_plot() def _get_fermi_energy(self): """Function to return the Fermi energy information depending on the data available.""" @@ -112,133 +127,95 @@ def _band_yaxis(self): return bandyaxis - def _dos_xaxis(self): - """Function to return the xaxis for the dos plot.""" + def _pdos_xaxis(self): + """Function to return the xaxis for the pdos plot.""" if not self.pdos_data: return None + # For combined plot + axis_settings = { + "showgrid": True, + "showline": True, + "mirror": "ticks", + "ticks": "inside", + "linewidth": 2, + "tickwidth": 2, + "linecolor": self.SETTINGS["axis_linecolor"], + "title": "Density of states", + "side": "bottom", + "automargin": True, + } - if self.bands_data: - dosxaxis = go.layout.XAxis( - title="Density of states", - side="bottom", - showgrid=True, - showline=True, - linecolor=self.SETTINGS["axis_linecolor"], - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - automargin=True, - ) - - else: - dosxaxis = go.layout.XAxis( - title="Density of states (eV)", - showgrid=True, - showline=True, - linecolor=self.SETTINGS["axis_linecolor"], - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - range=self.SETTINGS["horizontal_range_pdos"], - ) + if self.plot_type != "combined": + axis_settings["title"] = "Density of states (eV)" + axis_settings["range"] = self.SETTINGS["horizontal_range_pdos"] + axis_settings.pop("side") + axis_settings.pop("automargin") - return dosxaxis + return go.layout.XAxis(**axis_settings) - def _dos_yaxis(self): - """Function to return the yaxis for the dos plot.""" + def _pdos_yaxis(self): + """Function to return the yaxis for the pdos plot.""" if not self.pdos_data: return None - if self.bands_data: - dosyaxis = go.layout.YAxis( - # title= {"text":"Density of states (eV)", "standoff": 1}, - showgrid=True, - showline=True, - side="right", - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - linecolor=self.SETTINGS["axis_linecolor"], - zerolinewidth=2, - ) - - else: - dosyaxis = go.layout.YAxis( - # title="Density of states (eV)", - showgrid=True, - showline=True, - side="left", - mirror="ticks", - ticks="inside", - linewidth=2, - tickwidth=2, - linecolor=self.SETTINGS["axis_linecolor"], - zerolinewidth=2, - ) + axis_settings = { + "showgrid": True, + "showline": True, + "side": "right" if self.plot_type == "combined" else "left", + "mirror": "ticks", + "ticks": "inside", + "linewidth": 2, + "tickwidth": 2, + "linecolor": self.SETTINGS["axis_linecolor"], + "zerolinewidth": 2, + } - return dosyaxis + return go.layout.YAxis(**axis_settings) def _get_bandspdos_plot(self): """Function to return the bands plot widget.""" - conditions = { - (True, False): self._create_bands_only_plot, - (False, True): self._create_dos_only_plot, - (True, True): self._create_combined_plot, - } - - return conditions.get((bool(self.bands_data), bool(self.pdos_data)), None)() - def _create_bands_only_plot(self): - """Function to return the bands plot widget.""" + fig = self._create_fig() + if self.bands_data: + self._add_band_traces(fig) - fig = go.Figure() - paths = self.bands_data.get("paths") + band_labels = self.bands_data.get("pathlabels") + for label in band_labels[1]: + fig.add_vline( + x=label, + line=dict(color=self.SETTINGS["vertical_linecolor"], width=1), + ) - self._add_band_traces(fig, paths, "bands_only") + if self.project_bands: + self._add_projection_traces(fig) - band_labels = self.bands_data.get("pathlabels") - for i in band_labels[1]: - fig.add_vline( - x=i, line=dict(color=self.SETTINGS["vertical_linecolor"], width=1) - ) - fig.update_layout( - xaxis=self._bands_xaxis, - yaxis=self._bands_yaxis, - plot_bgcolor="white", - height=self.SETTINGS["bands_plot_height"], - width=self.SETTINGS["bands_plot_width"], - ) - return go.FigureWidget(fig) + if self.pdos_data: + self._add_pdos_traces(fig) + if self.plot_type == "pdos": + fig.add_vline( + x=0, + line=dict( + color=self.SETTINGS["vertical_linecolor"], width=1, dash="dot" + ), + ) - def _create_dos_only_plot(self): - """Function to return the pdos plot widget.""" + if self.plot_type == "combined": + self._customize_combined_layout(fig) + else: + self._customize_single_layout(fig) - fig = go.Figure() - # Extract DOS data - self._add_dos_traces(fig, plot_type="dos_only") - # Add a vertical line at zero energy - fig.add_vline( - x=0, - line=dict(color=self.SETTINGS["vertical_linecolor"], width=1, dash="dot"), - ) + return go.FigureWidget(fig) - # Update the layout of the Figure - fig.update_layout( - xaxis=self._dos_xaxis, - yaxis=self._dos_yaxis, - plot_bgcolor="white", - height=self.SETTINGS["pdos_plot_height"], - width=self.SETTINGS["pdos_plot_width"], - ) + def _create_fig(self): + """Create a plotly figure. - return go.FigureWidget(fig) + The figure layout is different depending on the plot type. + """ + if self.plot_type != "combined": + return go.Figure() - def _create_combined_plot(self): fig = make_subplots( rows=1, cols=2, @@ -246,106 +223,69 @@ def _create_combined_plot(self): column_widths=self.SETTINGS["combined_column_widths"], horizontal_spacing=0.015, ) - paths = self.bands_data.get("paths") - self._add_band_traces(fig, paths, plot_type="combined") - self._add_dos_traces(fig, plot_type="combined") - band_labels = self.bands_data.get("pathlabels") - for i in band_labels[1]: - fig.add_vline( - x=i, - line=dict(color=self.SETTINGS["vertical_linecolor"], width=1), - row=1, - col=1, - ) - self._customize_combined_layout(fig) - return go.FigureWidget(fig) - - def _add_band_traces(self, fig, paths, plot_type): - paths = self.bands_data.get("paths") + return fig + + def _add_traces_to_fig(self, fig, traces, col): + """Add a list of traces to a figure.""" + if self.plot_type == "combined": + rows = [1] * len(traces) + cols = [col] * len(traces) + fig.add_traces(traces, rows=rows, cols=cols) + else: + fig.add_traces(traces) + + def _add_band_traces(self, fig): + """Generate the band traces and add them to the figure.""" + colors = { + (True, 0): self.SETTINGS["bands_up_linecolor"], + (True, 1): self.SETTINGS["bands_down_linecolor"], + (False, 0): self.SETTINGS["bands_linecolor"], + } + fermi_energy_mapping = { + (False, 0): self.fermi_energy.get("fermi_energy_up", None), + (False, 1): self.fermi_energy.get("fermi_energy_down", None), + } - # Spin condition: True if spin-polarized False if not - spin_type = paths[0].get("two_band_types") + bands_data = self.bands_data # Convert paths to a list of Scatter objects scatter_objects = [] - for band in paths: - if not spin_type: - # Non-spin-polarized case - for bands in band["values"]: - bands_np = np.array(bands) - scatter_objects.append( - go.Scatter( - x=band["x"], - y=bands_np - self.fermi_energy["fermi_energy"], - mode="lines", - line=dict( - color=self.SETTINGS["bands_linecolor"], - shape="spline", - smoothing=1.3, - ), - showlegend=False, - ) - ) - else: - half_len = len(band["values"]) // 2 - first_half = band["values"][:half_len] - second_half = band["values"][half_len:] - - # Red line for the Spin up - color_first_half = self.SETTINGS["bands_up_linecolor"] - # Blue line for the Spin down - color_second_half = self.SETTINGS["bands_down_linecolor"] - if "fermi_energy" in self.fermi_energy: - for bands, color in zip( - (first_half, second_half), (color_first_half, color_second_half) - ): - bands_np = np.array(bands) - scatter_objects.append( - go.Scatter( - x=band["x"], - y=bands_np - self.fermi_energy["fermi_energy"], - mode="lines", - line=dict( - color=color, - shape="spline", - smoothing=1.3, - ), - showlegend=False, - ) - ) - else: - for bands, color, fermi_energy in zip( - (first_half, second_half), - (color_first_half, color_second_half), - ( - self.fermi_energy["fermi_energy_up"], - self.fermi_energy["fermi_energy_down"], - ), - ): - for band_values in bands: - bands_np = np.array(band_values) - scatter_objects.append( - go.Scatter( - x=band["x"], - y=bands_np - fermi_energy, - mode="lines", - line=dict( - color=color, - shape="spline", - smoothing=1.3, - ), - showlegend=False, - ) - ) - - if plot_type == "bands_only": - fig.add_traces(scatter_objects) - else: - rows = [1] * len(scatter_objects) - cols = [1] * len(scatter_objects) - fig.add_traces(scatter_objects, rows=rows, cols=cols) + spin_polarized = 1 in bands_data["band_type_idx"] + for spin in [0, 1]: + # In case of non-spin-polarized or SOC calculations, the spin index is only 0 + if spin not in bands_data["band_type_idx"]: + continue + + x_bands = np.array(bands_data["x"]).reshape(1, -1) + # New shape: (number of bands, number of kpoints) + y_bands = bands_data["y"][:, bands_data["band_type_idx"] == spin].T + # Concatenate the bands and prepare the traces + x_bands_comb, y_bands_comb = _prepare_combined_plotly_traces( + x_bands, y_bands + ) + + fermi_energy = fermi_energy_mapping.get( + ("fermi_energy" in self.fermi_energy, spin), + self.fermi_energy.get("fermi_energy"), + ) + + scatter_objects.append( + go.Scatter( + x=x_bands_comb, + y=y_bands_comb - fermi_energy, + mode="lines", + line=dict( + color=colors[(spin_polarized, spin)], + shape="spline", + smoothing=1.3, + ), + showlegend=False, + ) + ) + + self._add_traces_to_fig(fig, scatter_objects, 1) - def _add_dos_traces(self, fig, plot_type): + def _add_pdos_traces(self, fig): # Extract DOS data dos_data = self.pdos_data["dos"] @@ -353,62 +293,75 @@ def _add_dos_traces(self, fig, plot_type): num_traces = len(dos_data) scatter_objects = [None] * num_traces + # dictionary with keys (bool(spin polarized), bool(spin up)) + fermi_energy_spin_mapping = { + (False, True): self.fermi_energy.get("fermi_energy_up", None), + (False, False): self.fermi_energy.get("fermi_energy_down", None), + } + # Vectorize Scatter object creation for i, trace in enumerate(dos_data): dos_np = np.array(trace["x"]) - fill = "tozerox" if plot_type == "combined" else "tozeroy" + fill = "tozerox" if self.plot_type == "combined" else "tozeroy" + fermi_energy = fermi_energy_spin_mapping.get( + ("fermi_energy" in self.fermi_energy, trace["label"].endswith("(↑)")), + self.fermi_energy.get("fermi_energy"), + ) - if "fermi_energy" in self.fermi_energy: - y_data = ( - dos_np - self.fermi_energy["fermi_energy"] - if plot_type == "combined" - else trace["y"] - ) - x_data = ( - trace["y"] - if plot_type == "combined" - else dos_np - self.fermi_energy["fermi_energy"] - ) - else: - if trace["label"].endswith("(↑)"): - y_data = ( - dos_np - self.fermi_energy["fermi_energy_up"] - if plot_type == "combined" - else trace["y"] - ) - x_data = ( - trace["y"] - if plot_type == "combined" - else dos_np - self.fermi_energy["fermi_energy_up"] - ) - else: - y_data = ( - dos_np - self.fermi_energy["fermi_energy_down"] - if plot_type == "combined" - else trace["y"] - ) - x_data = ( - trace["y"] - if plot_type == "combined" - else dos_np - self.fermi_energy["fermi_energy_down"] - ) + x_data = ( + trace["y"] if self.plot_type == "combined" else dos_np - fermi_energy + ) + y_data = ( + dos_np - fermi_energy if self.plot_type == "combined" else trace["y"] + ) scatter_objects[i] = go.Scatter( x=x_data, y=y_data, fill=fill, name=trace["label"], line=dict(color=trace["borderColor"], shape="spline", smoothing=1.0), + legendgroup=trace["label"], ) - if plot_type == "dos_only": - fig.add_traces(scatter_objects) - else: - rows = [1] * len(scatter_objects) - cols = [2] * len(scatter_objects) - fig.add_traces(scatter_objects, rows=rows, cols=cols) + + self._add_traces_to_fig(fig, scatter_objects, 2) + + def _add_projection_traces(self, fig): + """Function to add the projected bands traces to the bands plot.""" + projected_bands = self.bands_data["projected_bands"] + # dictionary with keys (bool(spin polarized), bool(spin up)) + fermi_energy_spin_mapping = { + (False, True): self.fermi_energy.get("fermi_energy_up", None), + (False, False): self.fermi_energy.get("fermi_energy_down", None), + } + + scatter_objects = [] + for proj_bands in projected_bands: + fermi_energy = fermi_energy_spin_mapping.get( + ( + "fermi_energy" in self.fermi_energy, + proj_bands["label"].endswith("(↑)"), + ), + self.fermi_energy.get("fermi_energy"), + ) + scatter_objects.append( + go.Scatter( + x=proj_bands["x"], + y=np.array(proj_bands["y"]) - fermi_energy, + fill="toself", + legendgroup=proj_bands["label"], + mode="lines", + line=dict(width=0, color=proj_bands["color"]), + name=proj_bands["label"], + # If PDOS is present, use those legend entries + showlegend=True if self.plot_type == "bands" else False, + ) + ) + + self._add_traces_to_fig(fig, scatter_objects, 1) def _customize_combined_layout(self, fig): self._customize_layout(fig, self._bands_xaxis, self._bands_yaxis) - self._customize_layout(fig, self._dos_xaxis, self._dos_yaxis, col=2) + self._customize_layout(fig, self._pdos_xaxis, self._pdos_yaxis, col=2) fig.update_layout( legend=dict(xanchor="left", x=1.06), height=self.SETTINGS["combined_plot_height"], @@ -426,9 +379,17 @@ def _customize_layout(self, fig, xaxis, yaxis, row=1, col=1): col=col, ) - @property - def bandspdosfigure(self): - return self._get_bandspdos_plot() + def _customize_single_layout(self, fig): + xaxis = getattr(self, f"_{self.plot_type}_xaxis") + yaxis = getattr(self, f"_{self.plot_type}_yaxis") + + fig.update_layout( + xaxis=xaxis, + yaxis=yaxis, + plot_bgcolor="white", + height=self.SETTINGS[f"{self.plot_type}_plot_height"], + width=self.SETTINGS[f"{self.plot_type}_plot_width"], + ) class BandPdosWidget(ipw.VBox): @@ -446,6 +407,7 @@ class BandPdosWidget(ipw.VBox): - selected_atoms: Text widget to select specific atoms for PDOS plotting. - update_plot_button: Button widget to update the plot. - download_button: Button widget to download the data. + - project_bands_box: Checkbox widget to choose whether projected bands should be plotted. - dos_data: PDOS data. - bands_data: Band structure data. - bandsplot_widget: Plotly widget for band structure and PDOS plot. @@ -458,6 +420,7 @@ class BandPdosWidget(ipw.VBox): Select the style of plotting the projected density of states. """ ) + projected_bands_width = 0.5 def __init__(self, bands=None, pdos=None, **kwargs): if bands is None and pdos is None: @@ -504,28 +467,35 @@ def __init__(self, bands=None, pdos=None, **kwargs): disabled=False, layout=ipw.Layout(visibility="hidden"), ) + self.project_bands_box = ipw.Checkbox( + value=False, + description="Add `fat bands` projections", + ) # Information for the plot - self.dos_data = self._get_dos_data() + self.pdos_data = self._get_pdos_data() self.bands_data = self._get_bands_data() # Plotly widget self.bandsplot_widget = BandPdosPlotly( - bands_data=self.bands_data, pdos_data=self.dos_data + bands_data=self.bands_data, pdos_data=self.pdos_data ).bandspdosfigure # Output widget to display the bandsplot widget self.bands_widget = ipw.Output() # Output widget to clear the specific widgets self.pdos_options_out = ipw.Output() - self.pdos_options = ipw.VBox( - [ - self.description, - self.dos_atoms_group, - self.dos_plot_group, - ipw.HBox([self.selected_atoms, self._wrong_syntax]), - self.update_plot_button, - ] - ) + pdos_options_list = [ + self.description, + self.dos_atoms_group, + self.dos_plot_group, + ipw.HBox([self.selected_atoms, self._wrong_syntax]), + self.update_plot_button, + ] + # If projections are available in the bands data, include the box to plot fat-bands + if self.bands_data and "projected_bands" in self.bands_data: + pdos_options_list.insert(4, self.project_bands_box) + + self.pdos_options = ipw.VBox(pdos_options_list) self._initial_view() @@ -541,22 +511,31 @@ def __init__(self, bands=None, pdos=None, **kwargs): ], **kwargs, ) - if self.pdos: + + # Plot the options only if the pdos is provided or in case the bands data contains projections + if self.pdos or (self.bands_data and "projected_bands" in self.bands_data): with self.pdos_options_out: display(self.pdos_options) def download_data(self, _=None): """Function to download the data.""" file_name_bands = "bands_data.json" - file_name_dos = "dos_data.json" + file_name_pdos = "dos_data.json" if self.bands_data: - json_str = json.dumps(self.bands_data) + bands_data_export = {} + for key, value in self.bands_data.items(): + if isinstance(value, np.ndarray): + bands_data_export[key] = value.tolist() + else: + bands_data_export[key] = value + + json_str = json.dumps(bands_data_export) b64_str = base64.b64encode(json_str.encode()).decode() self._download(payload=b64_str, filename=file_name_bands) - if self.dos_data: - json_str = json.dumps(self.dos_data) + if self.pdos_data: + json_str = json.dumps(self.pdos_data) b64_str = base64.b64encode(json_str.encode()).decode() - self._download(payload=b64_str, filename=file_name_dos) + self._download(payload=b64_str, filename=file_name_pdos) @staticmethod def _download(payload, filename): @@ -575,34 +554,45 @@ def _download(payload, filename): ) display(javas) - def _get_dos_data(self): + def _get_pdos_data(self): if not self.pdos: return None expanded_selection, syntax_ok = string_range_to_list( self.selected_atoms.value, shift=-1 ) if syntax_ok: - dos = get_pdos_data( + pdos = get_pdos_data( self.pdos, group_tag=self.dos_atoms_group.value, plot_tag=self.dos_plot_group.value, selected_atoms=expanded_selection, ) - return dos - else: - return None + return pdos + return None def _get_bands_data(self): if not self.bands: return None - bands = export_bands_data(self.bands) - return bands + expanded_selection, syntax_ok = string_range_to_list( + self.selected_atoms.value, shift=-1 + ) + if syntax_ok: + bands = get_bands_projections_data( + self.bands, + group_tag=self.dos_atoms_group.value, + plot_tag=self.dos_plot_group.value, + selected_atoms=expanded_selection, + bands_width=self.projected_bands_width, + ) + return bands + return None def _initial_view(self): with self.bands_widget: self._clear_output_and_display(self.bandsplot_widget) self.download_button.layout.visibility = "visible" + self.project_bands_box.layout.visibility = "visible" def _update_plot(self, _=None): with self.bands_widget: @@ -613,9 +603,12 @@ def _update_plot(self, _=None): self._wrong_syntax.message = """
ERROR: Invalid syntax for selected atoms
""" clear_output(wait=True) else: - self.dos_data = self._get_dos_data() + self.pdos_data = self._get_pdos_data() + self.bands_data = self._get_bands_data() self.bandsplot_widget = BandPdosPlotly( - bands_data=self.bands_data, pdos_data=self.dos_data + bands_data=self.bands_data, + pdos_data=self.pdos_data, + project_bands=self.project_bands_box.value, ).bandspdosfigure self._clear_output_and_display(self.bandsplot_widget) @@ -625,6 +618,136 @@ def _clear_output_and_display(self, widget=None): display(widget) +def _prepare_combined_plotly_traces(x_to_conc, y_to_conc): + """Combine multiple lines into a single trace. + + The rows of y are concatenated with a np.nan column as a separator. Moreover, + the x values are ajduced to match the shape of the concatenated y values. These + transfomred arrays, representing multiple datasets/lines, can be plotted in a single trace. + """ + if y_to_conc.ndim != 2: + raise ValueError("y must be a 2D array") + + y_dim0 = y_to_conc.shape[0] + + # Add a np.nan column as a separator + y_transf = np.hstack( + [ + y_to_conc, + np.full((y_dim0, 1), np.nan), + ] + ).flatten() + + # Same logic for the x axis + x_transf = x_to_conc.reshape(1, -1) * np.ones(y_dim0).reshape(-1, 1) + x_transf = np.hstack([x_transf, np.full((y_dim0, 1), np.nan)]).flatten() + + return x_transf, y_transf + + +def _prepare_projections_to_plot(bands_data, projections, bands_width): + """Prepare the projected bands to be plotted. + + This function transforms the projected bands into a format that can be plotted + in a single trace. To use the fill option `toself`, + a band needs to be concatenated with its mirror image, first. + """ + projected_bands = [] + for spin in [0, 1]: + # In case of non-spin-polarized calculations, the spin index is only 0 + if spin not in bands_data["band_type_idx"]: + continue + + x_bands = bands_data["x"] + # New shape: (number of bands, number of kpoints) + y_bands = bands_data["y"][:, bands_data["band_type_idx"] == spin].T + + for proj in projections[spin]: + # Create the upper and lower boundary of the fat bands based on the orbital projections + y_bands_proj_upper = y_bands + bands_width * proj["projections"].T + y_bands_proj_lower = y_bands - bands_width * proj["projections"].T + # As mentioned above, the bands need to be concatenated with their mirror image + # to create the filled areas properly + y_bands_mirror = np.hstack( + [y_bands_proj_upper, y_bands_proj_lower[:, ::-1]] + ) + # Same logic for the energy axis + x_bands_mirror = np.concatenate([x_bands, x_bands[::-1]]).reshape(1, -1) + x_bands_comb, y_bands_proj_comb = _prepare_combined_plotly_traces( + x_bands_mirror, y_bands_mirror + ) + + projected_bands.append( + { + "x": x_bands_comb.tolist(), + "y": y_bands_proj_comb.tolist(), + "label": proj["label"], + "color": proj["color"], + } + ) + return projected_bands + + +def get_bands_projections_data( + outputs, group_tag, plot_tag, selected_atoms, bands_width, fermi_energy=None +): + """Extract the bandstructure and possibly the projections along the bands.""" + if "band_structure" not in outputs: + return None + + bands_data = outputs.band_structure._get_bandplot_data( + cartesian=True, prettify_format=None, join_symbol=None, get_segments=True + ) + # The fermi energy from band calculation is not robust. + if "fermi_energy_up" in outputs.band_parameters: + bands_data["fermi_energy_up"] = outputs.band_parameters["fermi_energy_up"] + bands_data["fermi_energy_down"] = outputs.band_parameters["fermi_energy_down"] + else: + bands_data["fermi_energy"] = ( + outputs.band_parameters["fermi_energy"] or fermi_energy + ) + + bands_data["pathlabels"] = get_bands_labeling(bands_data) + + if "projwfc" in outputs: + projections = [] + + if "projections" in outputs.projwfc: + projections.append( + _projections_curated_options( + outputs.projwfc.projections, + spin_type="none", + group_tag=group_tag, + plot_tag=plot_tag, + selected_atoms=selected_atoms, + projections_pdos="projections", + ) + ) + else: + for spin_proj, spin_type in zip( + [ + outputs.projwfc.projections_up, + outputs.projwfc.projections_down, + ], + ["up", "down"], + ): + projections.append( + _projections_curated_options( + spin_proj, + spin_type=spin_type, + group_tag=group_tag, + plot_tag=plot_tag, + selected_atoms=selected_atoms, + projections_pdos="projections", + ) + ) + + bands_data["projected_bands"] = _prepare_projections_to_plot( + bands_data, projections, bands_width + ) + return bands_data + + def get_pdos_data(pdos, group_tag, plot_tag, selected_atoms): dos = [] @@ -706,17 +829,39 @@ def get_pdos_data(pdos, group_tag, plot_tag, selected_atoms): return json.loads(json.dumps(data_dict)) -def _projections_curated_options( - projections: ProjectionData, +def _get_grouping_key( group_tag, plot_tag, - selected_atoms, - spin_type="none", - line_style="solid", + atom_position, + kind_name, + orbital_name_plotly, + orbital_angular_momentum, ): - _pdos = {} - list_positions = [] + """Generates the grouping key based on group_tag and plot_tag.""" + + key_formats = { + ("atoms", "total"): r"{var1}-{var}", + ("kinds", "total"): r"{var1}", + ("atoms", "orbital"): r"{var1}-{var}
{var2}", + ("kinds", "orbital"): r"{var1}-{var2}", + ("atoms", "angular_momentum"): r"{var1}-{var}
{var3}", + ("kinds", "angular_momentum"): r"{var1}-{var3}", + } + key = key_formats.get((group_tag, plot_tag)) + if key is not None: + return key.format( + var=atom_position, + var1=kind_name, + var2=orbital_name_plotly, + var3=orbital_angular_momentum, + ) + else: + return None + + +def _curate_orbitals(orbital): + """Curate and transform the orbital data into the desired format.""" # Constants for HTML tags HTML_TAGS = { "s": "s", @@ -743,145 +888,124 @@ def _projections_curated_options( -2.5: "-5/2", } - # Constants for spin types - SPIN_LABELS = {"up": "(↑)", "down": "(↓)", "none": ""} + orbital_data = orbital.get_orbital_dict() + kind_name = orbital_data["kind_name"] + atom_position = [round(i, 2) for i in orbital_data["position"]] + + try: + orbital_name = orbital.get_name_from_quantum_numbers( + orbital_data["angular_momentum"], orbital_data["magnetic_number"] + ).lower() + orbital_name_plotly = HTML_TAGS.get(orbital_name, orbital_name) + orbital_angular_momentum = orbital_name[0] + except AttributeError: + # Set quanutum numbers + qn_j = orbital_data["total_angular_momentum"] + qn_l = orbital_data["angular_momentum"] + qn_m_j = orbital_data["magnetic_number"] + orbital_name = "j {j} l {l} m_j{m_j}".format(j=qn_j, l=qn_l, m_j=qn_m_j) + orbital_name_plotly = "j={j} l={l} mj={m_j}".format( + j=HTML_TAGS.get(qn_j, qn_j), + l=qn_l, + m_j=HTML_TAGS.get(qn_m_j, qn_m_j), + ) + orbital_angular_momentum = "l {l} ".format(l=qn_l) - def get_key( - group_tag, - plot_tag, - atom_position, - kind_name, - orbital_name_plotly, - orbital_angular_momentum, - ): - """Generates the key based on group_tag and plot_tag.""" - - key_formats = { - ("atoms", "total"): r"{var1}-{var}", - ("kinds", "total"): r"{var1}", - ("atoms", "orbital"): r"{var1}-{var}
{var2}", - ("kinds", "orbital"): r"{var1}-{var2}", - ("atoms", "angular_momentum"): r"{var1}-{var}
{var3}", - ("kinds", "angular_momentum"): r"{var1}-{var3}", - } + return orbital_name_plotly, orbital_angular_momentum, kind_name, atom_position - key = key_formats.get((group_tag, plot_tag)) - if key is not None: - return key.format( - var=atom_position, - var1=kind_name, - var2=orbital_name_plotly, - var3=orbital_angular_momentum, - ) - else: - return None - for orbital, pdos, energy in projections.get_pdos(): - orbital_data = orbital.get_orbital_dict() - kind_name = orbital_data["kind_name"] - atom_position = [round(i, 2) for i in orbital_data["position"]] +def _projections_curated_options( + projections: ProjectionData, + group_tag, + plot_tag, + selected_atoms, + projections_pdos="pdos", + spin_type="none", + line_style="solid", +): + """Extract and curate the projections. + + This function can be used to extract the PDOS or the projections data. + """ + _proj_pdos = {} + list_positions = [] + + # Constants for spin types + SPIN_LABELS = {"up": "(↑)", "down": "(↓)", "none": ""} + SIGN_MULT_FACTOR = {"up": 1, "down": -1, "none": 1} + + if projections_pdos == "pdos": + proj_data = projections.get_pdos() + elif projections_pdos == "projections": + proj_data = projections.get_projections() + else: + raise ValueError(f"Invalid value for `projections_pdos`: {projections_pdos}") + + for orb_proj in proj_data: + if projections_pdos == "pdos": + orbital, proj_pdos, energy = orb_proj + elif projections_pdos == "projections": + orbital, proj_pdos = orb_proj + energy = None + + ( + orbital_name_plotly, + orbital_angular_momentum, + kind_name, + atom_position, + ) = _curate_orbitals(orbital) if atom_position not in list_positions: list_positions.append(atom_position) - try: - orbital_name = orbital.get_name_from_quantum_numbers( - orbital_data["angular_momentum"], orbital_data["magnetic_number"] - ).lower() - orbital_name_plotly = HTML_TAGS.get(orbital_name, orbital_name) - orbital_angular_momentum = orbital_name[0] - except AttributeError: - orbital_name = "j {j} l {l} m_j{m_j}".format( - j=orbital_data["total_angular_momentum"], - l=orbital_data["angular_momentum"], - m_j=orbital_data["magnetic_number"], - ) - orbital_name_plotly = "j={j} l={l} mj={m_j}".format( - j=HTML_TAGS.get( - orbital_data["total_angular_momentum"], - orbital_data["total_angular_momentum"], - ), - l=orbital_data["angular_momentum"], - m_j=HTML_TAGS.get( - orbital_data["magnetic_number"], orbital_data["magnetic_number"] - ), - ) - orbital_angular_momentum = "l {l} ".format( - l=orbital_data["angular_momentum"], - ) - + key = _get_grouping_key( + group_tag, + plot_tag, + atom_position, + kind_name, + orbital_name_plotly, + orbital_angular_momentum, + ) if not selected_atoms: - key = get_key( - group_tag, - plot_tag, - atom_position, - kind_name, - orbital_name_plotly, - orbital_angular_momentum, - ) - if key: - _pdos.setdefault(key, [energy, 0])[1] += pdos + _proj_pdos.setdefault(key, [energy, 0])[1] += proj_pdos else: try: index = list_positions.index(atom_position) if index in selected_atoms: - key = get_key( - group_tag, - plot_tag, - atom_position, - kind_name, - orbital_name_plotly, - orbital_angular_momentum, - ) - if key: - _pdos.setdefault(key, [energy, 0])[1] += pdos + _proj_pdos.setdefault(key, [energy, 0])[1] += proj_pdos except ValueError: pass - dos = [] - for label, (energy, pdos) in _pdos.items(): - if spin_type == "down": - pdos = -pdos - label += SPIN_LABELS[spin_type] - - if spin_type == "up": - label += SPIN_LABELS[spin_type] - - orbital_pdos = { - "label": label, - "x": energy.tolist(), - "y": pdos.tolist(), - "borderColor": cmap(label), - "lineStyle": line_style, - } - dos.append(orbital_pdos) - - return dos - - -def export_bands_data(outputs, fermi_energy=None): - if "band_structure" not in outputs: - return None + curated_proj = [] + for label, (energy, proj_pdos) in _proj_pdos.items(): + label += SPIN_LABELS[spin_type] + if projections_pdos == "pdos": + orbital_proj_pdos = { + "label": label, + "x": energy.tolist(), + "y": (SIGN_MULT_FACTOR[spin_type] * proj_pdos).tolist(), + "borderColor": cmap(label), + "lineStyle": line_style, + } + else: + orbital_proj_pdos = { + "label": label, + "projections": proj_pdos, + "color": cmap(label), + } + curated_proj.append(orbital_proj_pdos) - data = json.loads(outputs.band_structure._exportcontent("json", comments=False)[0]) - # The fermi energy from band calculation is not robust. - if "fermi_energy_up" in outputs.band_parameters: - data["fermi_energy_up"] = outputs.band_parameters["fermi_energy_up"] - data["fermi_energy_down"] = outputs.band_parameters["fermi_energy_down"] - else: - data["fermi_energy"] = outputs.band_parameters["fermi_energy"] or fermi_energy - data["pathlabels"] = get_bands_labeling(data) - return data + return curated_proj def get_bands_labeling(bandsdata: dict) -> list: """Function to return two lists containing the labels and values (kpoint) for plotting. params: - - bandsdata: dictionary from export_bands_data function + - bandsdata: dictionary from `get_bands_projections_data` function output: update bandsdata with a new key "pathlabels" including (list of str), label_values (list of float) """ UNICODE_SYMBOL = { diff --git a/tests/test_plugins_bands.py b/tests/test_plugins_bands.py index a586ca285..27533eec0 100644 --- a/tests/test_plugins_bands.py +++ b/tests/test_plugins_bands.py @@ -17,7 +17,7 @@ def test_result(generate_qeapp_workchain): # Check if data is correct assert result.children[0].bands_data is not None assert result.children[0].bands_data["pathlabels"] is not None - assert result.children[0].dos_data is None + assert result.children[0].pdos_data is None # Check Bands axis assert result.children[0].bandsplot_widget.layout.xaxis.title.text == "k-points" diff --git a/tests/test_plugins_electronic_structure.py b/tests/test_plugins_electronic_structure.py index 8e209cede..c2c0203d3 100644 --- a/tests/test_plugins_electronic_structure.py +++ b/tests/test_plugins_electronic_structure.py @@ -31,7 +31,7 @@ def test_electronic_structure(generate_qeapp_workchain): # Check if data is correct assert result.children[0].bands_data is not None assert result.children[0].bands_data["pathlabels"] is not None - assert result.children[0].dos_data is not None + assert result.children[0].pdos_data is not None # Check Bands axis assert result.children[0].bandsplot_widget.layout.xaxis.title.text == "k-points" diff --git a/tests/test_plugins_pdos.py b/tests/test_plugins_pdos.py index 354c67e7f..6d02a140e 100644 --- a/tests/test_plugins_pdos.py +++ b/tests/test_plugins_pdos.py @@ -16,7 +16,7 @@ def test_result(generate_qeapp_workchain): # Check if data is correct assert result.children[0].bands_data is None - assert result.children[0].dos_data is not None + assert result.children[0].pdos_data is not None # Check PDOS settings is not None