|
22 | 22 | try:
|
23 | 23 | # torch >=2.3
|
24 | 24 | _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
|
| 25 | + _HAS_LARGE_UINT = True |
25 | 26 | except AttributeError:
|
26 |
| - pass |
27 |
| - |
| 27 | + _HAS_LARGE_UINT = False |
28 | 28 |
|
29 | 29 | _array_api_dtypes = {
|
30 | 30 | torch.bool,
|
|
59 | 59 | (torch.float64, torch.complex128): torch.complex128,
|
60 | 60 | }
|
61 | 61 |
|
| 62 | +if _HAS_LARGE_UINT: # torch >=2.3 |
| 63 | + _promotion_table.update( |
| 64 | + { |
| 65 | + # uints |
| 66 | + (torch.uint8, torch.uint16): torch.uint16, |
| 67 | + (torch.uint8, torch.uint32): torch.uint32, |
| 68 | + (torch.uint8, torch.uint64): torch.uint64, |
| 69 | + (torch.uint16, torch.uint32): torch.uint32, |
| 70 | + (torch.uint16, torch.uint64): torch.uint64, |
| 71 | + (torch.uint32, torch.uint64): torch.uint64, |
| 72 | + # ints and uints (mixed sign) |
| 73 | + (torch.uint16, torch.int8): torch.int32, |
| 74 | + (torch.uint16, torch.int16): torch.int32, |
| 75 | + (torch.uint16, torch.int32): torch.int32, |
| 76 | + (torch.uint16, torch.int64): torch.int64, |
| 77 | + (torch.uint32, torch.int8): torch.int64, |
| 78 | + (torch.uint32, torch.int16): torch.int64, |
| 79 | + (torch.uint32, torch.int32): torch.int64, |
| 80 | + (torch.uint32, torch.int64): torch.int64, |
| 81 | + } |
| 82 | + ) |
| 83 | + |
62 | 84 | _promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
|
63 | 85 | _promotion_table.update({(a, a): a for a in _array_api_dtypes})
|
64 | 86 |
|
@@ -295,10 +317,16 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
|
295 | 317 | if dtype is not None:
|
296 | 318 | return x.clone() if dtype == x.dtype else x.to(dtype)
|
297 | 319 |
|
298 |
| - # We can't upcast uint8 according to the spec because there is no |
299 |
| - # torch.uint64, so at least upcast to int64 which is what prod does |
300 |
| - # when axis=None. |
301 |
| - if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): |
| 320 | + if x.dtype in (torch.int8, torch.int16, torch.int32): |
| 321 | + return x.to(torch.int64) |
| 322 | + |
| 323 | + if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32): |
| 324 | + return x.to(torch.uint64) |
| 325 | + |
| 326 | + if x.dtype == torch.uint8: |
| 327 | + # We can't upcast uint8 according to the spec because there is no |
| 328 | + # torch.uint64, so at least upcast to int64 which is what prod does |
| 329 | + # when axis=None. |
302 | 330 | return x.to(torch.int64)
|
303 | 331 |
|
304 | 332 | return x.clone()
|
|
0 commit comments