Skip to content

Commit 5d70f4d

Browse files
slevangshoyer
andauthored
Namespace-aware xarray.ufuncs (#9776)
* initial namespace-aware implementation * use np subclass, test duck dask arrays * remove dask special casing and numpy fallback * add isnat * hard code the supported ufuncs * handle np versions, separate unary/binary path * explicit unary/binary creators * add to api docs * add whats new * move numpy version check to tests * fix docs for aliased np funcs * fix whats new --------- Co-authored-by: Stephan Hoyer <[email protected]>
1 parent fa62c2d commit 5d70f4d

File tree

7 files changed

+598
-2
lines changed

7 files changed

+598
-2
lines changed

doc/api.rst

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,120 @@ Methods copied from :py:class:`numpy.ndarray` objects, here applying to the data
894894
.. DataTree.sortby
895895
.. DataTree.broadcast_like
896896
897+
Universal functions
898+
===================
899+
900+
These functions are equivalent to their NumPy versions, but for xarray
901+
objects backed by non-NumPy array types (e.g. ``cupy``, ``sparse``, or ``jax``),
902+
they will ensure that the computation is dispatched to the appropriate
903+
backend. You can find them in the ``xarray.ufuncs`` module:
904+
905+
.. autosummary::
906+
:toctree: generated/
907+
908+
ufuncs.abs
909+
ufuncs.absolute
910+
ufuncs.acos
911+
ufuncs.acosh
912+
ufuncs.arccos
913+
ufuncs.arccosh
914+
ufuncs.arcsin
915+
ufuncs.arcsinh
916+
ufuncs.arctan
917+
ufuncs.arctanh
918+
ufuncs.asin
919+
ufuncs.asinh
920+
ufuncs.atan
921+
ufuncs.atanh
922+
ufuncs.bitwise_count
923+
ufuncs.bitwise_invert
924+
ufuncs.bitwise_not
925+
ufuncs.cbrt
926+
ufuncs.ceil
927+
ufuncs.conj
928+
ufuncs.conjugate
929+
ufuncs.cos
930+
ufuncs.cosh
931+
ufuncs.deg2rad
932+
ufuncs.degrees
933+
ufuncs.exp
934+
ufuncs.exp2
935+
ufuncs.expm1
936+
ufuncs.fabs
937+
ufuncs.floor
938+
ufuncs.invert
939+
ufuncs.isfinite
940+
ufuncs.isinf
941+
ufuncs.isnan
942+
ufuncs.isnat
943+
ufuncs.log
944+
ufuncs.log10
945+
ufuncs.log1p
946+
ufuncs.log2
947+
ufuncs.logical_not
948+
ufuncs.negative
949+
ufuncs.positive
950+
ufuncs.rad2deg
951+
ufuncs.radians
952+
ufuncs.reciprocal
953+
ufuncs.rint
954+
ufuncs.sign
955+
ufuncs.signbit
956+
ufuncs.sin
957+
ufuncs.sinh
958+
ufuncs.spacing
959+
ufuncs.sqrt
960+
ufuncs.square
961+
ufuncs.tan
962+
ufuncs.tanh
963+
ufuncs.trunc
964+
ufuncs.add
965+
ufuncs.arctan2
966+
ufuncs.atan2
967+
ufuncs.bitwise_and
968+
ufuncs.bitwise_left_shift
969+
ufuncs.bitwise_or
970+
ufuncs.bitwise_right_shift
971+
ufuncs.bitwise_xor
972+
ufuncs.copysign
973+
ufuncs.divide
974+
ufuncs.equal
975+
ufuncs.float_power
976+
ufuncs.floor_divide
977+
ufuncs.fmax
978+
ufuncs.fmin
979+
ufuncs.fmod
980+
ufuncs.gcd
981+
ufuncs.greater
982+
ufuncs.greater_equal
983+
ufuncs.heaviside
984+
ufuncs.hypot
985+
ufuncs.lcm
986+
ufuncs.ldexp
987+
ufuncs.left_shift
988+
ufuncs.less
989+
ufuncs.less_equal
990+
ufuncs.logaddexp
991+
ufuncs.logaddexp2
992+
ufuncs.logical_and
993+
ufuncs.logical_or
994+
ufuncs.logical_xor
995+
ufuncs.maximum
996+
ufuncs.minimum
997+
ufuncs.mod
998+
ufuncs.multiply
999+
ufuncs.nextafter
1000+
ufuncs.not_equal
1001+
ufuncs.pow
1002+
ufuncs.power
1003+
ufuncs.remainder
1004+
ufuncs.right_shift
1005+
ufuncs.subtract
1006+
ufuncs.true_divide
1007+
ufuncs.angle
1008+
ufuncs.isreal
1009+
ufuncs.iscomplex
1010+
8971011
IO / Conversion
8981012
===============
8991013

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ New Features
4141
- Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with
4242
arrays with more than two dimensions.
4343
(:issue:`5629`). By `Deepak Cherian <https://github.com/dcherian>`_.
44+
- Re-implement the :py:mod:`ufuncs` module, which now dynamically dispatches to the
45+
underlying array's backend. Provides better support for certain wrapped array types
46+
like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`).
47+
By `Sam Levang <https://github.com/slevang>`_.
4448

4549
Breaking changes
4650
~~~~~~~~~~~~~~~~

xarray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from importlib.metadata import version as _version
22

3-
from xarray import groupers, testing, tutorial
3+
from xarray import groupers, testing, tutorial, ufuncs
44
from xarray.backends.api import (
55
load_dataarray,
66
load_dataset,
@@ -69,6 +69,7 @@
6969
"groupers",
7070
"testing",
7171
"tutorial",
72+
"ufuncs",
7273
# Top-level functions
7374
"align",
7475
"apply_ufunc",

xarray/tests/test_dask.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212

1313
import xarray as xr
14+
import xarray.ufuncs as xu
1415
from xarray import DataArray, Dataset, Variable
1516
from xarray.core import duck_array_ops
1617
from xarray.core.duck_array_ops import lazy_array_equiv
@@ -274,6 +275,17 @@ def test_bivariate_ufunc(self):
274275
self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(v, 0))
275276
self.assertLazyAndAllClose(np.maximum(u, 0), np.maximum(0, v))
276277

278+
def test_univariate_xufunc(self):
279+
u = self.eager_var
280+
v = self.lazy_var
281+
self.assertLazyAndAllClose(np.sin(u), xu.sin(v))
282+
283+
def test_bivariate_xufunc(self):
284+
u = self.eager_var
285+
v = self.lazy_var
286+
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
287+
self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))
288+
277289
def test_compute(self):
278290
u = self.eager_var
279291
v = self.lazy_var

xarray/tests/test_sparse.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
import xarray as xr
12+
import xarray.ufuncs as xu
1213
from xarray import DataArray, Variable
1314
from xarray.namedarray.pycompat import array_type
1415
from xarray.tests import assert_equal, assert_identical, requires_dask
@@ -294,6 +295,13 @@ def test_bivariate_ufunc(self):
294295
assert_sparse_equal(np.maximum(self.data, 0), np.maximum(self.var, 0).data)
295296
assert_sparse_equal(np.maximum(self.data, 0), np.maximum(0, self.var).data)
296297

298+
def test_univariate_xufunc(self):
299+
assert_sparse_equal(xu.sin(self.var).data, np.sin(self.data))
300+
301+
def test_bivariate_xufunc(self):
302+
assert_sparse_equal(xu.multiply(self.var, 0).data, np.multiply(self.data, 0))
303+
assert_sparse_equal(xu.multiply(0, self.var).data, np.multiply(0, self.data))
304+
297305
def test_repr(self):
298306
expected = dedent(
299307
"""\

xarray/tests/test_ufuncs.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations
22

3+
import pickle
4+
from unittest.mock import patch
5+
36
import numpy as np
47
import pytest
58

69
import xarray as xr
7-
from xarray.tests import assert_allclose, assert_array_equal, mock
10+
import xarray.ufuncs as xu
11+
from xarray.tests import assert_allclose, assert_array_equal, mock, requires_dask
812
from xarray.tests import assert_identical as assert_identical_
913

1014

@@ -155,3 +159,108 @@ def test_gufuncs():
155159
fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin)
156160
with pytest.raises(NotImplementedError, match=r"generalized ufuncs"):
157161
xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj)
162+
163+
164+
class DuckArray(np.ndarray):
165+
# Minimal subclassed duck array with its own self-contained namespace,
166+
# which implements a few ufuncs
167+
def __new__(cls, array):
168+
obj = np.asarray(array).view(cls)
169+
return obj
170+
171+
def __array_namespace__(self):
172+
return DuckArray
173+
174+
@staticmethod
175+
def sin(x):
176+
return np.sin(x)
177+
178+
@staticmethod
179+
def add(x, y):
180+
return x + y
181+
182+
183+
class DuckArray2(DuckArray):
184+
def __array_namespace__(self):
185+
return DuckArray2
186+
187+
188+
class TestXarrayUfuncs:
189+
@pytest.fixture(autouse=True)
190+
def setUp(self):
191+
self.x = xr.DataArray([1, 2, 3])
192+
self.xd = xr.DataArray(DuckArray([1, 2, 3]))
193+
self.xd2 = xr.DataArray(DuckArray2([1, 2, 3]))
194+
self.xt = xr.DataArray(np.datetime64("2021-01-01", "ns"))
195+
196+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
197+
@pytest.mark.parametrize("name", xu.__all__)
198+
def test_ufuncs(self, name, request):
199+
xu_func = getattr(xu, name)
200+
np_func = getattr(np, name, None)
201+
if np_func is None and np.lib.NumpyVersion(np.__version__) < "2.0.0":
202+
pytest.skip(f"Ufunc {name} is not available in numpy {np.__version__}.")
203+
204+
if name == "isnat":
205+
args = (self.xt,)
206+
elif hasattr(np_func, "nin") and np_func.nin == 2:
207+
args = (self.x, self.x)
208+
else:
209+
args = (self.x,)
210+
211+
expected = np_func(*args)
212+
actual = xu_func(*args)
213+
214+
if name in ["angle", "iscomplex"]:
215+
np.testing.assert_equal(expected, actual.values)
216+
else:
217+
assert_identical(actual, expected)
218+
219+
def test_ufunc_pickle(self):
220+
a = 1.0
221+
cos_pickled = pickle.loads(pickle.dumps(xu.cos))
222+
assert_identical(cos_pickled(a), xu.cos(a))
223+
224+
def test_ufunc_scalar(self):
225+
actual = xu.sin(1)
226+
assert isinstance(actual, float)
227+
228+
def test_ufunc_duck_array_dataarray(self):
229+
actual = xu.sin(self.xd)
230+
assert isinstance(actual.data, DuckArray)
231+
232+
def test_ufunc_duck_array_variable(self):
233+
actual = xu.sin(self.xd.variable)
234+
assert isinstance(actual.data, DuckArray)
235+
236+
def test_ufunc_duck_array_dataset(self):
237+
ds = xr.Dataset({"a": self.xd})
238+
actual = xu.sin(ds)
239+
assert isinstance(actual.a.data, DuckArray)
240+
241+
@requires_dask
242+
def test_ufunc_duck_dask(self):
243+
import dask.array as da
244+
245+
x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3]))))
246+
actual = xu.sin(x)
247+
assert isinstance(actual.data._meta, DuckArray)
248+
249+
@requires_dask
250+
@pytest.mark.xfail(reason="dask ufuncs currently dispatch to numpy")
251+
def test_ufunc_duck_dask_no_array_ufunc(self):
252+
import dask.array as da
253+
254+
# dask ufuncs currently only preserve duck arrays that implement __array_ufunc__
255+
with patch.object(DuckArray, "__array_ufunc__", new=None, create=True):
256+
x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3]))))
257+
actual = xu.sin(x)
258+
assert isinstance(actual.data._meta, DuckArray)
259+
260+
def test_ufunc_mixed_arrays_compatible(self):
261+
actual = xu.add(self.xd, self.x)
262+
assert isinstance(actual.data, DuckArray)
263+
264+
def test_ufunc_mixed_arrays_incompatible(self):
265+
with pytest.raises(ValueError, match=r"Mixed array types"):
266+
xu.add(self.xd, self.xd2)

0 commit comments

Comments
 (0)