diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 56ab81a3..1b01f755 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -7,5 +7,3 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch - extra-env-vars: | - ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5b20aabc..ed47af78 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -96,6 +96,53 @@ } +try: + # torch >=2.3 + _uint_promotion_table = { + # uints + (torch.uint8, torch.uint16): torch.uint16, + (torch.uint8, torch.uint32): torch.uint32, + (torch.uint8, torch.uint64): torch.uint64, + (torch.uint16, torch.uint8): torch.uint16, + (torch.uint16, torch.uint16): torch.uint16, + (torch.uint16, torch.uint32): torch.uint32, + (torch.uint16, torch.uint64): torch.uint64, + (torch.uint32, torch.uint8): torch.uint32, + (torch.uint32, torch.uint16): torch.uint32, + (torch.uint32, torch.uint32): torch.uint32, + (torch.uint32, torch.uint64): torch.uint64, + (torch.uint64, torch.uint8): torch.uint64, + (torch.uint64, torch.uint16): torch.uint64, + (torch.uint64, torch.uint32): torch.uint64, + (torch.uint64, torch.uint64): torch.uint64, + # ints and uints (mixed sign) + (torch.int8, torch.uint16): torch.int32, + (torch.int8, torch.uint32): torch.int64, + (torch.int16, torch.uint8): torch.int16, + (torch.int16, torch.uint16): torch.int32, + (torch.int16, torch.uint32): torch.int64, + (torch.int32, torch.uint8): torch.int32, + (torch.int32, torch.uint16): torch.int32, + (torch.int32, torch.uint32): torch.int64, + (torch.int64, torch.uint8): torch.int64, + (torch.int64, torch.uint16): torch.int64, + (torch.int64, torch.uint32): torch.int64, + (torch.uint16, torch.int8): torch.int32, + (torch.uint16, torch.int16): torch.int32, + (torch.uint16, torch.int32): torch.int32, + (torch.uint16, torch.int64): torch.int64, + (torch.uint32, torch.int8): torch.int64, + (torch.uint32, torch.int16): torch.int64, + (torch.uint32, torch.int32): torch.int64, + (torch.uint32, torch.int64): torch.int64, + } +except AttributeError: + _uint_promotion_table = {} + pass + +_promotion_table.update(_uint_promotion_table) + + def _two_arg(f): @_wraps(f) def _f(x1, x2, /, **kwargs): diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..d1cce55b 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -169,16 +169,26 @@ def _dtypes(self, kind): int32 = torch.int32 int64 = torch.int64 uint8 = torch.uint8 - # uint16, uint32, and uint64 are present in newer versions of pytorch, - # but they aren't generally supported by the array API functions, so - # we omit them from this function. + try: + # pytorch >= 2.3 + uint16 = torch.uint16 + uint32 = torch.uint32 + uint64 = torch.uint64 + uint_kinds = { + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + except AttributeError: + uint_kinds = {} + float32 = torch.float32 float64 = torch.float64 complex64 = torch.complex64 complex128 = torch.complex128 if kind is None: - return { + kinds = { "bool": bool, "int8": int8, "int16": int16, @@ -190,6 +200,8 @@ def _dtypes(self, kind): "complex64": complex64, "complex128": complex128, } + kinds.update(uint_kinds) + return kinds if kind == "bool": return {"bool": bool} if kind == "signed integer": @@ -200,17 +212,21 @@ def _dtypes(self, kind): "int64": int64, } if kind == "unsigned integer": - return { + kinds= { "uint8": uint8, } + kinds.update(uint_kinds) + return kinds if kind == "integral": - return { + kinds= { "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, } + kinds.update(uint_kinds) + return kinds if kind == "real floating": return { "float32": float32, @@ -222,7 +238,7 @@ def _dtypes(self, kind): "complex128": complex128, } if kind == "numeric": - return { + kinds = { "int8": int8, "int16": int16, "int32": int32, @@ -233,6 +249,9 @@ def _dtypes(self, kind): "complex64": complex64, "complex128": complex128, } + kinds.update(uint_kinds) + return kinds + if isinstance(kind, tuple): res = {} for k in kind: