22
22
try :
23
23
# torch >=2.3
24
24
_int_dtypes |= {torch .uint16 , torch .uint32 , torch .uint64 }
25
+ _HAS_LARGE_UINT = True
25
26
except AttributeError :
26
- pass
27
-
27
+ _HAS_LARGE_UINT = False
28
28
29
29
_array_api_dtypes = {
30
30
torch .bool ,
35
35
torch .complex128 ,
36
36
}
37
37
38
- _promotion_table = {
39
- # bool
40
- (torch .bool , torch .bool ): torch .bool ,
38
+ _promotion_table = {
41
39
# ints
42
- (torch .int8 , torch .int8 ): torch .int8 ,
43
40
(torch .int8 , torch .int16 ): torch .int16 ,
44
41
(torch .int8 , torch .int32 ): torch .int32 ,
45
42
(torch .int8 , torch .int64 ): torch .int64 ,
46
- (torch .int16 , torch .int8 ): torch .int16 ,
47
- (torch .int16 , torch .int16 ): torch .int16 ,
48
43
(torch .int16 , torch .int32 ): torch .int32 ,
49
44
(torch .int16 , torch .int64 ): torch .int64 ,
50
- (torch .int32 , torch .int8 ): torch .int32 ,
51
- (torch .int32 , torch .int16 ): torch .int32 ,
52
- (torch .int32 , torch .int32 ): torch .int32 ,
53
45
(torch .int32 , torch .int64 ): torch .int64 ,
54
- (torch .int64 , torch .int8 ): torch .int64 ,
55
- (torch .int64 , torch .int16 ): torch .int64 ,
56
- (torch .int64 , torch .int32 ): torch .int64 ,
57
- (torch .int64 , torch .int64 ): torch .int64 ,
58
- # uints
59
- (torch .uint8 , torch .uint8 ): torch .uint8 ,
60
46
# ints and uints (mixed sign)
61
- (torch .int8 , torch .uint8 ): torch .int16 ,
62
- (torch .int16 , torch .uint8 ): torch .int16 ,
63
- (torch .int32 , torch .uint8 ): torch .int32 ,
64
- (torch .int64 , torch .uint8 ): torch .int64 ,
65
47
(torch .uint8 , torch .int8 ): torch .int16 ,
66
48
(torch .uint8 , torch .int16 ): torch .int16 ,
67
49
(torch .uint8 , torch .int32 ): torch .int32 ,
68
50
(torch .uint8 , torch .int64 ): torch .int64 ,
69
51
# floats
70
- (torch .float32 , torch .float32 ): torch .float32 ,
71
52
(torch .float32 , torch .float64 ): torch .float64 ,
72
- (torch .float64 , torch .float32 ): torch .float64 ,
73
- (torch .float64 , torch .float64 ): torch .float64 ,
74
53
# complexes
75
- (torch .complex64 , torch .complex64 ): torch .complex64 ,
76
54
(torch .complex64 , torch .complex128 ): torch .complex128 ,
77
- (torch .complex128 , torch .complex64 ): torch .complex128 ,
78
- (torch .complex128 , torch .complex128 ): torch .complex128 ,
79
55
# Mixed float and complex
80
56
(torch .float32 , torch .complex64 ): torch .complex64 ,
81
57
(torch .float32 , torch .complex128 ): torch .complex128 ,
82
58
(torch .float64 , torch .complex64 ): torch .complex128 ,
83
59
(torch .float64 , torch .complex128 ): torch .complex128 ,
84
60
}
85
61
62
+ if _HAS_LARGE_UINT : # torch >=2.3
63
+ _promotion_table .update (
64
+ {
65
+ # uints
66
+ (torch .uint8 , torch .uint16 ): torch .uint16 ,
67
+ (torch .uint8 , torch .uint32 ): torch .uint32 ,
68
+ (torch .uint8 , torch .uint64 ): torch .uint64 ,
69
+ (torch .uint16 , torch .uint32 ): torch .uint32 ,
70
+ (torch .uint16 , torch .uint64 ): torch .uint64 ,
71
+ (torch .uint32 , torch .uint64 ): torch .uint64 ,
72
+ # ints and uints (mixed sign)
73
+ (torch .uint16 , torch .int8 ): torch .int32 ,
74
+ (torch .uint16 , torch .int16 ): torch .int32 ,
75
+ (torch .uint16 , torch .int32 ): torch .int32 ,
76
+ (torch .uint16 , torch .int64 ): torch .int64 ,
77
+ (torch .uint32 , torch .int8 ): torch .int64 ,
78
+ (torch .uint32 , torch .int16 ): torch .int64 ,
79
+ (torch .uint32 , torch .int32 ): torch .int64 ,
80
+ (torch .uint32 , torch .int64 ): torch .int64 ,
81
+ }
82
+ )
83
+
84
+ _promotion_table .update ({(b , a ): c for (a , b ), c in _promotion_table .items ()})
85
+ _promotion_table .update ({(a , a ): a for a in _array_api_dtypes })
86
+
86
87
87
88
def _two_arg (f ):
88
89
@_wraps (f )
@@ -301,27 +302,41 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
301
302
out = torch .unsqueeze (out , a )
302
303
return out
303
304
305
+
306
+ def _sum_prod_no_axis (x : Array , dtype : DType | None ) -> Array :
307
+ """
308
+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
309
+
310
+ Works around https://github.com/pytorch/pytorch/issues/29137
311
+ """
312
+ if dtype is not None :
313
+ return x .clone () if dtype == x .dtype else x .to (dtype )
314
+
315
+ if x .dtype in (torch .int8 , torch .int16 , torch .int32 ):
316
+ return x .to (torch .int64 )
317
+
318
+ if _HAS_LARGE_UINT and x .dtype in (torch .uint8 , torch .uint16 , torch .uint32 ):
319
+ return x .to (torch .uint64 )
320
+
321
+ if x .dtype == torch .uint8 :
322
+ # We can't upcast uint8 according to the spec because there is no
323
+ # torch.uint64, so at least upcast to int64 which is what prod does
324
+ # when axis=None.
325
+ return x .to (torch .int64 )
326
+
327
+ return x .clone ()
328
+
329
+
304
330
def prod (x : Array ,
305
331
/ ,
306
332
* ,
307
333
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
308
334
dtype : Optional [DType ] = None ,
309
335
keepdims : bool = False ,
310
336
** kwargs ) -> Array :
311
- ndim = x .ndim
312
337
313
- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
314
- # below because it still needs to upcast.
315
338
if axis == ():
316
- if dtype is None :
317
- # We can't upcast uint8 according to the spec because there is no
318
- # torch.uint64, so at least upcast to int64 which is what sum does
319
- # when axis=None.
320
- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
321
- return x .to (torch .int64 )
322
- return x .clone ()
323
- return x .to (dtype )
324
-
339
+ return _sum_prod_no_axis (x , dtype )
325
340
# torch.prod doesn't support multiple axes
326
341
# (https://github.com/pytorch/pytorch/issues/56586).
327
342
if isinstance (axis , tuple ):
@@ -330,7 +345,7 @@ def prod(x: Array,
330
345
# torch doesn't support keepdims with axis=None
331
346
# (https://github.com/pytorch/pytorch/issues/71209)
332
347
res = torch .prod (x , dtype = dtype , ** kwargs )
333
- res = _axis_none_keepdims (res , ndim , keepdims )
348
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
334
349
return res
335
350
336
351
return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -343,25 +358,14 @@ def sum(x: Array,
343
358
dtype : Optional [DType ] = None ,
344
359
keepdims : bool = False ,
345
360
** kwargs ) -> Array :
346
- ndim = x .ndim
347
361
348
- # https://github.com/pytorch/pytorch/issues/29137.
349
- # Make sure it upcasts.
350
362
if axis == ():
351
- if dtype is None :
352
- # We can't upcast uint8 according to the spec because there is no
353
- # torch.uint64, so at least upcast to int64 which is what sum does
354
- # when axis=None.
355
- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
356
- return x .to (torch .int64 )
357
- return x .clone ()
358
- return x .to (dtype )
359
-
363
+ return _sum_prod_no_axis (x , dtype )
360
364
if axis is None :
361
365
# torch doesn't support keepdims with axis=None
362
366
# (https://github.com/pytorch/pytorch/issues/71209)
363
367
res = torch .sum (x , dtype = dtype , ** kwargs )
364
- res = _axis_none_keepdims (res , ndim , keepdims )
368
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
365
369
return res
366
370
367
371
return torch .sum (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -372,7 +376,7 @@ def any(x: Array,
372
376
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
373
377
keepdims : bool = False ,
374
378
** kwargs ) -> Array :
375
- ndim = x . ndim
379
+
376
380
if axis == ():
377
381
return x .to (torch .bool )
378
382
# torch.any doesn't support multiple axes
@@ -384,7 +388,7 @@ def any(x: Array,
384
388
# torch doesn't support keepdims with axis=None
385
389
# (https://github.com/pytorch/pytorch/issues/71209)
386
390
res = torch .any (x , ** kwargs )
387
- res = _axis_none_keepdims (res , ndim , keepdims )
391
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
388
392
return res .to (torch .bool )
389
393
390
394
# torch.any doesn't return bool for uint8
@@ -396,7 +400,7 @@ def all(x: Array,
396
400
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
397
401
keepdims : bool = False ,
398
402
** kwargs ) -> Array :
399
- ndim = x . ndim
403
+
400
404
if axis == ():
401
405
return x .to (torch .bool )
402
406
# torch.all doesn't support multiple axes
@@ -408,7 +412,7 @@ def all(x: Array,
408
412
# torch doesn't support keepdims with axis=None
409
413
# (https://github.com/pytorch/pytorch/issues/71209)
410
414
res = torch .all (x , ** kwargs )
411
- res = _axis_none_keepdims (res , ndim , keepdims )
415
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
412
416
return res .to (torch .bool )
413
417
414
418
# torch.all doesn't return bool for uint8
0 commit comments