From 7ce90f8436a1f8ef73413026e6eaabbff810705d Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 7 Jul 2024 19:55:16 +0200 Subject: [PATCH 1/6] use `nested_duck_arrays` to detect `cupy` beneath duck array layers --- xarray/core/duck_array_ops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8993c136ba6..0cbbda397d9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -51,6 +51,13 @@ normalize_axis_index, ) +try: + from nested_duck_arrays import first_layer +except ImportError: + + def first_layer(x): + return type(x) + dask_available = module_available("dask") @@ -268,7 +275,7 @@ def as_shared_dtype(scalars_or_arrays, xp=None): # Avoid calling array_type("cupy") repeatidely in the any check array_type_cupy = array_type("cupy") - if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): + if any(first_layer(x) is array_type_cupy for x in scalars_or_arrays): import cupy as cp xp = cp From e2f0058701a9984afffed68adba99755245abfda Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 7 Jul 2024 20:53:47 +0200 Subject: [PATCH 2/6] typo --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 0cbbda397d9..403c985391c 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -273,7 +273,7 @@ def as_shared_dtype(scalars_or_arrays, xp=None): f" array types {[x.dtype for x in scalars_or_arrays]}" ) - # Avoid calling array_type("cupy") repeatidely in the any check + # Avoid calling array_type("cupy") repeatedly in the any check array_type_cupy = array_type("cupy") if any(first_layer(x) is array_type_cupy for x in scalars_or_arrays): import cupy as cp From c79edde830f8ecded16bdfbc2d3e977cdc8b9136 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 7 Jul 2024 20:53:56 +0200 Subject: [PATCH 3/6] use `issubclass` instead --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 403c985391c..f23125ee113 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -275,7 +275,7 @@ def as_shared_dtype(scalars_or_arrays, xp=None): # Avoid calling array_type("cupy") repeatedly in the any check array_type_cupy = array_type("cupy") - if any(first_layer(x) is array_type_cupy for x in scalars_or_arrays): + if any(issubclass(first_layer(x), array_type_cupy) for x in scalars_or_arrays): import cupy as cp xp = cp From 2c03fb3be5431bbc7a133cdfb1bc8c430adaf244 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 7 Jul 2024 21:03:17 +0200 Subject: [PATCH 4/6] check if chunked cupy also works --- xarray/tests/test_cupy.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 94776902c11..08066796553 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -7,6 +7,7 @@ import xarray as xr cp = pytest.importorskip("cupy") +from xarray.tests import requires_dask @pytest.fixture @@ -60,3 +61,20 @@ def test_where() -> None: output = where(data < 1, 1, data).all() assert output assert isinstance(output, cp.ndarray) + + +@requires_dask +def test_where_dask() -> None: + import dask.array as da + + from xarray.core.duck_array_ops import where + + data = cp.zeros(10) + chunked = da.from_array(data, chunks=(2,)) + + chunked_output = where(chunked < 1, 1, chunked).all() + assert isinstance(chunked_output, da.Array) + + output = chunked_output.compute() + assert output + assert isinstance(output, cp.ndarray) From 6aa2b51e8f1a1c046b7f5393f8429a6203106fd8 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 7 Jul 2024 21:05:30 +0200 Subject: [PATCH 5/6] require nested duck arrays to be installed for this --- xarray/tests/__init__.py | 3 +++ xarray/tests/test_cupy.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0caab6e8247..1f3ab79feed 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -113,6 +113,9 @@ def _importorskip( category=DeprecationWarning, ) has_dask_expr, requires_dask_expr = _importorskip("dask_expr") +has_nested_duck_arrays, requires_nested_duck_arrays = _importorskip( + "nested_duck_arrays" +) has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 08066796553..4f1d48f2d8a 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -5,9 +5,9 @@ import pytest import xarray as xr +from xarray.tests import requires_dask, requires_nested_duck_arrays cp = pytest.importorskip("cupy") -from xarray.tests import requires_dask @pytest.fixture @@ -63,6 +63,7 @@ def test_where() -> None: assert isinstance(output, cp.ndarray) +@requires_nested_duck_arrays @requires_dask def test_where_dask() -> None: import dask.array as da From 9eb2d512a2a1b5bae6b44a304799c5447291502c Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sun, 7 Jul 2024 21:06:19 +0200 Subject: [PATCH 6/6] explicitly import `nested_duck_arrays.dask` --- xarray/tests/test_cupy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 4f1d48f2d8a..3f9d1cfd98e 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -67,6 +67,7 @@ def test_where() -> None: @requires_dask def test_where_dask() -> None: import dask.array as da + import nested_duck_arrays.dask # noqa: F401 from xarray.core.duck_array_ops import where