From 3c1e694d9ff03829c6dd382edca74fc49dd19b56 Mon Sep 17 00:00:00 2001
From: crusaderky <crusaderky@gmail.com>
Date: Tue, 1 Apr 2025 09:42:19 +0100
Subject: [PATCH 1/2] TYP: Compact Python scalar types

---
 array_api_strict/_array_object.py          | 80 +++++++++++-----------
 array_api_strict/_creation_functions.py    | 22 +++---
 array_api_strict/_data_type_functions.py   |  4 +-
 array_api_strict/_elementwise_functions.py | 14 ++--
 array_api_strict/_helpers.py               |  4 +-
 array_api_strict/_linalg.py                |  2 +-
 array_api_strict/_searching_functions.py   |  7 +-
 array_api_strict/_statistical_functions.py |  4 +-
 8 files changed, 62 insertions(+), 75 deletions(-)

diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py
index 483952e..edcc8e5 100644
--- a/array_api_strict/_array_object.py
+++ b/array_api_strict/_array_object.py
@@ -191,7 +191,7 @@ def __array__(
     # NumPy behavior
 
     def _check_allowed_dtypes(
-        self, other: Array | bool | int | float | complex, dtype_category: str, op: str
+        self, other: Array | complex, dtype_category: str, op: str
     ) -> Array:
         """
         Helper function for operators to only allow specific input dtypes
@@ -233,7 +233,7 @@ def _check_allowed_dtypes(
 
         return other
 
-    def _check_device(self, other: Array | bool | int | float | complex) -> None:
+    def _check_device(self, other: Array | complex) -> None:
         """Check that other is on a device compatible with the current array"""
         if isinstance(other, (bool, int, float, complex)):
             return
@@ -244,7 +244,7 @@ def _check_device(self, other: Array | bool | int | float | complex) -> None:
             raise TypeError(f"Expected Array | python scalar; got {type(other)}")
 
     # Helper function to match the type promotion rules in the spec
-    def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
+    def _promote_scalar(self, scalar: complex) -> Array:
         """
         Returns a promoted version of a Python scalar appropriate for use with
         operations on self.
@@ -538,7 +538,7 @@ def __abs__(self) -> Array:
         res = self._array.__abs__()
         return self.__class__._new(res, device=self.device)
 
-    def __add__(self, other: Array | int | float | complex, /) -> Array:
+    def __add__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __add__.
         """
@@ -550,7 +550,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__add__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __and__(self, other: Array | bool | int, /) -> Array:
+    def __and__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __and__.
         """
@@ -647,7 +647,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
         # Note: device support is required for this
         return self._array.__dlpack_device__()
 
-    def __eq__(self, other: Array | bool | int | float | complex, /) -> Array:  # type: ignore[override]
+    def __eq__(self, other: Array | complex, /) -> Array:  # type: ignore[override]
         """
         Performs the operation __eq__.
         """
@@ -673,7 +673,7 @@ def __float__(self) -> float:
         res = self._array.__float__()
         return res
 
-    def __floordiv__(self, other: Array | int | float, /) -> Array:
+    def __floordiv__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __floordiv__.
         """
@@ -685,7 +685,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
         res = self._array.__floordiv__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __ge__(self, other: Array | int | float, /) -> Array:
+    def __ge__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __ge__.
         """
@@ -737,7 +737,7 @@ def __getitem__(
         res = self._array.__getitem__(np_key)
         return self._new(res, device=self.device)
 
-    def __gt__(self, other: Array | int | float, /) -> Array:
+    def __gt__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __gt__.
         """
@@ -792,7 +792,7 @@ def __iter__(self) -> Iterator[Array]:
         # implemented, which implies iteration on 1-D arrays.
         return (Array._new(i, device=self.device) for i in self._array)
 
-    def __le__(self, other: Array | int | float, /) -> Array:
+    def __le__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __le__.
         """
@@ -816,7 +816,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
         res = self._array.__lshift__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __lt__(self, other: Array | int | float, /) -> Array:
+    def __lt__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __lt__.
         """
@@ -841,7 +841,7 @@ def __matmul__(self, other: Array, /) -> Array:
         res = self._array.__matmul__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __mod__(self, other: Array | int | float, /) -> Array:
+    def __mod__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __mod__.
         """
@@ -853,7 +853,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
         res = self._array.__mod__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __mul__(self, other: Array | int | float | complex, /) -> Array:
+    def __mul__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __mul__.
         """
@@ -865,7 +865,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__mul__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __ne__(self, other: Array | bool | int | float | complex, /) -> Array:  # type: ignore[override]
+    def __ne__(self, other: Array | complex, /) -> Array:  # type: ignore[override]
         """
         Performs the operation __ne__.
         """
@@ -886,7 +886,7 @@ def __neg__(self) -> Array:
         res = self._array.__neg__()
         return self.__class__._new(res, device=self.device)
 
-    def __or__(self, other: Array | bool | int, /) -> Array:
+    def __or__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __or__.
         """
@@ -907,7 +907,7 @@ def __pos__(self) -> Array:
         res = self._array.__pos__()
         return self.__class__._new(res, device=self.device)
 
-    def __pow__(self, other: Array | int | float | complex, /) -> Array:
+    def __pow__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __pow__.
         """
@@ -944,7 +944,7 @@ def __setitem__(
             | Array
             | tuple[int | slice | EllipsisType, ...]
         ),
-        value: Array | bool | int | float | complex,
+        value: Array | complex,
         /,
     ) -> None:
         """
@@ -957,7 +957,7 @@ def __setitem__(
         np_key = key._array if isinstance(key, Array) else key
         self._array.__setitem__(np_key, asarray(value)._array)
 
-    def __sub__(self, other: Array | int | float | complex, /) -> Array:
+    def __sub__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __sub__.
         """
@@ -971,7 +971,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
 
     # PEP 484 requires int to be a subtype of float, but __truediv__ should
     # not accept int.
-    def __truediv__(self, other: Array | int | float | complex, /) -> Array:
+    def __truediv__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __truediv__.
         """
@@ -983,7 +983,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__truediv__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __xor__(self, other: Array | bool | int, /) -> Array:
+    def __xor__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __xor__.
         """
@@ -995,7 +995,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
         res = self._array.__xor__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __iadd__(self, other: Array | int | float | complex, /) -> Array:
+    def __iadd__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __iadd__.
         """
@@ -1006,7 +1006,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
         self._array.__iadd__(other._array)
         return self
 
-    def __radd__(self, other: Array | int | float | complex, /) -> Array:
+    def __radd__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __radd__.
         """
@@ -1018,7 +1018,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__radd__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __iand__(self, other: Array | bool | int, /) -> Array:
+    def __iand__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __iand__.
         """
@@ -1029,7 +1029,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
         self._array.__iand__(other._array)
         return self
 
-    def __rand__(self, other: Array | bool | int, /) -> Array:
+    def __rand__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __rand__.
         """
@@ -1041,7 +1041,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
         res = self._array.__rand__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __ifloordiv__(self, other: Array | int | float, /) -> Array:
+    def __ifloordiv__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __ifloordiv__.
         """
@@ -1052,7 +1052,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
         self._array.__ifloordiv__(other._array)
         return self
 
-    def __rfloordiv__(self, other: Array | int | float, /) -> Array:
+    def __rfloordiv__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __rfloordiv__.
         """
@@ -1113,7 +1113,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
         res = self._array.__rmatmul__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __imod__(self, other: Array | int | float, /) -> Array:
+    def __imod__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __imod__.
         """
@@ -1123,7 +1123,7 @@ def __imod__(self, other: Array | int | float, /) -> Array:
         self._array.__imod__(other._array)
         return self
 
-    def __rmod__(self, other: Array | int | float, /) -> Array:
+    def __rmod__(self, other: Array | float, /) -> Array:
         """
         Performs the operation __rmod__.
         """
@@ -1135,7 +1135,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
         res = self._array.__rmod__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __imul__(self, other: Array | int | float | complex, /) -> Array:
+    def __imul__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __imul__.
         """
@@ -1145,7 +1145,7 @@ def __imul__(self, other: Array | int | float | complex, /) -> Array:
         self._array.__imul__(other._array)
         return self
 
-    def __rmul__(self, other: Array | int | float | complex, /) -> Array:
+    def __rmul__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __rmul__.
         """
@@ -1157,7 +1157,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__rmul__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __ior__(self, other: Array | bool | int, /) -> Array:
+    def __ior__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __ior__.
         """
@@ -1167,7 +1167,7 @@ def __ior__(self, other: Array | bool | int, /) -> Array:
         self._array.__ior__(other._array)
         return self
 
-    def __ror__(self, other: Array | bool | int, /) -> Array:
+    def __ror__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __ror__.
         """
@@ -1179,7 +1179,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
         res = self._array.__ror__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __ipow__(self, other: Array | int | float | complex, /) -> Array:
+    def __ipow__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __ipow__.
         """
@@ -1189,7 +1189,7 @@ def __ipow__(self, other: Array | int | float | complex, /) -> Array:
         self._array.__ipow__(other._array)
         return self
 
-    def __rpow__(self, other: Array | int | float | complex, /) -> Array:
+    def __rpow__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __rpow__.
         """
@@ -1224,7 +1224,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
         res = self._array.__rrshift__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __isub__(self, other: Array | int | float | complex, /) -> Array:
+    def __isub__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __isub__.
         """
@@ -1234,7 +1234,7 @@ def __isub__(self, other: Array | int | float | complex, /) -> Array:
         self._array.__isub__(other._array)
         return self
 
-    def __rsub__(self, other: Array | int | float | complex, /) -> Array:
+    def __rsub__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __rsub__.
         """
@@ -1246,7 +1246,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__rsub__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
+    def __itruediv__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __itruediv__.
         """
@@ -1256,7 +1256,7 @@ def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
         self._array.__itruediv__(other._array)
         return self
 
-    def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
+    def __rtruediv__(self, other: Array | complex, /) -> Array:
         """
         Performs the operation __rtruediv__.
         """
@@ -1268,7 +1268,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
         res = self._array.__rtruediv__(other._array)
         return self.__class__._new(res, device=self.device)
 
-    def __ixor__(self, other: Array | bool | int, /) -> Array:
+    def __ixor__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __ixor__.
         """
@@ -1278,7 +1278,7 @@ def __ixor__(self, other: Array | bool | int, /) -> Array:
         self._array.__ixor__(other._array)
         return self
 
-    def __rxor__(self, other: Array | bool | int, /) -> Array:
+    def __rxor__(self, other: Array | int, /) -> Array:
         """
         Performs the operation __rxor__.
         """
diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py
index db3897c..69d37aa 100644
--- a/array_api_strict/_creation_functions.py
+++ b/array_api_strict/_creation_functions.py
@@ -68,13 +68,7 @@ def _check_device(device: Device | None) -> None:
 
 
 def asarray(
-    obj: Array
-    | bool
-    | int
-    | float
-    | complex
-    | NestedSequence[bool | int | float | complex]
-    | SupportsBufferProtocol,
+    obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
     /,
     *,
     dtype: DType | None = None,
@@ -135,10 +129,10 @@ def asarray(
 
 
 def arange(
-    start: int | float,
+    start: float,
     /,
-    stop: int | float | None = None,
-    step: int | float = 1,
+    stop: float | None = None,
+    step: float = 1,
     *,
     dtype: DType | None = None,
     device: Device | None = None,
@@ -248,7 +242,7 @@ def from_dlpack(
 
 def full(
     shape: int | tuple[int, ...],
-    fill_value: bool | int | float | complex,
+    fill_value: complex,
     *,
     dtype: DType | None = None,
     device: Device | None = None,
@@ -276,7 +270,7 @@ def full(
 def full_like(
     x: Array,
     /,
-    fill_value: bool | int | float | complex,
+    fill_value: complex,
     *,
     dtype: DType | None = None,
     device: Device | None = None,
@@ -302,8 +296,8 @@ def full_like(
 
 
 def linspace(
-    start: int | float | complex,
-    stop: int | float | complex,
+    start: complex,
+    stop: complex,
     /,
     num: int,
     *,
diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py
index e318724..82d438f 100644
--- a/array_api_strict/_data_type_functions.py
+++ b/array_api_strict/_data_type_functions.py
@@ -210,9 +210,7 @@ def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool:
         raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")
 
 
-def result_type(
-    *arrays_and_dtypes: DType | Array | bool | int | float | complex,
-) -> DType:
+def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType:
     """
     Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
 
diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py
index b05e0fd..8cd42cf 100644
--- a/array_api_strict/_elementwise_functions.py
+++ b/array_api_strict/_elementwise_functions.py
@@ -51,15 +51,15 @@ def inner(x1, x2, /) -> Array:
 # static type annotation for ArrayOrPythonScalar arguments given a category
 # NB: keep the keys in sync with the _dtype_categories dict
 _annotations = {
-    "all": "bool | int | float | complex | Array",
-    "real numeric": "int | float | Array",
-    "numeric": "int | float | complex | Array",
+    "all": "complex | Array",
+    "real numeric": "float | Array",
+    "numeric": "complex | Array",
     "integer": "int | Array",
-    "integer or boolean": "bool | int | Array",
+    "integer or boolean": "int | Array",
     "boolean": "bool | Array",
     "real floating-point": "float | Array",
     "complex floating-point": "complex | Array",
-    "floating-point": "float | complex | Array",
+    "floating-point": "complex | Array",
 }
 
 
@@ -268,8 +268,8 @@ def ceil(x: Array, /) -> Array:
 def clip(
     x: Array,
     /,
-    min: Array | int | float | None = None,
-    max: Array | int | float | None = None,
+    min: Array | float | None = None,
+    max: Array | float | None = None,
 ) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.clip <numpy.clip>`.
diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py
index e8c6767..db58667 100644
--- a/array_api_strict/_helpers.py
+++ b/array_api_strict/_helpers.py
@@ -8,8 +8,8 @@
 
 
 def _maybe_normalize_py_scalars(
-    x1: Array | bool | int | float | complex,
-    x2: Array | bool | int | float | complex,
+    x1: Array | complex,
+    x2: Array | complex,
     dtype_category: str,
     func_name: str,
 ) -> tuple[Array, Array]:
diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py
index 72d7f0a..84a31f6 100644
--- a/array_api_strict/_linalg.py
+++ b/array_api_strict/_linalg.py
@@ -415,7 +415,7 @@ def vector_norm(
     *,
     axis: int | tuple[int, ...] | None = None,
     keepdims: bool = False,
-    ord: int | float = 2,
+    ord: float = 2,
 ) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py
index c42ccc7..3fbb0cf 100644
--- a/array_api_strict/_searching_functions.py
+++ b/array_api_strict/_searching_functions.py
@@ -92,12 +92,7 @@ def searchsorted(
     )
 
 
-def where(
-    condition: Array,
-    x1: Array | bool | int | float | complex,
-    x2: Array | bool | int | float | complex,
-    /,
-) -> Array:
+def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array:
     """
     Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
 
diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py
index 4160f7a..35876dd 100644
--- a/array_api_strict/_statistical_functions.py
+++ b/array_api_strict/_statistical_functions.py
@@ -149,7 +149,7 @@ def std(
     /,
     *,
     axis: int | tuple[int, ...] | None = None,
-    correction: int | float = 0.0,
+    correction: float = 0.0,
     keepdims: bool = False,
 ) -> Array:
     # Note: the keyword argument correction is different here
@@ -181,7 +181,7 @@ def var(
     /,
     *,
     axis: int | tuple[int, ...] | None = None,
-    correction: int | float = 0.0,
+    correction: float = 0.0,
     keepdims: bool = False,
 ) -> Array:
     # Note: the keyword argument correction is different here

From d5a6b67bb05fb73e86a4d850a901683a50c93c5a Mon Sep 17 00:00:00 2001
From: Guido Imperiale <crusaderky@gmail.com>
Date: Wed, 7 May 2025 16:01:42 +0100
Subject: [PATCH 2/2] Update array_api_strict/_array_object.py

---
 array_api_strict/_array_object.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py
index a46d950..7242055 100644
--- a/array_api_strict/_array_object.py
+++ b/array_api_strict/_array_object.py
@@ -233,7 +233,7 @@ def _check_allowed_dtypes(
 
         return other
 
-    def _check_device(self, other: Array | complex) -> None:
+    def _check_type_device(self, other: Array | complex) -> None:
         """Check that other is either a Python scalar or an array on a device
         compatible with the current array.
         """