Skip to content

Commit 845d8b5

Browse files
committed
ENH: torch dtype promotions
1 parent 3e5fdc0 commit 845d8b5

File tree

1 file changed

+61
-57
lines changed

1 file changed

+61
-57
lines changed

array_api_compat/torch/_aliases.py

+61-57
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
try:
2323
# torch >=2.3
2424
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
25+
_HAS_LARGE_UINT = True
2526
except AttributeError:
26-
pass
27-
27+
_HAS_LARGE_UINT = False
2828

2929
_array_api_dtypes = {
3030
torch.bool,
@@ -35,54 +35,55 @@
3535
torch.complex128,
3636
}
3737

38-
_promotion_table = {
39-
# bool
40-
(torch.bool, torch.bool): torch.bool,
38+
_promotion_table = {
4139
# ints
42-
(torch.int8, torch.int8): torch.int8,
4340
(torch.int8, torch.int16): torch.int16,
4441
(torch.int8, torch.int32): torch.int32,
4542
(torch.int8, torch.int64): torch.int64,
46-
(torch.int16, torch.int8): torch.int16,
47-
(torch.int16, torch.int16): torch.int16,
4843
(torch.int16, torch.int32): torch.int32,
4944
(torch.int16, torch.int64): torch.int64,
50-
(torch.int32, torch.int8): torch.int32,
51-
(torch.int32, torch.int16): torch.int32,
52-
(torch.int32, torch.int32): torch.int32,
5345
(torch.int32, torch.int64): torch.int64,
54-
(torch.int64, torch.int8): torch.int64,
55-
(torch.int64, torch.int16): torch.int64,
56-
(torch.int64, torch.int32): torch.int64,
57-
(torch.int64, torch.int64): torch.int64,
58-
# uints
59-
(torch.uint8, torch.uint8): torch.uint8,
6046
# ints and uints (mixed sign)
61-
(torch.int8, torch.uint8): torch.int16,
62-
(torch.int16, torch.uint8): torch.int16,
63-
(torch.int32, torch.uint8): torch.int32,
64-
(torch.int64, torch.uint8): torch.int64,
6547
(torch.uint8, torch.int8): torch.int16,
6648
(torch.uint8, torch.int16): torch.int16,
6749
(torch.uint8, torch.int32): torch.int32,
6850
(torch.uint8, torch.int64): torch.int64,
6951
# floats
70-
(torch.float32, torch.float32): torch.float32,
7152
(torch.float32, torch.float64): torch.float64,
72-
(torch.float64, torch.float32): torch.float64,
73-
(torch.float64, torch.float64): torch.float64,
7453
# complexes
75-
(torch.complex64, torch.complex64): torch.complex64,
7654
(torch.complex64, torch.complex128): torch.complex128,
77-
(torch.complex128, torch.complex64): torch.complex128,
78-
(torch.complex128, torch.complex128): torch.complex128,
7955
# Mixed float and complex
8056
(torch.float32, torch.complex64): torch.complex64,
8157
(torch.float32, torch.complex128): torch.complex128,
8258
(torch.float64, torch.complex64): torch.complex128,
8359
(torch.float64, torch.complex128): torch.complex128,
8460
}
8561

62+
if _HAS_LARGE_UINT: # torch >=2.3
63+
_promotion_table.update(
64+
{
65+
# uints
66+
(torch.uint8, torch.uint16): torch.uint16,
67+
(torch.uint8, torch.uint32): torch.uint32,
68+
(torch.uint8, torch.uint64): torch.uint64,
69+
(torch.uint16, torch.uint32): torch.uint32,
70+
(torch.uint16, torch.uint64): torch.uint64,
71+
(torch.uint32, torch.uint64): torch.uint64,
72+
# ints and uints (mixed sign)
73+
(torch.uint16, torch.int8): torch.int32,
74+
(torch.uint16, torch.int16): torch.int32,
75+
(torch.uint16, torch.int32): torch.int32,
76+
(torch.uint16, torch.int64): torch.int64,
77+
(torch.uint32, torch.int8): torch.int64,
78+
(torch.uint32, torch.int16): torch.int64,
79+
(torch.uint32, torch.int32): torch.int64,
80+
(torch.uint32, torch.int64): torch.int64,
81+
}
82+
)
83+
84+
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
85+
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
86+
8687

8788
def _two_arg(f):
8889
@_wraps(f)
@@ -301,27 +302,41 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
301302
out = torch.unsqueeze(out, a)
302303
return out
303304

305+
306+
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
307+
"""
308+
Implements `sum(..., axis=())` and `prod(..., axis=())`.
309+
310+
Works around https://github.com/pytorch/pytorch/issues/29137
311+
"""
312+
if dtype is not None:
313+
return x.clone() if dtype == x.dtype else x.to(dtype)
314+
315+
if x.dtype in (torch.int8, torch.int16, torch.int32):
316+
return x.to(torch.int64)
317+
318+
if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32):
319+
return x.to(torch.uint64)
320+
321+
if x.dtype == torch.uint8:
322+
# We can't upcast uint8 according to the spec because there is no
323+
# torch.uint64, so at least upcast to int64 which is what prod does
324+
# when axis=None.
325+
return x.to(torch.int64)
326+
327+
return x.clone()
328+
329+
304330
def prod(x: Array,
305331
/,
306332
*,
307333
axis: Optional[Union[int, Tuple[int, ...]]] = None,
308334
dtype: Optional[DType] = None,
309335
keepdims: bool = False,
310336
**kwargs) -> Array:
311-
ndim = x.ndim
312337

313-
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
314-
# below because it still needs to upcast.
315338
if axis == ():
316-
if dtype is None:
317-
# We can't upcast uint8 according to the spec because there is no
318-
# torch.uint64, so at least upcast to int64 which is what sum does
319-
# when axis=None.
320-
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
321-
return x.to(torch.int64)
322-
return x.clone()
323-
return x.to(dtype)
324-
339+
return _sum_prod_no_axis(x, dtype)
325340
# torch.prod doesn't support multiple axes
326341
# (https://github.com/pytorch/pytorch/issues/56586).
327342
if isinstance(axis, tuple):
@@ -330,7 +345,7 @@ def prod(x: Array,
330345
# torch doesn't support keepdims with axis=None
331346
# (https://github.com/pytorch/pytorch/issues/71209)
332347
res = torch.prod(x, dtype=dtype, **kwargs)
333-
res = _axis_none_keepdims(res, ndim, keepdims)
348+
res = _axis_none_keepdims(res, x.ndim, keepdims)
334349
return res
335350

336351
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -343,25 +358,14 @@ def sum(x: Array,
343358
dtype: Optional[DType] = None,
344359
keepdims: bool = False,
345360
**kwargs) -> Array:
346-
ndim = x.ndim
347361

348-
# https://github.com/pytorch/pytorch/issues/29137.
349-
# Make sure it upcasts.
350362
if axis == ():
351-
if dtype is None:
352-
# We can't upcast uint8 according to the spec because there is no
353-
# torch.uint64, so at least upcast to int64 which is what sum does
354-
# when axis=None.
355-
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
356-
return x.to(torch.int64)
357-
return x.clone()
358-
return x.to(dtype)
359-
363+
return _sum_prod_no_axis(x, dtype)
360364
if axis is None:
361365
# torch doesn't support keepdims with axis=None
362366
# (https://github.com/pytorch/pytorch/issues/71209)
363367
res = torch.sum(x, dtype=dtype, **kwargs)
364-
res = _axis_none_keepdims(res, ndim, keepdims)
368+
res = _axis_none_keepdims(res, x.ndim, keepdims)
365369
return res
366370

367371
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -372,7 +376,7 @@ def any(x: Array,
372376
axis: Optional[Union[int, Tuple[int, ...]]] = None,
373377
keepdims: bool = False,
374378
**kwargs) -> Array:
375-
ndim = x.ndim
379+
376380
if axis == ():
377381
return x.to(torch.bool)
378382
# torch.any doesn't support multiple axes
@@ -384,7 +388,7 @@ def any(x: Array,
384388
# torch doesn't support keepdims with axis=None
385389
# (https://github.com/pytorch/pytorch/issues/71209)
386390
res = torch.any(x, **kwargs)
387-
res = _axis_none_keepdims(res, ndim, keepdims)
391+
res = _axis_none_keepdims(res, x.ndim, keepdims)
388392
return res.to(torch.bool)
389393

390394
# torch.any doesn't return bool for uint8
@@ -396,7 +400,7 @@ def all(x: Array,
396400
axis: Optional[Union[int, Tuple[int, ...]]] = None,
397401
keepdims: bool = False,
398402
**kwargs) -> Array:
399-
ndim = x.ndim
403+
400404
if axis == ():
401405
return x.to(torch.bool)
402406
# torch.all doesn't support multiple axes
@@ -408,7 +412,7 @@ def all(x: Array,
408412
# torch doesn't support keepdims with axis=None
409413
# (https://github.com/pytorch/pytorch/issues/71209)
410414
res = torch.all(x, **kwargs)
411-
res = _axis_none_keepdims(res, ndim, keepdims)
415+
res = _axis_none_keepdims(res, x.ndim, keepdims)
412416
return res.to(torch.bool)
413417

414418
# torch.all doesn't return bool for uint8

0 commit comments

Comments
 (0)