6
6
import numpy as np
7
7
from pandas .api .types import is_extension_array_dtype
8
8
9
- from xarray .core import npcompat , utils
9
+ from xarray .core import array_api_compat , npcompat , utils
10
10
11
11
# Use as a sentinel value to indicate a dtype appropriate NA value.
12
12
NA = utils .ReprObject ("<NA>" )
@@ -131,7 +131,10 @@ def get_pos_infinity(dtype, max_for_int=False):
131
131
if isdtype (dtype , "complex floating" ):
132
132
return np .inf + 1j * np .inf
133
133
134
- return INF
134
+ if isdtype (dtype , "bool" ):
135
+ return True
136
+
137
+ return np .array (INF , dtype = object )
135
138
136
139
137
140
def get_neg_infinity (dtype , min_for_int = False ):
@@ -159,7 +162,10 @@ def get_neg_infinity(dtype, min_for_int=False):
159
162
if isdtype (dtype , "complex floating" ):
160
163
return - np .inf - 1j * np .inf
161
164
162
- return NINF
165
+ if isdtype (dtype , "bool" ):
166
+ return False
167
+
168
+ return np .array (NINF , dtype = object )
163
169
164
170
165
171
def is_datetime_like (dtype ) -> bool :
@@ -209,8 +215,16 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
209
215
return xp .isdtype (dtype , kind )
210
216
211
217
218
+ def preprocess_scalar_types (t ):
219
+ if isinstance (t , (str , bytes )):
220
+ return type (t )
221
+ else :
222
+ return t
223
+
224
+
212
225
def result_type (
213
226
* arrays_and_dtypes : np .typing .ArrayLike | np .typing .DTypeLike ,
227
+ xp = None ,
214
228
) -> np .dtype :
215
229
"""Like np.result_type, but with type promotion rules matching pandas.
216
230
@@ -227,26 +241,26 @@ def result_type(
227
241
-------
228
242
numpy.dtype for the result.
229
243
"""
244
+ # TODO (keewis): replace `array_api_compat.result_type` with `xp.result_type` once we
245
+ # can require a version of the Array API that supports passing scalars to it.
230
246
from xarray .core .duck_array_ops import get_array_namespace
231
247
232
- # TODO(shoyer): consider moving this logic into get_array_namespace()
233
- # or another helper function.
234
- namespaces = {get_array_namespace (t ) for t in arrays_and_dtypes }
235
- non_numpy = namespaces - {np }
236
- if non_numpy :
237
- [xp ] = non_numpy
238
- else :
239
- xp = np
240
-
241
- types = {xp .result_type (t ) for t in arrays_and_dtypes }
248
+ if xp is None :
249
+ xp = get_array_namespace (arrays_and_dtypes )
242
250
251
+ types = {
252
+ array_api_compat .result_type (preprocess_scalar_types (t ), xp = xp )
253
+ for t in arrays_and_dtypes
254
+ }
243
255
if any (isinstance (t , np .dtype ) for t in types ):
244
256
# only check if there's numpy dtypes – the array API does not
245
257
# define the types we're checking for
246
258
for left , right in PROMOTE_TO_OBJECT :
247
259
if any (np .issubdtype (t , left ) for t in types ) and any (
248
260
np .issubdtype (t , right ) for t in types
249
261
):
250
- return xp .dtype (object )
262
+ return np .dtype (object )
251
263
252
- return xp .result_type (* arrays_and_dtypes )
264
+ return array_api_compat .result_type (
265
+ * map (preprocess_scalar_types , arrays_and_dtypes ), xp = xp
266
+ )
0 commit comments