Skip to content

Commit 4b48cf7

Browse files
Duck array ops for all and any (#9883)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2658c00 commit 4b48cf7

File tree

7 files changed

+32
-23
lines changed

7 files changed

+32
-23
lines changed

xarray/coding/times.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from xarray.core import indexing
2424
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
25-
from xarray.core.duck_array_ops import asarray, ravel, reshape
25+
from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape
2626
from xarray.core.formatting import first_n_items, format_timestamp, last_item
2727
from xarray.core.pdcompat import default_precision_timestamp, timestamp_as_unit
2828
from xarray.core.utils import attempt_import, emit_user_level_warning
@@ -676,7 +676,7 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str:
676676
unit_timedelta = _unit_timedelta_numpy
677677
zero_timedelta = np.timedelta64(0, "ns")
678678
for time_unit in time_units:
679-
if np.all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
679+
if array_all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
680680
return time_unit
681681
return "seconds"
682682

@@ -939,7 +939,7 @@ def encode_datetime(d):
939939

940940
def cast_to_int_if_safe(num) -> np.ndarray:
941941
int_num = np.asarray(num, dtype=np.int64)
942-
if (num == int_num).all():
942+
if array_all(num == int_num):
943943
num = int_num
944944
return num
945945

@@ -961,7 +961,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
961961
cast_num = np.asarray(num, dtype=dtype)
962962

963963
if np.issubdtype(dtype, np.integer):
964-
if not (num == cast_num).all():
964+
if not array_all(num == cast_num):
965965
if np.issubdtype(num.dtype, np.floating):
966966
raise ValueError(
967967
f"Not possible to cast all encoded times from "
@@ -979,7 +979,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
979979
"a larger integer dtype."
980980
)
981981
else:
982-
if np.isinf(cast_num).any():
982+
if array_any(np.isinf(cast_num)):
983983
raise OverflowError(
984984
f"Not possible to cast encoded times from {num.dtype!r} to "
985985
f"{dtype!r} without overflow. Consider removing the dtype "

xarray/core/duck_array_ops.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import numpy as np
1818
import pandas as pd
19-
from numpy import all as array_all # noqa: F401
20-
from numpy import any as array_any # noqa: F401
2119
from numpy import ( # noqa: F401
2220
isclose,
2321
isnat,
@@ -319,7 +317,9 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
319317
if lazy_equiv is None:
320318
with warnings.catch_warnings():
321319
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
322-
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
320+
return bool(
321+
array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True))
322+
)
323323
else:
324324
return lazy_equiv
325325

@@ -333,7 +333,7 @@ def array_equiv(arr1, arr2):
333333
with warnings.catch_warnings():
334334
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
335335
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
336-
return bool(flag_array.all())
336+
return bool(array_all(flag_array))
337337
else:
338338
return lazy_equiv
339339

@@ -349,7 +349,7 @@ def array_notnull_equiv(arr1, arr2):
349349
with warnings.catch_warnings():
350350
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
351351
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
352-
return bool(flag_array.all())
352+
return bool(array_all(flag_array))
353353
else:
354354
return lazy_equiv
355355

@@ -536,6 +536,16 @@ def f(values, axis=None, skipna=None, **kwargs):
536536
cumsum_1d.numeric_only = True
537537

538538

539+
def array_all(array, axis=None, keepdims=False, **kwargs):
540+
xp = get_array_namespace(array)
541+
return xp.all(array, axis=axis, keepdims=keepdims, **kwargs)
542+
543+
544+
def array_any(array, axis=None, keepdims=False, **kwargs):
545+
xp = get_array_namespace(array)
546+
return xp.any(array, axis=axis, keepdims=keepdims, **kwargs)
547+
548+
539549
_mean = _create_nan_agg_method("mean", invariant_0d=True)
540550

541551

xarray/core/formatting.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pandas.errors import OutOfBoundsDatetime
1919

2020
from xarray.core.datatree_render import RenderDataTree
21-
from xarray.core.duck_array_ops import array_equiv, astype
21+
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype
2222
from xarray.core.indexing import MemoryCachedArray
2323
from xarray.core.options import OPTIONS, _get_boolean_with_default
2424
from xarray.core.treenode import group_subtrees
@@ -204,9 +204,9 @@ def format_items(x):
204204
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
205205
time_needed = x[~pd.isnull(x)] != day_part
206206
day_needed = day_part != np.timedelta64(0, "ns")
207-
if np.logical_not(day_needed).all():
207+
if array_all(np.logical_not(day_needed)):
208208
timedelta_format = "time"
209-
elif np.logical_not(time_needed).all():
209+
elif array_all(np.logical_not(time_needed)):
210210
timedelta_format = "date"
211211

212212
formatted = [format_item(xi, timedelta_format) for xi in x]
@@ -232,7 +232,7 @@ def format_array_flat(array, max_width: int):
232232

233233
cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1
234234
if (array.size > 2) and (
235-
(max_possibly_relevant < array.size) or (cum_len > max_width).any()
235+
(max_possibly_relevant < array.size) or array_any(cum_len > max_width)
236236
):
237237
padding = " ... "
238238
max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2)

xarray/core/nanops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
4545
data = getattr(np, func)(value, axis=axis, **kwargs)
4646

4747
# TODO This will evaluate dask arrays and might be costly.
48-
if (valid_count == 0).any():
48+
if duck_array_ops.array_any(valid_count == 0):
4949
raise ValueError("All-NaN slice encountered")
5050

5151
return data

xarray/core/weighted.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:
171171

172172
def _weight_check(w):
173173
# Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
174-
if duck_array_ops.isnull(w).any():
174+
if duck_array_ops.array_any(duck_array_ops.isnull(w)):
175175
raise ValueError(
176176
"`weights` cannot contain missing values. "
177177
"Missing values can be replaced by `weights.fillna(0)`."

xarray/groupers.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
from numpy.typing import ArrayLike
1818

1919
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
20-
from xarray.core import duck_array_ops
2120
from xarray.core.computation import apply_ufunc
2221
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
2322
from xarray.core.dataarray import DataArray
24-
from xarray.core.duck_array_ops import isnull
23+
from xarray.core.duck_array_ops import array_all, isnull
2524
from xarray.core.groupby import T_Group, _DummyGroup
2625
from xarray.core.indexes import safe_cast_to_index
2726
from xarray.core.resample_cftime import CFTimeGrouper
@@ -235,7 +234,7 @@ def _factorize_unique(self) -> EncodedGroups:
235234
# look through group to find the unique values
236235
sort = not isinstance(self.group_as_index, pd.MultiIndex)
237236
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
238-
if (codes_ == -1).all():
237+
if array_all(codes_ == -1):
239238
raise ValueError(
240239
"Failed to group data. Are you grouping by a variable that is all NaN?"
241240
)
@@ -347,7 +346,7 @@ def reset(self) -> Self:
347346
)
348347

349348
def __post_init__(self) -> None:
350-
if duck_array_ops.isnull(self.bins).all():
349+
if array_all(isnull(self.bins)):
351350
raise ValueError("All bin edges are NaN.")
352351

353352
def _cut(self, data):
@@ -381,7 +380,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
381380
f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead"
382381
)
383382
codes = self._factorize_lazy(group)
384-
if not by_is_chunked and (codes == -1).all():
383+
if not by_is_chunked and array_all(codes == -1):
385384
raise ValueError(
386385
f"None of the data falls within bins with edges {self.bins!r}"
387386
)
@@ -547,7 +546,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
547546
def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
548547
# Copied from flox
549548
sorter = np.argsort(labels)
550-
is_sorted = (sorter == np.arange(sorter.size)).all()
549+
is_sorted = array_all(sorter == np.arange(sorter.size))
551550
codes = np.searchsorted(labels, data, sorter=sorter)
552551
mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels))
553552
# codes is the index in to the sorted array.

xarray/testing/assertions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True):
296296
if (utils.is_duck_array(x) and utils.is_scalar(y)) or (
297297
utils.is_scalar(x) and utils.is_duck_array(y)
298298
):
299-
equiv = (x == y).all()
299+
equiv = duck_array_ops.array_all(x == y)
300300
else:
301301
equiv = duck_array_ops.array_equiv(x, y)
302302
assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose)

0 commit comments

Comments
 (0)