From 37d351034c9b8ff5158ff0bfea3652e8bcc5ce0d Mon Sep 17 00:00:00 2001 From: tvo Date: Sat, 2 Nov 2024 19:46:20 -0600 Subject: [PATCH 1/5] Allow wrapping astropy.units.Quantity --- xarray/core/variable.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 13053faff58..4216e574312 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -320,10 +320,17 @@ def convert_non_numpy_type(data): else: data = np.asarray(data) + _is_array_like = isinstance(data, np.ndarray | np.generic) + _is_nep18 = hasattr(data, "__array_function__") + _has_array_api = hasattr(data, "__array_namespace__") + _has_unit = hasattr(data, "_unit") + + # Allow `astropy.units.Quantity` + if _is_array_like and (_is_nep18 or _has_array_api) and _has_unit: + return cast("T_DuckArray", data) + # immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars - if not isinstance(data, np.ndarray | np.generic) and ( - hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") - ): + if not _is_array_like and (_is_nep18 or _has_array_api): return cast("T_DuckArray", data) # validate whether the data is valid data types. Also, explicitly cast `numpy` From 0f637c9a2a3e8c31987417c13ad0f4c0cd6e908a Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 9 Nov 2024 22:03:03 -0500 Subject: [PATCH 2/5] allow all np.ndarray subclasses --- xarray/core/variable.py | 20 +++++++++----------- xarray/tests/test_variable.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 17b1c26a63a..3789da98ae4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -321,17 +321,15 @@ def convert_non_numpy_type(data): else: data = np.asarray(data) - _is_array_like = isinstance(data, np.ndarray | np.generic) - _is_nep18 = hasattr(data, "__array_function__") - _has_array_api = hasattr(data, "__array_namespace__") - _has_unit = hasattr(data, "_unit") - - # Allow `astropy.units.Quantity` - if _is_array_like and (_is_nep18 or _has_array_api) and _has_unit: - return cast("T_DuckArray", data) - - # immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars - if not _is_array_like and (_is_nep18 or _has_array_api): + if isinstance(data, np.matrix): + data = np.asarray(data) + + # immediately return array-like types except `numpy.ndarray` and `numpy` scalars + # compare types with `is` instead of `isinstance` to allow `numpy.ndarray` subclasses + is_numpy = type(data) is np.ndarray or isinstance(data, np.generic) + if not is_numpy and ( + hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") + ): return cast("T_DuckArray", data) # validate whether the data is valid data types. Also, explicitly cast `numpy` diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 0ed47c2b5fe..607781b643f 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2746,6 +2746,19 @@ def test_ones_like(self) -> None: assert_identical(ones_like(orig), full_like(orig, 1)) assert_identical(ones_like(orig, dtype=int), full_like(orig, 1, dtype=int)) + def test_numpy_ndarray_subclass(self): + class SubclassedArray(np.ndarray): + def __new__(cls, array, foo): + obj = np.asarray(array).view(cls) + obj.foo = foo + return obj + + data = SubclassedArray([1, 2, 3], foo="bar") + actual = as_compatible_data(data) + assert isinstance(actual, SubclassedArray) + assert actual.foo == "bar" + assert_array_equal(data, actual) + def test_unsupported_type(self): # Non indexable type class CustomArray(NDArrayMixin): From e92a39bc9e97d46a182d39e63f661ad199218aea Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 9 Nov 2024 22:17:12 -0500 Subject: [PATCH 3/5] whats new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4659978df8a..f2c063e43e3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,8 @@ New Features - Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])`` (:issue:`2852`, :issue:`757`). By `Deepak Cherian `_. +- Allow wrapping ``np.ndarray`` subclasses, e.g. ``astropy.units.Quantity`` (:issue:`9704`, :pull:`9760`). + By `Sam Levang `_ and `Tien Vo `_. Breaking changes ~~~~~~~~~~~~~~~~ From 51be6be56861ec4fecbd5121b3344bdbc4b4d111 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sun, 10 Nov 2024 13:09:53 -0500 Subject: [PATCH 4/5] test np.matrix --- xarray/tests/test_variable.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 607781b643f..9c6f50037d3 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2759,6 +2759,13 @@ def __new__(cls, array, foo): assert actual.foo == "bar" assert_array_equal(data, actual) + def test_numpy_matrix(self): + with pytest.warns(PendingDeprecationWarning): + data = np.matrix([[1, 2], [3, 4]]) + actual = as_compatible_data(data) + assert isinstance(actual, np.ndarray) + assert_array_equal(data, actual) + def test_unsupported_type(self): # Non indexable type class CustomArray(NDArrayMixin): From fd5913837828cca09ff097a859dd630ceeb2f20a Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 11 Nov 2024 13:59:24 -0500 Subject: [PATCH 5/5] fix comment --- xarray/core/variable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3789da98ae4..a6ea44b1ee5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -332,8 +332,7 @@ def convert_non_numpy_type(data): ): return cast("T_DuckArray", data) - # validate whether the data is valid data types. Also, explicitly cast `numpy` - # subclasses and `numpy` scalars to `numpy.ndarray` + # anything left will be converted to `numpy.ndarray`, including `numpy` scalars data = np.asarray(data) if data.dtype.kind in "OMm":