Skip to content

Commit 2013e7f

Browse files
djhoesepre-commit-ci[bot]dcheriankeewis
authored
Fix upcasting with python builtin numbers and numpy 2 (#8946)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Justus Magin <[email protected]> Co-authored-by: Justus Magin <[email protected]>
1 parent f0ee037 commit 2013e7f

7 files changed

+134
-36
lines changed

xarray/core/array_api_compat.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
3+
4+
def is_weak_scalar_type(t):
5+
return isinstance(t, (bool, int, float, complex, str, bytes))
6+
7+
8+
def _future_array_api_result_type(*arrays_and_dtypes, xp):
9+
# fallback implementation for `xp.result_type` with python scalars. Can be removed once a
10+
# version of the Array API that includes https://github.com/data-apis/array-api/issues/805
11+
# can be required
12+
strongly_dtyped = [t for t in arrays_and_dtypes if not is_weak_scalar_type(t)]
13+
weakly_dtyped = [t for t in arrays_and_dtypes if is_weak_scalar_type(t)]
14+
15+
if not strongly_dtyped:
16+
strongly_dtyped = [
17+
xp.asarray(x) if not isinstance(x, type) else x for x in weakly_dtyped
18+
]
19+
weakly_dtyped = []
20+
21+
dtype = xp.result_type(*strongly_dtyped)
22+
if not weakly_dtyped:
23+
return dtype
24+
25+
possible_dtypes = {
26+
complex: "complex64",
27+
float: "float32",
28+
int: "int8",
29+
bool: "bool",
30+
str: "str",
31+
bytes: "bytes",
32+
}
33+
dtypes = [possible_dtypes.get(type(x), "object") for x in weakly_dtyped]
34+
35+
return xp.result_type(dtype, *dtypes)
36+
37+
38+
def result_type(*arrays_and_dtypes, xp) -> np.dtype:
39+
if xp is np or any(
40+
isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
41+
):
42+
return xp.result_type(*arrays_and_dtypes)
43+
else:
44+
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)

xarray/core/dtypes.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from pandas.api.types import is_extension_array_dtype
88

9-
from xarray.core import npcompat, utils
9+
from xarray.core import array_api_compat, npcompat, utils
1010

1111
# Use as a sentinel value to indicate a dtype appropriate NA value.
1212
NA = utils.ReprObject("<NA>")
@@ -131,7 +131,10 @@ def get_pos_infinity(dtype, max_for_int=False):
131131
if isdtype(dtype, "complex floating"):
132132
return np.inf + 1j * np.inf
133133

134-
return INF
134+
if isdtype(dtype, "bool"):
135+
return True
136+
137+
return np.array(INF, dtype=object)
135138

136139

137140
def get_neg_infinity(dtype, min_for_int=False):
@@ -159,7 +162,10 @@ def get_neg_infinity(dtype, min_for_int=False):
159162
if isdtype(dtype, "complex floating"):
160163
return -np.inf - 1j * np.inf
161164

162-
return NINF
165+
if isdtype(dtype, "bool"):
166+
return False
167+
168+
return np.array(NINF, dtype=object)
163169

164170

165171
def is_datetime_like(dtype) -> bool:
@@ -209,8 +215,16 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
209215
return xp.isdtype(dtype, kind)
210216

211217

218+
def preprocess_scalar_types(t):
219+
if isinstance(t, (str, bytes)):
220+
return type(t)
221+
else:
222+
return t
223+
224+
212225
def result_type(
213226
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
227+
xp=None,
214228
) -> np.dtype:
215229
"""Like np.result_type, but with type promotion rules matching pandas.
216230
@@ -227,26 +241,26 @@ def result_type(
227241
-------
228242
numpy.dtype for the result.
229243
"""
244+
# TODO (keewis): replace `array_api_compat.result_type` with `xp.result_type` once we
245+
# can require a version of the Array API that supports passing scalars to it.
230246
from xarray.core.duck_array_ops import get_array_namespace
231247

232-
# TODO(shoyer): consider moving this logic into get_array_namespace()
233-
# or another helper function.
234-
namespaces = {get_array_namespace(t) for t in arrays_and_dtypes}
235-
non_numpy = namespaces - {np}
236-
if non_numpy:
237-
[xp] = non_numpy
238-
else:
239-
xp = np
240-
241-
types = {xp.result_type(t) for t in arrays_and_dtypes}
248+
if xp is None:
249+
xp = get_array_namespace(arrays_and_dtypes)
242250

251+
types = {
252+
array_api_compat.result_type(preprocess_scalar_types(t), xp=xp)
253+
for t in arrays_and_dtypes
254+
}
243255
if any(isinstance(t, np.dtype) for t in types):
244256
# only check if there's numpy dtypes – the array API does not
245257
# define the types we're checking for
246258
for left, right in PROMOTE_TO_OBJECT:
247259
if any(np.issubdtype(t, left) for t in types) and any(
248260
np.issubdtype(t, right) for t in types
249261
):
250-
return xp.dtype(object)
262+
return np.dtype(object)
251263

252-
return xp.result_type(*arrays_and_dtypes)
264+
return array_api_compat.result_type(
265+
*map(preprocess_scalar_types, arrays_and_dtypes), xp=xp
266+
)

xarray/core/duck_array_ops.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,26 @@
5555
dask_available = module_available("dask")
5656

5757

58-
def get_array_namespace(x):
59-
if hasattr(x, "__array_namespace__"):
60-
return x.__array_namespace__()
58+
def get_array_namespace(*values):
59+
def _get_array_namespace(x):
60+
if hasattr(x, "__array_namespace__"):
61+
return x.__array_namespace__()
62+
else:
63+
return np
64+
65+
namespaces = {_get_array_namespace(t) for t in values}
66+
non_numpy = namespaces - {np}
67+
68+
if len(non_numpy) > 1:
69+
raise TypeError(
70+
"cannot deal with more than one type supporting the array API at the same time"
71+
)
72+
elif non_numpy:
73+
[xp] = non_numpy
6174
else:
62-
return np
75+
xp = np
76+
77+
return xp
6378

6479

6580
def einsum(*args, **kwargs):
@@ -224,11 +239,19 @@ def astype(data, dtype, **kwargs):
224239
return data.astype(dtype, **kwargs)
225240

226241

227-
def asarray(data, xp=np):
228-
return data if is_duck_array(data) else xp.asarray(data)
242+
def asarray(data, xp=np, dtype=None):
243+
converted = data if is_duck_array(data) else xp.asarray(data)
229244

245+
if dtype is None or converted.dtype == dtype:
246+
return converted
230247

231-
def as_shared_dtype(scalars_or_arrays, xp=np):
248+
if xp is np or not hasattr(xp, "astype"):
249+
return converted.astype(dtype)
250+
else:
251+
return xp.astype(converted, dtype)
252+
253+
254+
def as_shared_dtype(scalars_or_arrays, xp=None):
232255
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
233256
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
234257
extension_array_types = [
@@ -239,23 +262,26 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
239262
):
240263
return scalars_or_arrays
241264
raise ValueError(
242-
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
265+
"Cannot cast arrays to shared type, found"
266+
f" array types {[x.dtype for x in scalars_or_arrays]}"
243267
)
244268

245269
# Avoid calling array_type("cupy") repeatidely in the any check
246270
array_type_cupy = array_type("cupy")
247271
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
248272
import cupy as cp
249273

250-
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
251-
else:
252-
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
274+
xp = cp
275+
elif xp is None:
276+
xp = get_array_namespace(scalars_or_arrays)
277+
253278
# Pass arrays directly instead of dtypes to result_type so scalars
254279
# get handled properly.
255280
# Note that result_type() safely gets the dtype from dask arrays without
256281
# evaluating them.
257-
out_type = dtypes.result_type(*arrays)
258-
return [astype(x, out_type, copy=False) for x in arrays]
282+
dtype = dtypes.result_type(*scalars_or_arrays, xp=xp)
283+
284+
return [asarray(x, dtype=dtype, xp=xp) for x in scalars_or_arrays]
259285

260286

261287
def broadcast_to(array, shape):

xarray/tests/test_dataarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3004,7 +3004,7 @@ def test_fillna(self) -> None:
30043004
expected = b.copy()
30053005
assert_identical(expected, actual)
30063006

3007-
actual = a.fillna(range(4))
3007+
actual = a.fillna(np.arange(4))
30083008
assert_identical(expected, actual)
30093009

30103010
actual = a.fillna(b[:3])
@@ -3017,7 +3017,7 @@ def test_fillna(self) -> None:
30173017
a.fillna({0: 0})
30183018

30193019
with pytest.raises(ValueError, match=r"broadcast"):
3020-
a.fillna([1, 2])
3020+
a.fillna(np.array([1, 2]))
30213021

30223022
def test_align(self) -> None:
30233023
array = DataArray(

xarray/tests/test_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5209,7 +5209,7 @@ def test_fillna(self) -> None:
52095209
actual6 = ds.fillna(expected)
52105210
assert_identical(expected, actual6)
52115211

5212-
actual7 = ds.fillna(range(4))
5212+
actual7 = ds.fillna(np.arange(4))
52135213
assert_identical(expected, actual7)
52145214

52155215
actual8 = ds.fillna(b[:3])

xarray/tests/test_dtypes.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,23 @@ def test_result_type(args, expected) -> None:
3535
assert actual == expected
3636

3737

38-
def test_result_type_scalar() -> None:
39-
actual = dtypes.result_type(np.arange(3, dtype=np.float32), np.nan)
40-
assert actual == np.float32
38+
@pytest.mark.parametrize(
39+
["values", "expected"],
40+
(
41+
([np.arange(3, dtype="float32"), np.nan], np.float32),
42+
([np.arange(3, dtype="int8"), 1], np.int8),
43+
([np.array(["a", "b"], dtype=str), np.nan], object),
44+
([np.array([b"a", b"b"], dtype=bytes), True], object),
45+
([np.array([b"a", b"b"], dtype=bytes), "c"], object),
46+
([np.array(["a", "b"], dtype=str), "c"], np.dtype(str)),
47+
([np.array(["a", "b"], dtype=str), None], object),
48+
([0, 1], np.dtype("int")),
49+
),
50+
)
51+
def test_result_type_scalars(values, expected) -> None:
52+
actual = dtypes.result_type(*values)
53+
54+
assert np.issubdtype(actual, expected)
4155

4256

4357
def test_result_type_dask_array() -> None:

xarray/tests/test_duck_array_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_count(self):
157157
assert 1 == count(np.datetime64("2000-01-01"))
158158

159159
def test_where_type_promotion(self):
160-
result = where([True, False], [1, 2], ["a", "b"])
160+
result = where(np.array([True, False]), np.array([1, 2]), np.array(["a", "b"]))
161161
assert_array_equal(result, np.array([1, "b"], dtype=object))
162162

163163
result = where([True, False], np.array([1, 2], np.float32), np.nan)
@@ -214,7 +214,7 @@ def test_stack_type_promotion(self):
214214
assert_array_equal(result, np.array([1, "b"], dtype=object))
215215

216216
def test_concatenate_type_promotion(self):
217-
result = concatenate([[1], ["b"]])
217+
result = concatenate([np.array([1]), np.array(["b"])])
218218
assert_array_equal(result, np.array([1, "b"], dtype=object))
219219

220220
@pytest.mark.filterwarnings("error")

0 commit comments

Comments
 (0)