Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions docs/key_functionality.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,18 @@ ds = sdfxr.open_mfdataset("tutorial_dataset_1d/*.sdf")
ds["Electric_Field_Ex"]
```

On top of accessing variables you can plot these <inv:#xarray.Dataset>
using the built-in <inv:#xarray.DataArray.plot> function (see
<https://docs.xarray.dev/en/stable/user-guide/plotting.html>) which is
a simple call to `matplotlib`. This also means that you can access
all the methods from `matplotlib` to manipulate your plot.
On top of accessing variables, you can plot these datasets using
[`xarray.DataArray.epoch.plot`](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot).
This is a custom <project:#sdf-xarray> plotting routine that builds on top of
<inv:#xarray.DataArray.plot>, so you keep the familiar <inv:#xarray> plotting
behaviour while using <project:#sdf-xarray> conveniences (see
[here](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot) for details).
Under the hood, plotting is still handled by <inv:#matplotlib>, which means you
can use the full <inv:#matplotlib> API to customise your figure.

```{code-cell} ipython3
# This is discretized in both space and time
ds["Electric_Field_Ex"].plot()
ds["Electric_Field_Ex"].epoch.plot()
plt.title("Electric field along the x-axis")
plt.show()
```
Expand All @@ -231,7 +234,6 @@ done by passsing the index to the `time` parameter (e.g., `time=0` for
the first snapshot).

```{code-cell} ipython3
# We can plot the variable at a given time index
ds["Electric_Field_Ex"].isel(time=20)
```

Expand Down Expand Up @@ -296,17 +298,17 @@ ds["Laser_Absorption_Fraction_in_Simulation"] = (
ds["Laser_Absorption_Fraction_in_Simulation"].attrs["units"] = "%"
ds["Laser_Absorption_Fraction_in_Simulation"].attrs["long_name"] = "Laser Absorption Fraction"

ds["Laser_Absorption_Fraction_in_Simulation"].plot()
ds["Laser_Absorption_Fraction_in_Simulation"].epoch.plot()
plt.title("Laser absorption fraction in simulation")
plt.show()
```

You can also call the `plot()` function on several variables with
You can also call the [`xarray.DataArray.epoch.plot`](project:#sdf_xarray.dataarray_accessor.EpochAccessor.plot) function on several variables with
labels by delaying the call to `plt.show()`.

```{code-cell} ipython3
ds["Total_Particle_Energy_Electron"].plot(label="Electron")
ds["Total_Particle_Energy_Ion"].plot(label="Ion")
ds["Total_Particle_Energy_Electron"].epoch.plot(label="Electron")
ds["Total_Particle_Energy_Ion"].epoch.plot(label="Ion")
plt.title("Particle Energy in Simulation per Species")
plt.legend()
plt.show()
Expand Down
10 changes: 5 additions & 5 deletions docs/unit_conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf")
ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"])

ds["Derived_Number_Density_Electron"].isel(time=0).plot(ax=ax1, x="X_Grid_mid", y="Y_Grid_mid")
ds["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax1)
ax1.set_title("Original X Coordinate (m)")

ds_in_microns["Derived_Number_Density_Electron"].isel(time=0).plot(ax=ax2, x="X_Grid_mid", y="Y_Grid_mid")
ds_in_microns["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax2)
ax2.set_title("Rescaled X Coordinate (µm)")

fig.tight_layout()
Expand Down Expand Up @@ -197,9 +197,9 @@ converted <inv:#xarray.Dataset> side by side:

```{code-cell} ipython3
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16,8))
ds["Total_Particle_Energy_Electron"].plot(ax=ax1)
total_particle_energy_ev.plot(ax=ax2)
total_particle_energy_w.plot(ax=ax3)
ds["Total_Particle_Energy_Electron"].epoch.plot(ax=ax1)
total_particle_energy_ev.epoch.plot(ax=ax2)
total_particle_energy_w.epoch.plot(ax=ax3)
ax4.set_visible(False)
fig.suptitle("Comparison of conversion from Joules to electron volts and watts")
fig.tight_layout()
Expand Down
1 change: 1 addition & 0 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# NOTE: Do not delete these lines, otherwise the "epoch" dataset and dataarray
# accessors will not be imported when the user imports sdf_xarray
import sdf_xarray.dataarray_accessor
import sdf_xarray.dataset_accessor
import sdf_xarray.download
import sdf_xarray.plotting # noqa: F401
Expand Down
69 changes: 69 additions & 0 deletions src/sdf_xarray/dataarray_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from types import MethodType

import xarray as xr
from matplotlib.animation import FuncAnimation
from xarray.plot.accessor import DataArrayPlotAccessor

from .plotting import animate, show


@xr.register_dataarray_accessor("epoch")
class EpochAccessor:
def __init__(self, xarray_obj: xr.DataArray):
self._obj = xarray_obj

def plot(self, *args, **kwargs) -> DataArrayPlotAccessor:
"""
Builds upon `xarray.DataArray.plot` while changing some of its default behaviours.

These changes are:

- Flips the default axes order for 2D plots so that x and y are on the correct axes.
This exists because plotting of 2D data in xarray uses the `xarray.plot.pcolormesh`
function which takes assumes that ``x = dim[1]`` and ``y = dim[0]``.

Parameters
----------
args
Positional arguments passed to `xarray.DataArray.plot`.
kwargs
Keyword arguments passed to `xarray.DataArray.plot`.
"""
dims = self._obj.dims
is_not_2d_data = len(dims) != 2
is_time_dim_present = "time" in dims
is_x_or_y_specified_in_kwargs = "x" in kwargs or "y" in kwargs

if is_not_2d_data or is_time_dim_present or is_x_or_y_specified_in_kwargs:
return self._obj.plot(*args, **kwargs)

updated_kwargs = dict(kwargs)
updated_kwargs.setdefault("x", dims[0])
updated_kwargs.setdefault("y", dims[1])

return self._obj.plot(*args, **updated_kwargs)

def animate(self, *args, **kwargs) -> FuncAnimation:
"""Generate animations of Epoch data.

Parameters
----------
args
Positional arguments passed to :func:`animation`.
kwargs
Keyword arguments passed to :func:`animation`.

Examples
--------
>>> anim = ds["Electric_Field_Ey"].epoch.animate()
>>> anim.save("animation.gif")
>>> # Or in a jupyter notebook:
>>> anim.show()
"""

# Add anim.show() functionality
# anim.show() will display the animation in a jupyter notebook
anim = animate(self._obj, *args, **kwargs)
anim.show = MethodType(show, anim)

return anim
32 changes: 0 additions & 32 deletions src/sdf_xarray/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from types import MethodType
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -534,34 +533,3 @@ def show(anim):
from IPython.display import HTML # noqa: PLC0415

return HTML(anim.to_jshtml())


@xr.register_dataarray_accessor("epoch")
class EpochAccessor:
def __init__(self, xarray_obj):
self._obj = xarray_obj

def animate(self, *args, **kwargs) -> FuncAnimation:
"""Generate animations of Epoch data.

Parameters
----------
args
Positional arguments passed to :func:`animation`.
kwargs
Keyword arguments passed to :func:`animation`.

Examples
--------
>>> anim = ds["Electric_Field_Ey"].epoch.animate()
>>> anim.save("animation.gif")
>>> # Or in a jupyter notebook:
>>> anim.show()
"""

# Add anim.show() functionality
# anim.show() will display the animation in a jupyter notebook
anim = animate(self._obj, *args, **kwargs)
anim.show = MethodType(show, anim)

return anim
101 changes: 101 additions & 0 deletions tests/test_epoch_dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from importlib.metadata import version

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pytest
import xarray as xr
from matplotlib.animation import PillowWriter
from matplotlib.container import BarContainer
from packaging.version import Version

import sdf_xarray as sdfxr
Expand All @@ -23,6 +25,13 @@
TEST_FILES_DIR_3D = download.fetch_dataset("test_files_3D")


@pytest.fixture
def subplots():
fig, ax = plt.subplots()
yield (fig, ax)
plt.close(fig)


def test_animation_accessor():
array = xr.DataArray(
[1, 2, 3],
Expand Down Expand Up @@ -234,3 +243,95 @@ def test_compute_global_limits_NaNs():
expected_result_max = 2.70
assert result_min == pytest.approx(expected_result_min, abs=1e-2)
assert result_max == pytest.approx(expected_result_max, abs=1e-1)


def test_epoch_plot_simple_1d_dataset(subplots):
with xr.open_mfdataset(
TEST_FILES_DIR_1D.glob("*.sdf"),
compat="no_conflicts",
join="outer",
preprocess=SDFPreprocess(),
) as ds:
_, ax = subplots
ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax)

assert len(ax.lines) == 1
assert ax.get_xlabel() == "X [m]"


def test_epoch_plot_simple_2d_dataset(subplots):
with xr.open_mfdataset(
TEST_FILES_DIR_2D_MW.glob("*.sdf"),
preprocess=SDFPreprocess(),
combine="nested",
compat="no_conflicts",
join="outer",
) as ds:
_, ax = subplots
ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax)

assert len(ax.collections) > 0
assert ax.get_xlabel() == "X [m]"
assert ax.get_ylabel() == "Y [m]"


def test_epoch_plot_simple_3d_dataset_slice(subplots):
with xr.open_dataset(TEST_FILES_DIR_3D / "0001.sdf") as ds:
_, ax = subplots
ds["Derived_Number_Density_Electron"].isel(Z_Grid_mid=0).epoch.plot(ax=ax)

assert len(ax.collections) > 0
assert ax.get_xlabel() == "X [m]"
assert ax.get_ylabel() == "Y [m]"


def test_epoch_plot_flips_axis_order_for_2d_data(subplots):
with xr.open_mfdataset(
TEST_FILES_DIR_2D_MW.glob("*.sdf"),
preprocess=SDFPreprocess(),
combine="nested",
compat="no_conflicts",
join="outer",
) as ds:
_, ax = subplots
ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(ax=ax)

assert ax.get_xlabel() == "X [m]"
assert ax.get_ylabel() == "Y [m]"


def test_epoch_plot_flips_axis_order_for_2d_data_with_additional_params(subplots):
with xr.open_mfdataset(
TEST_FILES_DIR_2D_MW.glob("*.sdf"),
preprocess=SDFPreprocess(),
combine="nested",
compat="no_conflicts",
join="outer",
) as ds:
_, ax = subplots
ds["Derived_Number_Density_electron"].isel(time=0).epoch.plot(
ax=ax,
xlim=(0.5, 1.0),
ylim=(0.0, 0.5),
)

assert ax.get_xlabel() == "X [m]"
assert ax.get_ylabel() == "Y [m]"
assert ax.get_xlim() == pytest.approx((0.5, 1.0), abs=1e-2)
assert ax.get_ylim() == pytest.approx((0.0, 0.5), abs=1e-2)


def test_epoch_plot_flips_axis_order_for_2d_data_but_not_when_time_dim_present(
subplots,
):
with xr.open_mfdataset(
TEST_FILES_DIR_2D_MW.glob("*.sdf"),
preprocess=SDFPreprocess(),
combine="nested",
compat="no_conflicts",
join="outer",
) as ds:
_, ax = subplots
plot = ds["Derived_Number_Density_electron"].epoch.plot(ax=ax)

assert type(plot[2]) is BarContainer
Loading