Skip to content

Commit caf62d3

Browse files
authored
Improved duck array wrapping (#9798)
* lots more duck array compat, plus tests * merge sliding_window_view * namespaces constant * revert dask allowed * fix up some tests * backwards compat sparse mask * add as_array methods * to_like_array helper * only cast non-numpy * better idxminmax approach * fix mypy * naming, add is_array_type * add public doc and whats new * update comments * add support for chunked arrays in as_array_type * revert array_type methods * fix up whats new * comment about bool_ * add jax to complete ci envs * add pint and sparse to tests * remove from windows * mypy, xfail one more sparse * add dask and a few other methods * move whats new
1 parent a765ae0 commit caf62d3

15 files changed

+730
-103
lines changed

ci/requirements/environment-3.13.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,5 @@ dependencies:
4747
- toolz
4848
- typing_extensions
4949
- zarr
50+
- pip:
51+
- jax # no way to get cpu-only jaxlib from conda if gpu is present

ci/requirements/environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ dependencies:
4949
- toolz
5050
- typing_extensions
5151
- zarr
52+
- pip:
53+
- jax # no way to get cpu-only jaxlib from conda if gpu is present

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ v.2024.11.1 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24+
- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized
25+
duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`).
26+
By `Sam Levang <https://github.com/slevang>`_.
2427

2528

2629
Breaking changes

xarray/core/array_api_compat.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from xarray.namedarray.pycompat import array_type
4+
35

46
def is_weak_scalar_type(t):
57
return isinstance(t, bool | int | float | complex | str | bytes)
@@ -42,3 +44,39 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype:
4244
return xp.result_type(*arrays_and_dtypes)
4345
else:
4446
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
47+
48+
49+
def get_array_namespace(*values):
50+
def _get_single_namespace(x):
51+
if hasattr(x, "__array_namespace__"):
52+
return x.__array_namespace__()
53+
elif isinstance(x, array_type("cupy")):
54+
# cupy is fully compliant from xarray's perspective, but will not expose
55+
# __array_namespace__ until at least v14. Special case it for now
56+
import cupy as cp
57+
58+
return cp
59+
else:
60+
return np
61+
62+
namespaces = {_get_single_namespace(t) for t in values}
63+
non_numpy = namespaces - {np}
64+
65+
if len(non_numpy) > 1:
66+
names = [module.__name__ for module in non_numpy]
67+
raise TypeError(f"Mixed array types {names} are not supported.")
68+
elif non_numpy:
69+
[xp] = non_numpy
70+
else:
71+
xp = np
72+
73+
return xp
74+
75+
76+
def to_like_array(array, like):
77+
# Mostly for cupy compatibility, because cupy binary ops require all cupy arrays
78+
xp = get_array_namespace(like)
79+
if xp is not np:
80+
return xp.asarray(array)
81+
# avoid casting things like pint quantities to numpy arrays
82+
return array

xarray/core/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def clip(
496496
keep_attrs = _get_keep_attrs(default=True)
497497

498498
return apply_ufunc(
499-
np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
499+
duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
500500
)
501501

502502
def get_index(self, key: Hashable) -> pd.Index:
@@ -1760,7 +1760,7 @@ def _full_like_variable(
17601760
**from_array_kwargs,
17611761
)
17621762
else:
1763-
data = np.full_like(other.data, fill_value, dtype=dtype)
1763+
data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype)
17641764

17651765
return Variable(dims=other.dims, data=data, attrs=other.attrs)
17661766

xarray/core/computation.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from xarray.core import dtypes, duck_array_ops, utils
2626
from xarray.core.alignment import align, deep_align
27+
from xarray.core.array_api_compat import to_like_array
2728
from xarray.core.common import zeros_like
2829
from xarray.core.duck_array_ops import datetime_to_numeric
2930
from xarray.core.formatting import limit_lines
@@ -1702,7 +1703,7 @@ def cross(
17021703
)
17031704

17041705
c = apply_ufunc(
1705-
np.cross,
1706+
duck_array_ops.cross,
17061707
a,
17071708
b,
17081709
input_core_dims=[[dim], [dim]],
@@ -2170,13 +2171,14 @@ def _calc_idxminmax(
21702171
chunks = dict(zip(array.dims, array.chunks, strict=True))
21712172
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
21722173
data = dask_coord[duck_array_ops.ravel(indx.data)]
2173-
res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
2174-
# we need to attach back the dim name
2175-
res.name = dim
21762174
else:
2177-
res = array[dim][(indx,)]
2178-
# The dim is gone but we need to remove the corresponding coordinate.
2179-
del res.coords[dim]
2175+
arr_coord = to_like_array(array[dim].data, array.data)
2176+
data = arr_coord[duck_array_ops.ravel(indx.data)]
2177+
2178+
# rebuild like the argmin/max output, and rename as the dim name
2179+
data = duck_array_ops.reshape(data, indx.shape)
2180+
res = indx.copy(data=data)
2181+
res.name = dim
21802182

21812183
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
21822184
# Put the NaN values back in after removing them

xarray/core/dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
align,
5656
)
5757
from xarray.core.arithmetic import DatasetArithmetic
58+
from xarray.core.array_api_compat import to_like_array
5859
from xarray.core.common import (
5960
DataWithCoords,
6061
_contains_datetime_like_objects,
@@ -127,7 +128,7 @@
127128
calculate_dimensions,
128129
)
129130
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
130-
from xarray.namedarray.pycompat import array_type, is_chunked_array
131+
from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy
131132
from xarray.plot.accessor import DatasetPlotAccessor
132133
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
133134

@@ -6620,7 +6621,7 @@ def dropna(
66206621
array = self._variables[k]
66216622
if dim in array.dims:
66226623
dims = [d for d in array.dims if d != dim]
6623-
count += np.asarray(array.count(dims))
6624+
count += to_numpy(array.count(dims).data)
66246625
size += math.prod([self.sizes[d] for d in dims])
66256626

66266627
if thresh is not None:
@@ -8734,16 +8735,17 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False):
87348735
coord_names.add(k)
87358736
else:
87368737
if k in self.data_vars and dim in v.dims:
8738+
coord_data = to_like_array(coord_var.data, like=v.data)
87378739
if _contains_datetime_like_objects(v):
87388740
v = datetime_to_numeric(v, datetime_unit=datetime_unit)
87398741
if cumulative:
87408742
integ = duck_array_ops.cumulative_trapezoid(
8741-
v.data, coord_var.data, axis=v.get_axis_num(dim)
8743+
v.data, coord_data, axis=v.get_axis_num(dim)
87428744
)
87438745
v_dims = v.dims
87448746
else:
87458747
integ = duck_array_ops.trapz(
8746-
v.data, coord_var.data, axis=v.get_axis_num(dim)
8748+
v.data, coord_data, axis=v.get_axis_num(dim)
87478749
)
87488750
v_dims = list(v.dims)
87498751
v_dims.remove(dim)

0 commit comments

Comments
 (0)