Skip to content

Commit 88e778a

Browse files
committed
Fix upcasting with python builtin numbers and numpy 2
1 parent 3ace2fb commit 88e778a

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

xarray/core/duck_array_ops.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
zeros_like, # noqa
3131
)
3232
from numpy import concatenate as _concatenate
33-
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
33+
try:
34+
# numpy 2.x
35+
from numpy.lib.array_utils import normalize_axis_index # type: ignore[attr-defined]
36+
except ImportError:
37+
# numpy 1.x
38+
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
3439
from numpy.lib.stride_tricks import sliding_window_view # noqa
3540

3641
from xarray.core import dask_array_ops, dtypes, nputils
@@ -202,13 +207,15 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
202207

203208
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
204209
else:
205-
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
210+
#arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
211+
arrays = [x if isinstance(x, (int, float, complex)) else asarray(x, xp=xp) for x in scalars_or_arrays]
206212
# Pass arrays directly instead of dtypes to result_type so scalars
207213
# get handled properly.
208214
# Note that result_type() safely gets the dtype from dask arrays without
209215
# evaluating them.
210216
out_type = dtypes.result_type(*arrays)
211-
return [astype(x, out_type, copy=False) for x in arrays]
217+
#return [astype(x, out_type, copy=False) for x in arrays]
218+
return [astype(x, out_type, copy=False) if hasattr(x, "dtype") else x for x in arrays]
212219

213220

214221
def broadcast_to(array, shape):

0 commit comments

Comments
 (0)