From ed6e363bb350c347193529a35aa421de2f8866e2 Mon Sep 17 00:00:00 2001 From: Joel Adams Date: Fri, 13 Mar 2026 12:54:12 +0100 Subject: [PATCH 1/5] monkey patch xarray plotting to fix default axes for 2D data --- src/sdf_xarray/__init__.py | 36 +++++++++++++++++ tests/test_epoch_dataarray_accessor.py | 55 ++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/src/sdf_xarray/__init__.py b/src/sdf_xarray/__init__.py index 55068ec..8756749 100644 --- a/src/sdf_xarray/__init__.py +++ b/src/sdf_xarray/__init__.py @@ -20,6 +20,7 @@ from xarray.core.types import T_Chunks from xarray.core.utils import close_on_error, try_read_magic_number_from_path from xarray.core.variable import Variable +from xarray.plot.accessor import DataArrayPlotAccessor # NOTE: Do not delete these lines, otherwise the "epoch" dataset and dataarray # accessors will not be imported when the user imports sdf_xarray @@ -41,6 +42,39 @@ PathLike = str | os_PathLike +def _patch_xarray_plot_axes() -> None: + """ + Changes the default axes order for 2D plots so that x and y are on the correct axes. + + This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh` + function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``. + """ + original_call = DataArrayPlotAccessor.__call__ + + def _sdf_plot_call( + self, + **kwargs, + ): + dims = self._da.dims + is_not_2d_data = len(dims) != 2 + is_time_dim_present = "time" in dims + is_x_or_y_specified_in_kwargs = "x" in kwargs or "y" in kwargs + + if is_not_2d_data or is_time_dim_present or is_x_or_y_specified_in_kwargs: + return original_call(self, **kwargs) + + updated_kwargs = dict(kwargs) + updated_kwargs.setdefault("x", dims[0]) + updated_kwargs.setdefault("y", dims[1]) + + return original_call( + self, + **updated_kwargs, + ) + + DataArrayPlotAccessor.__call__ = _sdf_plot_call + + def _rename_with_underscore(name: str) -> str: """A lot of the variable names have spaces, forward slashes and dashes in them, which are not valid in netCDF names so we replace them with underscores.""" @@ -768,6 +802,7 @@ def _process_grid_name(grid_name: str, transform_func) -> str: data_attrs = {} data_attrs["full_name"] = key data_attrs["long_name"] = base_name.replace("_", " ") + if value.units is not None: data_attrs["units"] = value.units @@ -871,6 +906,7 @@ def _process_grid_name(grid_name: str, transform_func) -> str: ds = xr.Dataset(data_vars, attrs=attrs, coords=coords) ds.attrs["deck"] = _load_deck(ds.attrs["filename"], self.deck_path) + _patch_xarray_plot_axes() ds.set_close(self.ds.close) return ds diff --git a/tests/test_epoch_dataarray_accessor.py b/tests/test_epoch_dataarray_accessor.py index 9069a74..cbe79ed 100644 --- a/tests/test_epoch_dataarray_accessor.py +++ b/tests/test_epoch_dataarray_accessor.py @@ -2,10 +2,12 @@ from importlib.metadata import version import matplotlib as mpl +import matplotlib.pyplot as plt import numpy as np import pytest import xarray as xr from matplotlib.animation import PillowWriter +from matplotlib.container import BarContainer from packaging.version import Version import sdf_xarray as sdfxr @@ -234,3 +236,56 @@ def test_compute_global_limits_NaNs(): expected_result_max = 2.70 assert result_min == pytest.approx(expected_result_min, abs=1e-2) assert result_max == pytest.approx(expected_result_max, abs=1e-1) + + +def test_default_plot_flips_axis_order_for_2d_data(): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + plot = ds["Derived_Number_Density_electron"].isel(time=0).plot() + ax = plot.axes + + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + plt.close(ax.figure) + + +def test_default_plot_flips_axis_order_for_2d_data_with_additional_params(): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + fig, ax = plt.subplots() + ds["Derived_Number_Density_electron"].isel(time=0).plot( + ax=ax, + xlim=(0.5, 1.0), + ylim=(0.0, 0.5), + ) + + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + assert ax.get_xlim() == pytest.approx((0.5, 1.0), abs=1e-2) + assert ax.get_ylim() == pytest.approx((0.0, 0.5), abs=1e-2) + plt.close(fig) + + +def test_default_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present(): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + fig, ax = plt.subplots() + plot = ds["Derived_Number_Density_electron"].plot(ax=ax) + + assert type(plot[2]) is BarContainer + plt.close(fig) From 584db27f12da5a3d9ecf2757891f9460f40562e7 Mon Sep 17 00:00:00 2001 From: Joel Adams Date: Mon, 16 Mar 2026 08:47:31 +0000 Subject: [PATCH 2/5] refactor plotting monkeypatch to dataarray accessor --- src/sdf_xarray/__init__.py | 37 +--------------- src/sdf_xarray/dataarray_accessor.py | 61 ++++++++++++++++++++++++++ src/sdf_xarray/plotting.py | 32 -------------- tests/test_epoch_dataarray_accessor.py | 59 +++++++++++++++++++++---- 4 files changed, 113 insertions(+), 76 deletions(-) create mode 100644 src/sdf_xarray/dataarray_accessor.py diff --git a/src/sdf_xarray/__init__.py b/src/sdf_xarray/__init__.py index 8756749..0785803 100644 --- a/src/sdf_xarray/__init__.py +++ b/src/sdf_xarray/__init__.py @@ -20,10 +20,10 @@ from xarray.core.types import T_Chunks from xarray.core.utils import close_on_error, try_read_magic_number_from_path from xarray.core.variable import Variable -from xarray.plot.accessor import DataArrayPlotAccessor # NOTE: Do not delete these lines, otherwise the "epoch" dataset and dataarray # accessors will not be imported when the user imports sdf_xarray +import sdf_xarray.dataarray_accessor import sdf_xarray.dataset_accessor import sdf_xarray.download import sdf_xarray.plotting # noqa: F401 @@ -42,39 +42,6 @@ PathLike = str | os_PathLike -def _patch_xarray_plot_axes() -> None: - """ - Changes the default axes order for 2D plots so that x and y are on the correct axes. - - This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh` - function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``. - """ - original_call = DataArrayPlotAccessor.__call__ - - def _sdf_plot_call( - self, - **kwargs, - ): - dims = self._da.dims - is_not_2d_data = len(dims) != 2 - is_time_dim_present = "time" in dims - is_x_or_y_specified_in_kwargs = "x" in kwargs or "y" in kwargs - - if is_not_2d_data or is_time_dim_present or is_x_or_y_specified_in_kwargs: - return original_call(self, **kwargs) - - updated_kwargs = dict(kwargs) - updated_kwargs.setdefault("x", dims[0]) - updated_kwargs.setdefault("y", dims[1]) - - return original_call( - self, - **updated_kwargs, - ) - - DataArrayPlotAccessor.__call__ = _sdf_plot_call - - def _rename_with_underscore(name: str) -> str: """A lot of the variable names have spaces, forward slashes and dashes in them, which are not valid in netCDF names so we replace them with underscores.""" @@ -802,7 +769,6 @@ def _process_grid_name(grid_name: str, transform_func) -> str: data_attrs = {} data_attrs["full_name"] = key data_attrs["long_name"] = base_name.replace("_", " ") - if value.units is not None: data_attrs["units"] = value.units @@ -906,7 +872,6 @@ def _process_grid_name(grid_name: str, transform_func) -> str: ds = xr.Dataset(data_vars, attrs=attrs, coords=coords) ds.attrs["deck"] = _load_deck(ds.attrs["filename"], self.deck_path) - _patch_xarray_plot_axes() ds.set_close(self.ds.close) return ds diff --git a/src/sdf_xarray/dataarray_accessor.py b/src/sdf_xarray/dataarray_accessor.py new file mode 100644 index 0000000..5e66d99 --- /dev/null +++ b/src/sdf_xarray/dataarray_accessor.py @@ -0,0 +1,61 @@ +from types import MethodType +from typing import Any + +import xarray as xr +from matplotlib.animation import FuncAnimation + +from .plotting import animate, show + + +@xr.register_dataarray_accessor("epoch") +class EpochAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def plot(self, *args, **kwargs) -> Any: + """ + Builds upon `xarray.plot` while changing some of its default behaviours. + + Those changes are: + - Flips the default axes order for 2D plots so that x and y are on the correct axes. + This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh` + function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``. + """ + dims = self._obj.dims + is_not_2d_data = len(dims) != 2 + is_time_dim_present = "time" in dims + is_x_or_y_specified_in_kwargs = "x" in kwargs or "y" in kwargs + + if is_not_2d_data or is_time_dim_present or is_x_or_y_specified_in_kwargs: + return self._obj.plot(*args, **kwargs) + + updated_kwargs = dict(kwargs) + updated_kwargs.setdefault("x", dims[0]) + updated_kwargs.setdefault("y", dims[1]) + + return self._obj.plot(*args, **updated_kwargs) + + def animate(self, *args, **kwargs) -> FuncAnimation: + """Generate animations of Epoch data. + + Parameters + ---------- + args + Positional arguments passed to :func:`animation`. + kwargs + Keyword arguments passed to :func:`animation`. + + Examples + -------- + >>> anim = ds["Electric_Field_Ey"].epoch.animate() + >>> anim.save("animation.gif") + >>> # Or in a jupyter notebook: + >>> anim.show() + """ + + # Add anim.show() functionality + # anim.show() will display the animation in a jupyter notebook + anim = animate(self._obj, *args, **kwargs) + anim.show = MethodType(show, anim) + + return anim diff --git a/src/sdf_xarray/plotting.py b/src/sdf_xarray/plotting.py index a3c2db9..fe312a9 100644 --- a/src/sdf_xarray/plotting.py +++ b/src/sdf_xarray/plotting.py @@ -3,7 +3,6 @@ import warnings from collections.abc import Callable from dataclasses import dataclass -from types import MethodType from typing import TYPE_CHECKING, Any import numpy as np @@ -534,34 +533,3 @@ def show(anim): from IPython.display import HTML # noqa: PLC0415 return HTML(anim.to_jshtml()) - - -@xr.register_dataarray_accessor("epoch") -class EpochAccessor: - def __init__(self, xarray_obj): - self._obj = xarray_obj - - def animate(self, *args, **kwargs) -> FuncAnimation: - """Generate animations of Epoch data. - - Parameters - ---------- - args - Positional arguments passed to :func:`animation`. - kwargs - Keyword arguments passed to :func:`animation`. - - Examples - -------- - >>> anim = ds["Electric_Field_Ey"].epoch.animate() - >>> anim.save("animation.gif") - >>> # Or in a jupyter notebook: - >>> anim.show() - """ - - # Add anim.show() functionality - # anim.show() will display the animation in a jupyter notebook - anim = animate(self._obj, *args, **kwargs) - anim.show = MethodType(show, anim) - - return anim diff --git a/tests/test_epoch_dataarray_accessor.py b/tests/test_epoch_dataarray_accessor.py index cbe79ed..ffccb26 100644 --- a/tests/test_epoch_dataarray_accessor.py +++ b/tests/test_epoch_dataarray_accessor.py @@ -238,7 +238,50 @@ def test_compute_global_limits_NaNs(): assert result_max == pytest.approx(expected_result_max, abs=1e-1) -def test_default_plot_flips_axis_order_for_2d_data(): +def test_epoch_plot_simple_1d_dataset(): + with xr.open_mfdataset( + TEST_FILES_DIR_1D.glob("*.sdf"), + compat="no_conflicts", + join="outer", + preprocess=SDFPreprocess(), + ) as ds: + fig, ax = plt.subplots() + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) + + assert len(ax.lines) == 1 + assert ax.get_xlabel() == "X [m]" + plt.close(fig) + + +def test_epoch_plot_simple_2d_dataset(): + with xr.open_mfdataset( + TEST_FILES_DIR_2D_MW.glob("*.sdf"), + preprocess=SDFPreprocess(), + combine="nested", + compat="no_conflicts", + join="outer", + ) as ds: + fig, ax = plt.subplots() + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) + + assert len(ax.collections) > 0 + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + plt.close(fig) + + +def test_epoch_plot_simple_3d_dataset_slice(): + with xr.open_dataset(TEST_FILES_DIR_3D / "0001.sdf") as ds: + fig, ax = plt.subplots() + ds["Derived_Number_Density_Electron"].isel(Z_Grid_mid=0).epoch.plot(ax=ax) + + assert len(ax.collections) > 0 + assert ax.get_xlabel() == "X [m]" + assert ax.get_ylabel() == "Y [m]" + plt.close(fig) + + +def test_epoch_plot_flips_axis_order_for_2d_data(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -246,15 +289,15 @@ def test_default_plot_flips_axis_order_for_2d_data(): compat="no_conflicts", join="outer", ) as ds: - plot = ds["Derived_Number_Density_electron"].isel(time=0).plot() - ax = plot.axes + fig, ax = plt.subplots() + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) assert ax.get_xlabel() == "X [m]" assert ax.get_ylabel() == "Y [m]" - plt.close(ax.figure) + plt.close(fig) -def test_default_plot_flips_axis_order_for_2d_data_with_additional_params(): +def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -263,7 +306,7 @@ def test_default_plot_flips_axis_order_for_2d_data_with_additional_params(): join="outer", ) as ds: fig, ax = plt.subplots() - ds["Derived_Number_Density_electron"].isel(time=0).plot( + ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot( ax=ax, xlim=(0.5, 1.0), ylim=(0.0, 0.5), @@ -276,7 +319,7 @@ def test_default_plot_flips_axis_order_for_2d_data_with_additional_params(): plt.close(fig) -def test_default_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present(): +def test_epoch_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present(): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -285,7 +328,7 @@ def test_default_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present join="outer", ) as ds: fig, ax = plt.subplots() - plot = ds["Derived_Number_Density_electron"].plot(ax=ax) + plot = ds["Derived_Number_Density_electron"].epoch.plot(ax=ax) assert type(plot[2]) is BarContainer plt.close(fig) From c5cbb41bf033c702865837641d48fd8c236c04d4 Mon Sep 17 00:00:00 2001 From: Joel Adams Date: Mon, 16 Mar 2026 09:25:10 +0000 Subject: [PATCH 3/5] improve docstring and typing --- src/sdf_xarray/dataarray_accessor.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/sdf_xarray/dataarray_accessor.py b/src/sdf_xarray/dataarray_accessor.py index 5e66d99..e31fa13 100644 --- a/src/sdf_xarray/dataarray_accessor.py +++ b/src/sdf_xarray/dataarray_accessor.py @@ -1,25 +1,33 @@ from types import MethodType -from typing import Any import xarray as xr from matplotlib.animation import FuncAnimation +from xarray.plot.accessor import DataArrayPlotAccessor from .plotting import animate, show @xr.register_dataarray_accessor("epoch") class EpochAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: xr.DataArray): self._obj = xarray_obj - def plot(self, *args, **kwargs) -> Any: + def plot(self, *args, **kwargs) -> DataArrayPlotAccessor: """ - Builds upon `xarray.plot` while changing some of its default behaviours. + Builds upon `xarray.DataArray.plot` while changing some of its default behaviours. + + These changes are: - Those changes are: - Flips the default axes order for 2D plots so that x and y are on the correct axes. - This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh` - function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``. + This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh` + function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``. + + Parameters + ---------- + args + Positional arguments passed to `xarray.DataArray.plot`. + kwargs + Keyword arguments passed to `xarray.DataArray.plot`. """ dims = self._obj.dims is_not_2d_data = len(dims) != 2 From 87cb1f3ebdd463f12893cd39c260f2d3fc6ff65e Mon Sep 17 00:00:00 2001 From: Joel Adams Date: Mon, 16 Mar 2026 09:25:52 +0000 Subject: [PATCH 4/5] Update existing documentation to use `.epoch.plot()` instead of `xarray.plot()` --- docs/key_functionality.md | 24 +++++++++++++----------- docs/unit_conversion.md | 10 +++++----- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/docs/key_functionality.md b/docs/key_functionality.md index 13d16c8..1b4a743 100644 --- a/docs/key_functionality.md +++ b/docs/key_functionality.md @@ -197,15 +197,18 @@ ds = sdfxr.open_mfdataset("tutorial_dataset_1d/*.sdf") ds["Electric_Field_Ex"] ``` -On top of accessing variables you can plot these -using the built-in function (see -) which is -a simple call to `matplotlib`. This also means that you can access -all the methods from `matplotlib` to manipulate your plot. +On top of accessing variables, you can plot these datasets using +[`xarray.DataArray.epoch.plot`](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot). +This is a custom plotting routine that builds on top of +, so you keep the familiar plotting +behaviour while using conveniences (see +[here](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot) for details). +Under the hood, plotting is still handled by , which means you +can use the full API to customise your figure. ```{code-cell} ipython3 # This is discretized in both space and time -ds["Electric_Field_Ex"].plot() +ds["Electric_Field_Ex"].epoch.plot() plt.title("Electric field along the x-axis") plt.show() ``` @@ -231,7 +234,6 @@ done by passsing the index to the `time` parameter (e.g., `time=0` for the first snapshot). ```{code-cell} ipython3 -# We can plot the variable at a given time index ds["Electric_Field_Ex"].isel(time=20) ``` @@ -296,17 +298,17 @@ ds["Laser_Absorption_Fraction_in_Simulation"] = ( ds["Laser_Absorption_Fraction_in_Simulation"].attrs["units"] = "%" ds["Laser_Absorption_Fraction_in_Simulation"].attrs["long_name"] = "Laser Absorption Fraction" -ds["Laser_Absorption_Fraction_in_Simulation"].plot() +ds["Laser_Absorption_Fraction_in_Simulation"].epoch.plot() plt.title("Laser absorption fraction in simulation") plt.show() ``` -You can also call the `plot()` function on several variables with +You can also call the [`xarray.DataArray.epoch.plot`](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot) function on several variables with labels by delaying the call to `plt.show()`. ```{code-cell} ipython3 -ds["Total_Particle_Energy_Electron"].plot(label="Electron") -ds["Total_Particle_Energy_Ion"].plot(label="Ion") +ds["Total_Particle_Energy_Electron"].epoch.plot(label="Electron") +ds["Total_Particle_Energy_Ion"].epoch.plot(label="Ion") plt.title("Particle Energy in Simulation per Species") plt.legend() plt.show() diff --git a/docs/unit_conversion.md b/docs/unit_conversion.md index b98f2d7..113c530 100644 --- a/docs/unit_conversion.md +++ b/docs/unit_conversion.md @@ -43,10 +43,10 @@ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf") ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) -ds["Derived_Number_Density_Electron"].isel(time=0).plot(ax=ax1, x="X_Grid_mid", y="Y_Grid_mid") +ds["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax1) ax1.set_title("Original X Coordinate (m)") -ds_in_microns["Derived_Number_Density_Electron"].isel(time=0).plot(ax=ax2, x="X_Grid_mid", y="Y_Grid_mid") +ds_in_microns["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax2) ax2.set_title("Rescaled X Coordinate (µm)") fig.tight_layout() @@ -197,9 +197,9 @@ converted side by side: ```{code-cell} ipython3 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16,8)) -ds["Total_Particle_Energy_Electron"].plot(ax=ax1) -total_particle_energy_ev.plot(ax=ax2) -total_particle_energy_w.plot(ax=ax3) +ds["Total_Particle_Energy_Electron"].epoch.plot(ax=ax1) +total_particle_energy_ev.epoch.plot(ax=ax2) +total_particle_energy_w.epoch.plot(ax=ax3) ax4.set_visible(False) fig.suptitle("Comparison of conversion from Joules to electron volts and watts") fig.tight_layout() From f7454bb08abb072dcf168e9313b39e94524b2075 Mon Sep 17 00:00:00 2001 From: Joel Adams Date: Wed, 18 Mar 2026 11:55:10 +0000 Subject: [PATCH 5/5] add pytest fixture for subplots to avoid figures not getting deleted after tests --- tests/test_epoch_dataarray_accessor.py | 39 ++++++++++++++------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tests/test_epoch_dataarray_accessor.py b/tests/test_epoch_dataarray_accessor.py index ffccb26..d9b6513 100644 --- a/tests/test_epoch_dataarray_accessor.py +++ b/tests/test_epoch_dataarray_accessor.py @@ -25,6 +25,13 @@ TEST_FILES_DIR_3D = download.fetch_dataset("test_files_3D") +@pytest.fixture +def subplots(): + fig, ax = plt.subplots() + yield (fig, ax) + plt.close(fig) + + def test_animation_accessor(): array = xr.DataArray( [1, 2, 3], @@ -238,22 +245,21 @@ def test_compute_global_limits_NaNs(): assert result_max == pytest.approx(expected_result_max, abs=1e-1) -def test_epoch_plot_simple_1d_dataset(): +def test_epoch_plot_simple_1d_dataset(subplots): with xr.open_mfdataset( TEST_FILES_DIR_1D.glob("*.sdf"), compat="no_conflicts", join="outer", preprocess=SDFPreprocess(), ) as ds: - fig, ax = plt.subplots() + _, ax = subplots ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) assert len(ax.lines) == 1 assert ax.get_xlabel() == "X [m]" - plt.close(fig) -def test_epoch_plot_simple_2d_dataset(): +def test_epoch_plot_simple_2d_dataset(subplots): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -261,27 +267,25 @@ def test_epoch_plot_simple_2d_dataset(): compat="no_conflicts", join="outer", ) as ds: - fig, ax = plt.subplots() + _, ax = subplots ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) assert len(ax.collections) > 0 assert ax.get_xlabel() == "X [m]" assert ax.get_ylabel() == "Y [m]" - plt.close(fig) -def test_epoch_plot_simple_3d_dataset_slice(): +def test_epoch_plot_simple_3d_dataset_slice(subplots): with xr.open_dataset(TEST_FILES_DIR_3D / "0001.sdf") as ds: - fig, ax = plt.subplots() + _, ax = subplots ds["Derived_Number_Density_Electron"].isel(Z_Grid_mid=0).epoch.plot(ax=ax) assert len(ax.collections) > 0 assert ax.get_xlabel() == "X [m]" assert ax.get_ylabel() == "Y [m]" - plt.close(fig) -def test_epoch_plot_flips_axis_order_for_2d_data(): +def test_epoch_plot_flips_axis_order_for_2d_data(subplots): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -289,15 +293,14 @@ def test_epoch_plot_flips_axis_order_for_2d_data(): compat="no_conflicts", join="outer", ) as ds: - fig, ax = plt.subplots() + _, ax = subplots ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax) assert ax.get_xlabel() == "X [m]" assert ax.get_ylabel() == "Y [m]" - plt.close(fig) -def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(): +def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(subplots): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -305,7 +308,7 @@ def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(): compat="no_conflicts", join="outer", ) as ds: - fig, ax = plt.subplots() + _, ax = subplots ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot( ax=ax, xlim=(0.5, 1.0), @@ -316,10 +319,11 @@ def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(): assert ax.get_ylabel() == "Y [m]" assert ax.get_xlim() == pytest.approx((0.5, 1.0), abs=1e-2) assert ax.get_ylim() == pytest.approx((0.0, 0.5), abs=1e-2) - plt.close(fig) -def test_epoch_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present(): +def test_epoch_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present( + subplots, +): with xr.open_mfdataset( TEST_FILES_DIR_2D_MW.glob("*.sdf"), preprocess=SDFPreprocess(), @@ -327,8 +331,7 @@ def test_epoch_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present() compat="no_conflicts", join="outer", ) as ds: - fig, ax = plt.subplots() + _, ax = subplots plot = ds["Derived_Number_Density_electron"].epoch.plot(ax=ax) assert type(plot[2]) is BarContainer - plt.close(fig)