Skip to content

Commit

Permalink
Merge branch 'main' into cftime_deprecation
Browse files Browse the repository at this point in the history
  • Loading branch information
Maddogghoek authored Feb 4, 2025
2 parents 87f0cde + 4b48cf7 commit 635aa2c
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 28 deletions.
10 changes: 5 additions & 5 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from xarray.core import indexing
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
from xarray.core.duck_array_ops import asarray, ravel, reshape
from xarray.core.duck_array_ops import array_all, array_any, asarray, ravel, reshape
from xarray.core.formatting import first_n_items, format_timestamp, last_item
from xarray.core.pdcompat import default_precision_timestamp, timestamp_as_unit
from xarray.core.utils import attempt_import, emit_user_level_warning
Expand Down Expand Up @@ -676,7 +676,7 @@ def _infer_time_units_from_diff(unique_timedeltas) -> str:
unit_timedelta = _unit_timedelta_numpy
zero_timedelta = np.timedelta64(0, "ns")
for time_unit in time_units:
if np.all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
if array_all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
return time_unit
return "seconds"

Expand Down Expand Up @@ -939,7 +939,7 @@ def encode_datetime(d):

def cast_to_int_if_safe(num) -> np.ndarray:
int_num = np.asarray(num, dtype=np.int64)
if (num == int_num).all():
if array_all(num == int_num):
num = int_num
return num

Expand All @@ -961,7 +961,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
cast_num = np.asarray(num, dtype=dtype)

if np.issubdtype(dtype, np.integer):
if not (num == cast_num).all():
if not array_all(num == cast_num):
if np.issubdtype(num.dtype, np.floating):
raise ValueError(
f"Not possible to cast all encoded times from "
Expand All @@ -979,7 +979,7 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
"a larger integer dtype."
)
else:
if np.isinf(cast_num).any():
if array_any(np.isinf(cast_num)):
raise OverflowError(
f"Not possible to cast encoded times from {num.dtype!r} to "
f"{dtype!r} without overflow. Consider removing the dtype "
Expand Down
20 changes: 15 additions & 5 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import numpy as np
import pandas as pd
from numpy import all as array_all # noqa: F401
from numpy import any as array_any # noqa: F401
from numpy import ( # noqa: F401
isclose,
isnat,
Expand Down Expand Up @@ -319,7 +317,9 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
if lazy_equiv is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
return bool(
array_all(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True))
)
else:
return lazy_equiv

Expand All @@ -333,7 +333,7 @@ def array_equiv(arr1, arr2):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
return bool(flag_array.all())
return bool(array_all(flag_array))
else:
return lazy_equiv

Expand All @@ -349,7 +349,7 @@ def array_notnull_equiv(arr1, arr2):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
return bool(flag_array.all())
return bool(array_all(flag_array))
else:
return lazy_equiv

Expand Down Expand Up @@ -536,6 +536,16 @@ def f(values, axis=None, skipna=None, **kwargs):
cumsum_1d.numeric_only = True


def array_all(array, axis=None, keepdims=False, **kwargs):
xp = get_array_namespace(array)
return xp.all(array, axis=axis, keepdims=keepdims, **kwargs)


def array_any(array, axis=None, keepdims=False, **kwargs):
xp = get_array_namespace(array)
return xp.any(array, axis=axis, keepdims=keepdims, **kwargs)


_mean = _create_nan_agg_method("mean", invariant_0d=True)


Expand Down
8 changes: 4 additions & 4 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pandas.errors import OutOfBoundsDatetime

from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.treenode import group_subtrees
Expand Down Expand Up @@ -204,9 +204,9 @@ def format_items(x):
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
time_needed = x[~pd.isnull(x)] != day_part
day_needed = day_part != np.timedelta64(0, "ns")
if np.logical_not(day_needed).all():
if array_all(np.logical_not(day_needed)):
timedelta_format = "time"
elif np.logical_not(time_needed).all():
elif array_all(np.logical_not(time_needed)):
timedelta_format = "date"

formatted = [format_item(xi, timedelta_format) for xi in x]
Expand All @@ -232,7 +232,7 @@ def format_array_flat(array, max_width: int):

cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1
if (array.size > 2) and (
(max_possibly_relevant < array.size) or (cum_len > max_width).any()
(max_possibly_relevant < array.size) or array_any(cum_len > max_width)
):
padding = " ... "
max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
data = getattr(np, func)(value, axis=axis, **kwargs)

# TODO This will evaluate dask arrays and might be costly.
if (valid_count == 0).any():
if duck_array_ops.array_any(valid_count == 0):
raise ValueError("All-NaN slice encountered")

return data
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def did_you_mean(
word: Hashable, possibilities: Iterable[Hashable], *, n: int = 10
) -> str:
"""
Suggest a few correct words based on a list of possibilites
Suggest a few correct words based on a list of possibilities
Parameters
----------
word : Hashable
Word to compare to a list of possibilites.
Word to compare to a list of possibilities.
possibilities : Iterable of Hashable
The iterable of Hashable that contains the correct values.
n : int, default: 10
Expand All @@ -142,15 +142,15 @@ def did_you_mean(
https://en.wikipedia.org/wiki/String_metric
"""
# Convert all values to string, get_close_matches doesn't handle all hashables:
possibilites_str: dict[str, Hashable] = {str(k): k for k in possibilities}
possibilities_str: dict[str, Hashable] = {str(k): k for k in possibilities}

msg = ""
if len(
best_str := difflib.get_close_matches(
str(word), list(possibilites_str.keys()), n=n
str(word), list(possibilities_str.keys()), n=n
)
):
best = tuple(possibilites_str[k] for k in best_str)
best = tuple(possibilities_str[k] for k in best_str)
msg = f"Did you mean one of {best}?"

return msg
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:

def _weight_check(w):
# Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
if duck_array_ops.isnull(w).any():
if duck_array_ops.array_any(duck_array_ops.isnull(w)):
raise ValueError(
"`weights` cannot contain missing values. "
"Missing values can be replaced by `weights.fillna(0)`."
Expand Down
11 changes: 5 additions & 6 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from numpy.typing import ArrayLike

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.computation import apply_ufunc
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.duck_array_ops import isnull
from xarray.core.duck_array_ops import array_all, isnull
from xarray.core.groupby import T_Group, _DummyGroup
from xarray.core.indexes import safe_cast_to_index
from xarray.core.resample_cftime import CFTimeGrouper
Expand Down Expand Up @@ -235,7 +234,7 @@ def _factorize_unique(self) -> EncodedGroups:
# look through group to find the unique values
sort = not isinstance(self.group_as_index, pd.MultiIndex)
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
if (codes_ == -1).all():
if array_all(codes_ == -1):
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
Expand Down Expand Up @@ -347,7 +346,7 @@ def reset(self) -> Self:
)

def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
if array_all(isnull(self.bins)):
raise ValueError("All bin edges are NaN.")

def _cut(self, data):
Expand Down Expand Up @@ -381,7 +380,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead"
)
codes = self._factorize_lazy(group)
if not by_is_chunked and (codes == -1).all():
if not by_is_chunked and array_all(codes == -1):
raise ValueError(
f"None of the data falls within bins with edges {self.bins!r}"
)
Expand Down Expand Up @@ -547,7 +546,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
# Copied from flox
sorter = np.argsort(labels)
is_sorted = (sorter == np.arange(sorter.size)).all()
is_sorted = array_all(sorter == np.arange(sorter.size))
codes = np.searchsorted(labels, data, sorter=sorter)
mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels))
# codes is the index in to the sorted array.
Expand Down
2 changes: 1 addition & 1 deletion xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True):
if (utils.is_duck_array(x) and utils.is_scalar(y)) or (
utils.is_scalar(x) and utils.is_duck_array(y)
):
equiv = (x == y).all()
equiv = duck_array_ops.array_all(x == y)
else:
equiv = duck_array_ops.array_equiv(x, y)
assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose)
Expand Down

0 comments on commit 635aa2c

Please sign in to comment.