Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 214 additions & 2 deletions src/skyborn/plot/curved_quiver_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import TYPE_CHECKING, Literal

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# import numpy as np
# import matplotlib
from matplotlib.artist import Artist, allow_rasterization
from matplotlib.backend_bases import RendererBase
Expand All @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from matplotlib.axes import Axes

__all__ = ["curved_quiver", "add_curved_quiverkey"]
__all__ = ["curved_quiver", "add_curved_quiverkey", "curved_quiver_numpy", "curved_quiver_profile"]


def curved_quiver(
Expand Down Expand Up @@ -690,3 +690,215 @@ def add_curved_quiverkey(
return CurvedQuiverLegend(
ax, curved_quiver_set, U, units=units, loc=loc, labelpos=labelpos, **kwargs
)


def curved_quiver_numpy(
x: np.ndarray,
y: np.ndarray,
u: np.ndarray,
v: np.ndarray,
ax: Axes | None = None,
density=1,
linewidth=None,
color=None,
cmap=None,
norm=None,
arrowsize=1,
arrowstyle="-|>",
transform=None,
zorder=None,
start_points=None,
integration_direction="both",
grains=15,
broken_streamlines=True,
interpolate_irregular=True,
nx=None,
ny=None,
interpolation_method="linear",
) -> CurvedQuiverplotSet:
"""
Plot curved arrows using numpy arrays directly (xarray not required).

.. warning::

This function is experimental and the API is subject to change.

Parameters
----------
x, y : 1D arrays
Coordinate arrays for the grid. Can be irregularly spaced.
u, v : 2D arrays
Velocity components. Shape must match (len(y), len(x)).
ax : matplotlib.axes.Axes, optional
Axes on which to plot. By default, use the current axes.
density : float or (float, float)
Controls the closeness of streamlines. When ``density = 1``, the domain
is divided into a 30x30 grid. *density* linearly scales this grid.
linewidth : float or 2D array
The width of the streamlines.
color : color or 2D array
The streamline color.
cmap, norm
Data normalization and colormapping parameters for *color*.
arrowsize : float
Scaling factor for the arrow size.
arrowstyle : str
Arrow style specification. See `~matplotlib.patches.FancyArrowPatch`.
transform : Transform, optional
Coordinate transformation for the plot.
zorder : float
The zorder of the streamlines and arrows.
start_points : (N, 2) array
Coordinates of starting points for the streamlines.
integration_direction : {'forward', 'backward', 'both'}, default: 'both'
Integrate the streamline in forward, backward or both directions.
grains : int, default: 15
Number of grains used in streamline integration.
broken_streamlines : boolean, default: True
If False, forces streamlines to continue until they leave the plot domain.
interpolate_irregular : boolean, default: True
If True, interpolate irregular grids to regular grid. If False, raise error.
nx, ny : int, optional
Target grid resolution for interpolation. If None, use original resolution.
interpolation_method : str, default: 'linear'
Interpolation method: 'linear', 'nearest', or 'cubic'.

Returns
-------
CurvedQuiverplotSet
Container object with lines and arrows.

Example
-------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> x = np.linspace(0, 10, 50)
>>> y = np.linspace(0, 5, 25)
>>> X, Y = np.meshgrid(x, y)
>>> U = np.sin(X) * np.cos(Y)
>>> V = np.cos(X) * np.sin(Y)
>>> fig, ax = plt.subplots()
>>> curved_quiver_numpy(x, y, U, V, ax=ax)
"""
from .modplot import velovect, _is_regular_grid, _interpolate_to_regular_grid

if ax is None:
ax = plt.gca()

if type(transform).__name__ == "PlateCarree":
transform = transform._as_mpl_transform(ax)

x = np.asarray(x)
y = np.asarray(y)
u = np.asarray(u)
v = np.asarray(v)

if x.ndim != 1 or y.ndim != 1:
raise ValueError("'x' and 'y' must be 1D arrays")

if u.shape != (len(y), len(x)) or v.shape != (len(y), len(x)):
raise ValueError(
f"'u' and 'v' must have shape ({len(y)}, {len(x)}), "
f"got {u.shape} and {v.shape}"
)

if _is_regular_grid(x, y):
x_plot, y_plot, u_plot, v_plot = x, y, u, v
else:
if not interpolate_irregular:
raise ValueError(
"Grid is irregularly spaced. Set interpolate_irregular=True "
"to interpolate to a regular grid."
)
x_plot, y_plot, u_plot, v_plot = _interpolate_to_regular_grid(
x, y, u, v, nx=nx, ny=ny, method=interpolation_method
)

obj = velovect(
ax,
x_plot,
y_plot,
u_plot,
v_plot,
density=density,
linewidth=linewidth,
color=color,
cmap=cmap,
norm=norm,
arrowsize=arrowsize,
arrowstyle=arrowstyle,
transform=transform,
zorder=zorder,
start_points=start_points,
integration_direction=integration_direction,
grains=grains,
broken_streamlines=broken_streamlines,
)
return obj


def curved_quiver_profile(
distance: np.ndarray,
height: np.ndarray,
u: np.ndarray,
v: np.ndarray,
ax: Axes | None = None,
distance_units: str = "km",
height_units: str = "m",
**kwargs,
) -> CurvedQuiverplotSet:
"""
Plot curved arrows on a vertical profile/cross-section.

This function is designed for vertical cross-section data where the
coordinates represent distance along a profile and vertical height/depth.

.. warning::

This function is experimental and the API is subject to change.

Parameters
----------
distance : 1D array
Distance along the profile (e.g., longitude, or actual distance in km).
height : 1D array
Vertical coordinate (e.g., altitude, pressure level, depth).
u, v : 2D arrays
Velocity components. Shape must match (len(height), len(distance)).
ax : matplotlib.axes.Axes, optional
Axes on which to plot. By default, use the current axes.
distance_units : str, default: 'km'
Unit string for distance axis (for labeling purposes).
height_units : str, default: 'm'
Unit string for height axis (for labeling purposes).
**kwargs
Additional keyword arguments passed to `curved_quiver_numpy`.

Returns
-------
CurvedQuiverplotSet
Container object with lines and arrows.

Example
-------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> distance = np.linspace(0, 100, 50) # km
>>> height = np.linspace(0, 20, 30) # km
>>> H, D = np.meshgrid(height, distance, indexing='ij')
>>> U = np.sin(H / 5) * np.cos(D / 20)
>>> V = np.cos(H / 5) * np.sin(D / 20)
>>> fig, ax = plt.subplots()
>>> curved_quiver_profile(distance, height, U, V, ax=ax)
>>> ax.set_xlabel(f'Distance ({distance_units})')
>>> ax.set_ylabel(f'Height ({height_units})')
"""
obj = curved_quiver_numpy(
x=distance,
y=height,
u=u,
v=v,
ax=ax,
**kwargs,
)
return obj
72 changes: 71 additions & 1 deletion src/skyborn/plot/modplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
from matplotlib import cm, patches

__all__ = ["velovect"]
__all__ = ["velovect", "_is_regular_grid", "_interpolate_to_regular_grid"]


def velovect(
Expand Down Expand Up @@ -789,3 +789,73 @@ def _gen_starting_points(x, y, grains):
seed_points = np.array([xs, ys])

return seed_points.T


def _is_regular_grid(x, y, rtol=1e-5, atol=1e-8):
"""
Check if grid is regular (equally spaced).

Parameters
----------
x, y : 1D arrays
Coordinate arrays
rtol, atol : float
Relative and absolute tolerances for np.allclose

Returns
-------
bool
True if grid is regular, False otherwise
"""
x_regular = np.allclose(np.diff(x), (x[-1] - x[0]) / (len(x) - 1), rtol=rtol, atol=atol)
y_regular = np.allclose(np.diff(y), (y[-1] - y[0]) / (len(y) - 1), rtol=rtol, atol=atol)
return x_regular and y_regular


def _interpolate_to_regular_grid(x, y, u, v, nx=None, ny=None, method='linear'):
"""
Interpolate irregular grid data to a regular grid.

Parameters
----------
x, y : 1D arrays
Original coordinate arrays (can be irregular)
u, v : 2D arrays
Velocity fields on original grid
nx, ny : int, optional
Target grid resolution. If None, use original resolution.
method : str
Interpolation method: 'linear' (default), 'nearest', 'cubic'

Returns
-------
x_reg, y_reg : 1D arrays
Regular coordinate arrays
u_reg, v_reg : 2D arrays
Interpolated velocity fields on regular grid
"""
try:
from scipy.interpolate import griddata
except ImportError:
raise ImportError(
"scipy is required for interpolation on irregular grids. "
"Install with: pip install scipy"
)

if nx is None:
nx = len(x)
if ny is None:
ny = len(y)

X, Y = np.meshgrid(x, y, indexing='xy')

points = np.column_stack([X.ravel(), Y.ravel()])

xi = np.linspace(x.min(), x.max(), nx)
yi = np.linspace(y.min(), y.max(), ny)
Xi, Yi = np.meshgrid(xi, yi, indexing='xy')

u_reg = griddata(points, u.ravel(), (Xi, Yi), method=method)
v_reg = griddata(points, v.ravel(), (Xi, Yi), method=method)

return xi, yi, u_reg, v_reg
Loading