Skip to content

Commit 801c859

Browse files
committed
ENH: torch.result_type for uint types
1 parent 9194c5c commit 801c859

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

array_api_compat/torch/_aliases.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
try:
2323
# torch >=2.3
2424
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
25+
_HAS_LARGE_UINT = True
2526
except AttributeError:
26-
pass
27-
27+
_HAS_LARGE_UINT = False
2828

2929
_array_api_dtypes = {
3030
torch.bool,
@@ -59,6 +59,28 @@
5959
(torch.float64, torch.complex128): torch.complex128,
6060
}
6161

62+
if _HAS_LARGE_UINT: # torch >=2.3
63+
_promotion_table.update(
64+
{
65+
# uints
66+
(torch.uint8, torch.uint16): torch.uint16,
67+
(torch.uint8, torch.uint32): torch.uint32,
68+
(torch.uint8, torch.uint64): torch.uint64,
69+
(torch.uint16, torch.uint32): torch.uint32,
70+
(torch.uint16, torch.uint64): torch.uint64,
71+
(torch.uint32, torch.uint64): torch.uint64,
72+
# ints and uints (mixed sign)
73+
(torch.uint16, torch.int8): torch.int32,
74+
(torch.uint16, torch.int16): torch.int32,
75+
(torch.uint16, torch.int32): torch.int32,
76+
(torch.uint16, torch.int64): torch.int64,
77+
(torch.uint32, torch.int8): torch.int64,
78+
(torch.uint32, torch.int16): torch.int64,
79+
(torch.uint32, torch.int32): torch.int64,
80+
(torch.uint32, torch.int64): torch.int64,
81+
}
82+
)
83+
6284
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
6385
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
6486

@@ -295,10 +317,16 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
295317
if dtype is not None:
296318
return x.clone() if dtype == x.dtype else x.to(dtype)
297319

298-
# We can't upcast uint8 according to the spec because there is no
299-
# torch.uint64, so at least upcast to int64 which is what prod does
300-
# when axis=None.
301-
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
320+
if x.dtype in (torch.int8, torch.int16, torch.int32):
321+
return x.to(torch.int64)
322+
323+
if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
324+
return x.to(torch.uint64)
325+
326+
if x.dtype == torch.uint8:
327+
# We can't upcast uint8 according to the spec because there is no
328+
# torch.uint64, so at least upcast to int64 which is what prod does
329+
# when axis=None.
302330
return x.to(torch.int64)
303331

304332
return x.clone()

0 commit comments

Comments
 (0)