From 42a2ba91c1115c9efb5d302a946d0784236f2466 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 15 Apr 2025 14:18:23 +0100 Subject: [PATCH 1/2] ENH: `torch.result_type` for uint types --- array_api_compat/torch/_aliases.py | 40 +++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 027a0261..bd4d5f8e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -22,9 +22,9 @@ try: # torch >=2.3 _int_dtypes |= {torch.uint16, torch.uint32, torch.uint64} + _HAS_LARGE_UINT = True except AttributeError: - pass - + _HAS_LARGE_UINT = False _array_api_dtypes = { torch.bool, @@ -59,6 +59,28 @@ (torch.float64, torch.complex128): torch.complex128, } +if _HAS_LARGE_UINT: # torch >=2.3 + _promotion_table.update( + { + # uints + (torch.uint8, torch.uint16): torch.uint16, + (torch.uint8, torch.uint32): torch.uint32, + (torch.uint8, torch.uint64): torch.uint64, + (torch.uint16, torch.uint32): torch.uint32, + (torch.uint16, torch.uint64): torch.uint64, + (torch.uint32, torch.uint64): torch.uint64, + # ints and uints (mixed sign) + (torch.uint16, torch.int8): torch.int32, + (torch.uint16, torch.int16): torch.int32, + (torch.uint16, torch.int32): torch.int32, + (torch.uint16, torch.int64): torch.int64, + (torch.uint32, torch.int8): torch.int64, + (torch.uint32, torch.int16): torch.int64, + (torch.uint32, torch.int32): torch.int64, + (torch.uint32, torch.int64): torch.int64, + } + ) + _promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) _promotion_table.update({(a, a): a for a in _array_api_dtypes}) @@ -295,10 +317,16 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: if dtype is not None: return x.clone() if dtype == x.dtype else x.to(dtype) - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what prod does - # when axis=None. - if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): + if x.dtype in (torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + if _HAS_LARGE_UINT and x.dtype in (torch.uint8, torch.uint16, torch.uint32): + return x.to(torch.uint64) + + if x.dtype == torch.uint8: + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. return x.to(torch.int64) return x.clone() From 9cb105c856e118b7d8577553fa5e79716361d396 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 15 Apr 2025 14:35:21 +0100 Subject: [PATCH 2/2] Torch info for large uints --- .github/workflows/array-api-tests-torch.yml | 2 - array_api_compat/torch/_info.py | 95 ++++++++------------- 2 files changed, 37 insertions(+), 60 deletions(-) diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index ac20df25..7a228812 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -8,6 +8,4 @@ jobs: with: package-name: torch extra-requires: '--index-url https://download.pytorch.org/whl/cpu' - extra-env-vars: | - ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 python-versions: '[''3.10'', ''3.13'']' diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 818e5d37..3835f024 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -170,78 +170,58 @@ def default_dtypes(self, *, device=None): "indexing": default_integral, } - def _dtypes(self, kind): - bool = torch.bool - int8 = torch.int8 - int16 = torch.int16 - int32 = torch.int32 - int64 = torch.int64 - uint8 = torch.uint8 - # uint16, uint32, and uint64 are present in newer versions of pytorch, - # but they aren't generally supported by the array API functions, so - # we omit them from this function. - float32 = torch.float32 - float64 = torch.float64 - complex64 = torch.complex64 - complex128 = torch.complex128 - if kind is None: - return { - "bool": bool, - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, - } + return self._dtypes( + ( + "bool", + "signed integer", + "unsigned integer", + "real floating", + "complex floating", + ) + ) if kind == "bool": - return {"bool": bool} + return {"bool": torch.bool} if kind == "signed integer": return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, } if kind == "unsigned integer": - return { - "uint8": uint8, - } + try: + # torch >=2.3 + return { + "uint8": torch.uint8, + "uint16": torch.uint16, + "uint32": torch.uint32, + "uint64": torch.uint32, + } + except AttributeError: + return {"uint8": torch.uint8} if kind == "integral": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - } + return self._dtypes(("signed integer", "unsigned integer")) if kind == "real floating": return { - "float32": float32, - "float64": float64, + "float32": torch.float32, + "float64": torch.float64, } if kind == "complex floating": return { - "complex64": complex64, - "complex128": complex128, + "complex64": torch.complex64, + "complex128": torch.complex128, } if kind == "numeric": - return { - "int8": int8, - "int16": int16, - "int32": int32, - "int64": int64, - "uint8": uint8, - "float32": float32, - "float64": float64, - "complex64": complex64, - "complex128": complex128, - } + return self._dtypes( + ( + "signed integer", + "unsigned integer", + "real floating", + "complex floating", + ) + ) if isinstance(kind, tuple): res = {} for k in kind: @@ -261,7 +241,6 @@ def dtypes(self, *, device=None, kind=None): ---------- device : Device, optional The device to get the data types for. - Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned.