diff --git a/xarray/coding/times.py b/xarray/coding/times.py index ad5e8653e2a..7fffa595d94 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,7 +22,7 @@ ) from xarray.core import indexing from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like -from xarray.core.duck_array_ops import asarray, ravel, reshape +from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape from xarray.core.formatting import first_n_items, format_timestamp, last_item from xarray.core.pdcompat import default_precision_timestamp, timestamp_as_unit from xarray.core.utils import attempt_import, emit_user_level_warning @@ -676,7 +676,7 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str: unit_timedelta = _unit_timedelta_numpy zero_timedelta = np.timedelta64(0, "ns") for time_unit in time_units: - if np.all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta): + if array_all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta): return time_unit return "seconds" @@ -939,7 +939,7 @@ def encode_datetime(d): def cast_to_int_if_safe(num) -> np.ndarray: int_num = np.asarray(num, dtype=np.int64) - if (num == int_num).all(): + if array_all(num == int_num): num = int_num return num @@ -961,7 +961,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: cast_num = np.asarray(num, dtype=dtype) if np.issubdtype(dtype, np.integer): - if not (num == cast_num).all(): + if not array_all(num == cast_num): if np.issubdtype(num.dtype, np.floating): raise ValueError( f"Not possible to cast all encoded times from " @@ -979,7 +979,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: "a larger integer dtype." ) else: - if np.isinf(cast_num).any(): + if array_any(np.isinf(cast_num)): raise OverflowError( f"Not possible to cast encoded times from {num.dtype!r} to " f"{dtype!r} without overflow. Consider removing the dtype " diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index c3f1598050a..45fdaee9768 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -16,8 +16,6 @@ import numpy as np import pandas as pd -from numpy import all as array_all # noqa: F401 -from numpy import any as array_any # noqa: F401 from numpy import ( # noqa: F401 isclose, isnat, @@ -319,7 +317,9 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): if lazy_equiv is None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") - return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + return bool( + array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True)) + ) else: return lazy_equiv @@ -333,7 +333,7 @@ def array_equiv(arr1, arr2): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2)) - return bool(flag_array.all()) + return bool(array_all(flag_array)) else: return lazy_equiv @@ -349,7 +349,7 @@ def array_notnull_equiv(arr1, arr2): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "In the future, 'NAT == x'") flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2) - return bool(flag_array.all()) + return bool(array_all(flag_array)) else: return lazy_equiv @@ -536,6 +536,16 @@ def f(values, axis=None, skipna=None, **kwargs): cumsum_1d.numeric_only = True +def array_all(array, axis=None, keepdims=False, **kwargs): + xp = get_array_namespace(array) + return xp.all(array, axis=axis, keepdims=keepdims, **kwargs) + + +def array_any(array, axis=None, keepdims=False, **kwargs): + xp = get_array_namespace(array) + return xp.any(array, axis=axis, keepdims=keepdims, **kwargs) + + _mean = _create_nan_agg_method("mean", invariant_0d=True) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index ab17fa85381..a6bacccbeef 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -18,7 +18,7 @@ from pandas.errors import OutOfBoundsDatetime from xarray.core.datatree_render import RenderDataTree -from xarray.core.duck_array_ops import array_equiv, astype +from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype from xarray.core.indexing import MemoryCachedArray from xarray.core.options import OPTIONS, _get_boolean_with_default from xarray.core.treenode import group_subtrees @@ -204,9 +204,9 @@ def format_items(x): day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]") time_needed = x[~pd.isnull(x)] != day_part day_needed = day_part != np.timedelta64(0, "ns") - if np.logical_not(day_needed).all(): + if array_all(np.logical_not(day_needed)): timedelta_format = "time" - elif np.logical_not(time_needed).all(): + elif array_all(np.logical_not(time_needed)): timedelta_format = "date" formatted = [format_item(xi, timedelta_format) for xi in x] @@ -232,7 +232,7 @@ def format_array_flat(array, max_width: int): cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1 if (array.size > 2) and ( - (max_possibly_relevant < array.size) or (cum_len > max_width).any() + (max_possibly_relevant < array.size) or array_any(cum_len > max_width) ): padding = " ... " max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 4894cf02be2..17c60b6f663 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -45,7 +45,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): data = getattr(np, func)(value, axis=axis, **kwargs) # TODO This will evaluate dask arrays and might be costly. - if (valid_count == 0).any(): + if duck_array_ops.array_any(valid_count == 0): raise ValueError("All-NaN slice encountered") return data diff --git a/xarray/core/utils.py b/xarray/core/utils.py index c3187b77722..0e8d69e4f84 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -119,12 +119,12 @@ def did_you_mean( word: Hashable, possibilities: Iterable[Hashable], *, n: int = 10 ) -> str: """ - Suggest a few correct words based on a list of possibilites + Suggest a few correct words based on a list of possibilities Parameters ---------- word : Hashable - Word to compare to a list of possibilites. + Word to compare to a list of possibilities. possibilities : Iterable of Hashable The iterable of Hashable that contains the correct values. n : int, default: 10 @@ -142,15 +142,15 @@ def did_you_mean( https://en.wikipedia.org/wiki/String_metric """ # Convert all values to string, get_close_matches doesn't handle all hashables: - possibilites_str: dict[str, Hashable] = {str(k): k for k in possibilities} + possibilities_str: dict[str, Hashable] = {str(k): k for k in possibilities} msg = "" if len( best_str := difflib.get_close_matches( - str(word), list(possibilites_str.keys()), n=n + str(word), list(possibilities_str.keys()), n=n ) ): - best = tuple(possibilites_str[k] for k in best_str) + best = tuple(possibilities_str[k] for k in best_str) msg = f"Did you mean one of {best}?" return msg diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 269cb49a2c1..cd24091b18e 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None: def _weight_check(w): # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670 - if duck_array_ops.isnull(w).any(): + if duck_array_ops.array_any(duck_array_ops.isnull(w)): raise ValueError( "`weights` cannot contain missing values. " "Missing values can be replaced by `weights.fillna(0)`." diff --git a/xarray/groupers.py b/xarray/groupers.py index dac4c4309de..32e5e712196 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -17,11 +17,10 @@ from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq -from xarray.core import duck_array_ops from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray -from xarray.core.duck_array_ops import isnull +from xarray.core.duck_array_ops import array_all, isnull from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -235,7 +234,7 @@ def _factorize_unique(self) -> EncodedGroups: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) - if (codes_ == -1).all(): + if array_all(codes_ == -1): raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) @@ -347,7 +346,7 @@ def reset(self) -> Self: ) def __post_init__(self) -> None: - if duck_array_ops.isnull(self.bins).all(): + if array_all(isnull(self.bins)): raise ValueError("All bin edges are NaN.") def _cut(self, data): @@ -381,7 +380,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead" ) codes = self._factorize_lazy(group) - if not by_is_chunked and (codes == -1).all(): + if not by_is_chunked and array_all(codes == -1): raise ValueError( f"None of the data falls within bins with edges {self.bins!r}" ) @@ -547,7 +546,7 @@ def factorize(self, group: T_Group) -> EncodedGroups: def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: # Copied from flox sorter = np.argsort(labels) - is_sorted = (sorter == np.arange(sorter.size)).all() + is_sorted = array_all(sorter == np.arange(sorter.size)) codes = np.searchsorted(labels, data, sorter=sorter) mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels)) # codes is the index in to the sorted array. diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 6d87537a523..8a2dba9261f 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -296,7 +296,7 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True): if (utils.is_duck_array(x) and utils.is_scalar(y)) or ( utils.is_scalar(x) and utils.is_duck_array(y) ): - equiv = (x == y).all() + equiv = duck_array_ops.array_all(x == y) else: equiv = duck_array_ops.array_equiv(x, y) assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose)