diff --git a/docs/key_functionality.md b/docs/key_functionality.md index 13d16c8..1b4a743 100644 --- a/docs/key_functionality.md +++ b/docs/key_functionality.md @@ -197,15 +197,18 @@ ds = sdfxr.open_mfdataset("tutorial_dataset_1d/*.sdf") ds["Electric_Field_Ex"] ``` -On top of accessing variables you can plot these -using the built-in function (see -) which is -a simple call to `matplotlib`. This also means that you can access -all the methods from `matplotlib` to manipulate your plot. +On top of accessing variables, you can plot these datasets using +[`xarray.DataArray.epoch.plot`](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot). +This is a custom plotting routine that builds on top of +, so you keep the familiar plotting +behaviour while using conveniences (see +[here](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot) for details). +Under the hood, plotting is still handled by , which means you +can use the full API to customise your figure. ```{code-cell} ipython3 # This is discretized in both space and time -ds["Electric_Field_Ex"].plot() +ds["Electric_Field_Ex"].epoch.plot() plt.title("Electric field along the x-axis") plt.show() ``` @@ -231,7 +234,6 @@ done by passsing the index to the `time` parameter (e.g., `time=0` for the first snapshot). ```{code-cell} ipython3 -# We can plot the variable at a given time index ds["Electric_Field_Ex"].isel(time=20) ``` @@ -296,17 +298,17 @@ ds["Laser_Absorption_Fraction_in_Simulation"] = ( ds["Laser_Absorption_Fraction_in_Simulation"].attrs["units"] = "%" ds["Laser_Absorption_Fraction_in_Simulation"].attrs["long_name"] = "Laser Absorption Fraction" -ds["Laser_Absorption_Fraction_in_Simulation"].plot() +ds["Laser_Absorption_Fraction_in_Simulation"].epoch.plot() plt.title("Laser absorption fraction in simulation") plt.show() ``` -You can also call the `plot()` function on several variables with +You can also call the [`xarray.DataArray.epoch.plot`](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot) function on several variables with labels by delaying the call to `plt.show()`. ```{code-cell} ipython3 -ds["Total_Particle_Energy_Electron"].plot(label="Electron") -ds["Total_Particle_Energy_Ion"].plot(label="Ion") +ds["Total_Particle_Energy_Electron"].epoch.plot(label="Electron") +ds["Total_Particle_Energy_Ion"].epoch.plot(label="Ion") plt.title("Particle Energy in Simulation per Species") plt.legend() plt.show() diff --git a/docs/unit_conversion.md b/docs/unit_conversion.md index b98f2d7..113c530 100644 --- a/docs/unit_conversion.md +++ b/docs/unit_conversion.md @@ -43,10 +43,10 @@ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf") ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) -ds["Derived_Number_Density_Electron"].isel(time=0).plot(ax=ax1, x="X_Grid_mid", y="Y_Grid_mid") +ds["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax1) ax1.set_title("Original X Coordinate (m)") -ds_in_microns["Derived_Number_Density_Electron"].isel(time=0).plot(ax=ax2, x="X_Grid_mid", y="Y_Grid_mid") +ds_in_microns["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax2) ax2.set_title("Rescaled X Coordinate (µm)") fig.tight_layout() @@ -197,9 +197,9 @@ converted side by side: ```{code-cell} ipython3 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16,8)) -ds["Total_Particle_Energy_Electron"].plot(ax=ax1) -total_particle_energy_ev.plot(ax=ax2) -total_particle_energy_w.plot(ax=ax3) +ds["Total_Particle_Energy_Electron"].epoch.plot(ax=ax1) +total_particle_energy_ev.epoch.plot(ax=ax2) +total_particle_energy_w.epoch.plot(ax=ax3) ax4.set_visible(False) fig.suptitle("Comparison of conversion from Joules to electron volts and watts") fig.tight_layout() diff --git a/src/sdf_xarray/__init__.py b/src/sdf_xarray/__init__.py index 55068ec..0785803 100644 --- a/src/sdf_xarray/__init__.py +++ b/src/sdf_xarray/__init__.py @@ -23,6 +23,7 @@ # NOTE: Do not delete these lines, otherwise the "epoch" dataset and dataarray # accessors will not be imported when the user imports sdf_xarray +import sdf_xarray.dataarray_accessor import sdf_xarray.dataset_accessor import sdf_xarray.download import sdf_xarray.plotting # noqa: F401 diff --git a/src/sdf_xarray/dataarray_accessor.py b/src/sdf_xarray/dataarray_accessor.py new file mode 100644 index 0000000..e31fa13 --- /dev/null +++ b/src/sdf_xarray/dataarray_accessor.py @@ -0,0 +1,69 @@ +from types import MethodType + +import xarray as xr +from matplotlib.animation import FuncAnimation +from xarray.plot.accessor import DataArrayPlotAccessor + +from .plotting import animate, show + + +@xr.register_dataarray_accessor("epoch") +class EpochAccessor: + def __init__(self, xarray_obj: xr.DataArray): + self._obj = xarray_obj + + def plot(self, *args, **kwargs) -> DataArrayPlotAccessor: + """ + Builds upon `xarray.DataArray.plot` while changing some of its default behaviours. + + These changes are: + + - Flips the default axes order for 2D plots so that x and y are on the correct axes. + This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh` + function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``. + + Parameters + ---------- + args + Positional arguments passed to `xarray.DataArray.plot`. + kwargs + Keyword arguments passed to `xarray.DataArray.plot`. + """ + dims = self._obj.dims + is_not_2d_data = len(dims) != 2 + is_time_dim_present = "time" in dims + is_x_or_y_specified_in_kwargs = "x" in kwargs or "y" in kwargs + + if is_not_2d_data or is_time_dim_present or is_x_or_y_specified_in_kwargs: + return self._obj.plot(*args, **kwargs) + + updated_kwargs = dict(kwargs) + updated_kwargs.setdefault("x", dims[0]) + updated_kwargs.setdefault("y", dims[1]) + + return self._obj.plot(*args, **updated_kwargs) + + def animate(self, *args, **kwargs) -> FuncAnimation: + """Generate animations of Epoch data. + + Parameters + ---------- + args + Positional arguments passed to :func:`animation`. + kwargs + Keyword arguments passed to :func:`animation`. + + Examples + -------- + >>> anim = ds["Electric_Field_Ey"].epoch.animate() + >>> anim.save("animation.gif") + >>> # Or in a jupyter notebook: + >>> anim.show() + """ + + # Add anim.show() functionality + # anim.show() will display the animation in a jupyter notebook + anim = animate(self._obj, *args, **kwargs) + anim.show = MethodType(show, anim) + + return anim diff --git a/src/sdf_xarray/plotting.py b/src/sdf_xarray/plotting.py index a3c2db9..fe312a9 100644 --- a/src/sdf_xarray/plotting.py +++ b/src/sdf_xarray/plotting.py @@ -3,7 +3,6 @@ import warnings from collections.abc import Callable from dataclasses import dataclass -from types import MethodType from typing import TYPE_CHECKING, Any import numpy as np @@ -534,34 +533,3 @@ def show(anim): from IPython.display import HTML # noqa: PLC0415 return HTML(anim.to_jshtml()) - - -@xr.register_dataarray_accessor("epoch") -class EpochAccessor: - def __init__(self, xarray_obj): - self._obj = xarray_obj - - def animate(self, *args, **kwargs) -> FuncAnimation: - """Generate animations of Epoch data. - - Parameters - ---------- - args - Positional arguments passed to :func:`animation`. - kwargs - Keyword arguments passed to :func:`animation`. - - Examples - -------- - >>> anim = ds["Electric_Field_Ey"].epoch.animate() - >>> anim.save("animation.gif") - >>> # Or in a jupyter notebook: - >>> anim.show() - """ - - # Add anim.show() functionality - # anim.show() will display the animation in a jupyter notebook - anim = animate(self._obj, *args, **kwargs) - anim.show = MethodType(show, anim) - - return anim diff --git a/tests/test_epoch_dataarray_accessor.py b/tests/test_epoch_dataarray_accessor.py index 9069a74..d9b6513 100644 --- a/tests/test_epoch_dataarray_accessor.py +++ b/tests/test_epoch_dataarray_accessor.py @@ -2,10 +2,12 @@ from importlib.metadata import version import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np import pytest import xarray as xr from matplotlib.animation import PillowWriter +from matplotlib.container import BarContainer from packaging.version import Version import sdf_xarray as sdfxr @@ -23,6 +25,13 @@ TEST_FILES_DIR_3D = download.fetch_dataset("test_files_3D") +@pytest.fixture +def subplots(): + fig, ax = plt.subplots() + yield (fig, ax) + plt.close(fig) + + def test_animation_accessor(): array = xr.DataArray( [1, 2, 3], @@ -234,3 +243,95 @@ def test_compute_global_limits_NaNs(): expected_result_max = 2.70 assert result_min == pytest.approx(expected_result_min, abs=1e-2) assert result_max == pytest.approx(expected_result_max, abs=1e-1) + + +def test_epoch_plot_simple_1d_dataset(subplots): + with xr.open_mfdataset( + TEST_FILES_DIR_1D.glob("*.sdf"), + compat="no_conflicts", + join="outer", + preprocess=SDFPreprocess(), + ) as ds: + _, ax = subplots + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) + + assert len(ax.lines) == 1 + assert ax.get_xlabel() == "X [m]" + + +def test_epoch_plot_simple_2d_dataset(subplots): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + _, ax = subplots + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) + + assert len(ax.collections) > 0 + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + + +def test_epoch_plot_simple_3d_dataset_slice(subplots): + with xr.open_dataset(TEST_FILES_DIR_3D / "0001.sdf") as ds: + _, ax = subplots + ds["Derived_Number_Density_Electron"].isel(Z_Grid_mid=0).epoch.plot(ax=ax) + + assert len(ax.collections) > 0 + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + + +def test_epoch_plot_flips_axis_order_for_2d_data(subplots): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + _, ax = subplots + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) + + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + + +def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(subplots): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + _, ax = subplots + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot( + ax=ax, + xlim=(0.5, 1.0), + ylim=(0.0, 0.5), + ) + + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + assert ax.get_xlim() == pytest.approx((0.5, 1.0), abs=1e-2) + assert ax.get_ylim() == pytest.approx((0.0, 0.5), abs=1e-2) + + +def test_epoch_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present( + subplots, +): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + _, ax = subplots + plot = ds["Derived_Number_Density_electron"].epoch.plot(ax=ax) + + assert type(plot[2]) is BarContainer