diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 579da90..a46d950 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -191,7 +191,7 @@ def __array__( # NumPy behavior def _check_allowed_dtypes( - self, other: Array | bool | int | float | complex, dtype_category: str, op: str + self, other: Array | complex, dtype_category: str, op: str ) -> Array: """ Helper function for operators to only allow specific input dtypes @@ -233,7 +233,7 @@ def _check_allowed_dtypes( return other - def _check_type_device(self, other: Array | bool | int | float | complex) -> None: + def _check_device(self, other: Array | complex) -> None: """Check that other is either a Python scalar or an array on a device compatible with the current array. """ @@ -245,7 +245,7 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non raise TypeError(f"Expected Array or Python scalar; got {type(other)}") # Helper function to match the type promotion rules in the spec - def _promote_scalar(self, scalar: bool | int | float | complex) -> Array: + def _promote_scalar(self, scalar: complex) -> Array: """ Returns a promoted version of a Python scalar appropriate for use with operations on self. @@ -539,7 +539,7 @@ def __abs__(self) -> Array: res = self._array.__abs__() return self.__class__._new(res, device=self.device) - def __add__(self, other: Array | int | float | complex, /) -> Array: + def __add__(self, other: Array | complex, /) -> Array: """ Performs the operation __add__. """ @@ -551,7 +551,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__add__(other._array) return self.__class__._new(res, device=self.device) - def __and__(self, other: Array | bool | int, /) -> Array: + def __and__(self, other: Array | int, /) -> Array: """ Performs the operation __and__. """ @@ -648,7 +648,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]: # Note: device support is required for this return self._array.__dlpack_device__() - def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] + def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override] """ Performs the operation __eq__. """ @@ -674,7 +674,7 @@ def __float__(self) -> float: res = self._array.__float__() return res - def __floordiv__(self, other: Array | int | float, /) -> Array: + def __floordiv__(self, other: Array | float, /) -> Array: """ Performs the operation __floordiv__. """ @@ -686,7 +686,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array: res = self._array.__floordiv__(other._array) return self.__class__._new(res, device=self.device) - def __ge__(self, other: Array | int | float, /) -> Array: + def __ge__(self, other: Array | float, /) -> Array: """ Performs the operation __ge__. """ @@ -738,7 +738,7 @@ def __getitem__( res = self._array.__getitem__(np_key) return self._new(res, device=self.device) - def __gt__(self, other: Array | int | float, /) -> Array: + def __gt__(self, other: Array | float, /) -> Array: """ Performs the operation __gt__. """ @@ -793,7 +793,7 @@ def __iter__(self) -> Iterator[Array]: # implemented, which implies iteration on 1-D arrays. return (Array._new(i, device=self.device) for i in self._array) - def __le__(self, other: Array | int | float, /) -> Array: + def __le__(self, other: Array | float, /) -> Array: """ Performs the operation __le__. """ @@ -817,7 +817,7 @@ def __lshift__(self, other: Array | int, /) -> Array: res = self._array.__lshift__(other._array) return self.__class__._new(res, device=self.device) - def __lt__(self, other: Array | int | float, /) -> Array: + def __lt__(self, other: Array | float, /) -> Array: """ Performs the operation __lt__. """ @@ -842,7 +842,7 @@ def __matmul__(self, other: Array, /) -> Array: res = self._array.__matmul__(other._array) return self.__class__._new(res, device=self.device) - def __mod__(self, other: Array | int | float, /) -> Array: + def __mod__(self, other: Array | float, /) -> Array: """ Performs the operation __mod__. """ @@ -854,7 +854,7 @@ def __mod__(self, other: Array | int | float, /) -> Array: res = self._array.__mod__(other._array) return self.__class__._new(res, device=self.device) - def __mul__(self, other: Array | int | float | complex, /) -> Array: + def __mul__(self, other: Array | complex, /) -> Array: """ Performs the operation __mul__. """ @@ -866,7 +866,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__mul__(other._array) return self.__class__._new(res, device=self.device) - def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override] + def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override] """ Performs the operation __ne__. """ @@ -887,7 +887,7 @@ def __neg__(self) -> Array: res = self._array.__neg__() return self.__class__._new(res, device=self.device) - def __or__(self, other: Array | bool | int, /) -> Array: + def __or__(self, other: Array | int, /) -> Array: """ Performs the operation __or__. """ @@ -908,7 +908,7 @@ def __pos__(self) -> Array: res = self._array.__pos__() return self.__class__._new(res, device=self.device) - def __pow__(self, other: Array | int | float | complex, /) -> Array: + def __pow__(self, other: Array | complex, /) -> Array: """ Performs the operation __pow__. """ @@ -945,7 +945,7 @@ def __setitem__( | Array | tuple[int | slice | EllipsisType, ...] ), - value: Array | bool | int | float | complex, + value: Array | complex, /, ) -> None: """ @@ -958,7 +958,7 @@ def __setitem__( np_key = key._array if isinstance(key, Array) else key self._array.__setitem__(np_key, asarray(value)._array) - def __sub__(self, other: Array | int | float | complex, /) -> Array: + def __sub__(self, other: Array | complex, /) -> Array: """ Performs the operation __sub__. """ @@ -972,7 +972,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array: # PEP 484 requires int to be a subtype of float, but __truediv__ should # not accept int. - def __truediv__(self, other: Array | int | float | complex, /) -> Array: + def __truediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __truediv__. """ @@ -984,7 +984,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__truediv__(other._array) return self.__class__._new(res, device=self.device) - def __xor__(self, other: Array | bool | int, /) -> Array: + def __xor__(self, other: Array | int, /) -> Array: """ Performs the operation __xor__. """ @@ -996,7 +996,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array: res = self._array.__xor__(other._array) return self.__class__._new(res, device=self.device) - def __iadd__(self, other: Array | int | float | complex, /) -> Array: + def __iadd__(self, other: Array | complex, /) -> Array: """ Performs the operation __iadd__. """ @@ -1007,7 +1007,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array: self._array.__iadd__(other._array) return self - def __radd__(self, other: Array | int | float | complex, /) -> Array: + def __radd__(self, other: Array | complex, /) -> Array: """ Performs the operation __radd__. """ @@ -1019,7 +1019,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__radd__(other._array) return self.__class__._new(res, device=self.device) - def __iand__(self, other: Array | bool | int, /) -> Array: + def __iand__(self, other: Array | int, /) -> Array: """ Performs the operation __iand__. """ @@ -1030,7 +1030,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array: self._array.__iand__(other._array) return self - def __rand__(self, other: Array | bool | int, /) -> Array: + def __rand__(self, other: Array | int, /) -> Array: """ Performs the operation __rand__. """ @@ -1042,7 +1042,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array: res = self._array.__rand__(other._array) return self.__class__._new(res, device=self.device) - def __ifloordiv__(self, other: Array | int | float, /) -> Array: + def __ifloordiv__(self, other: Array | float, /) -> Array: """ Performs the operation __ifloordiv__. """ @@ -1053,7 +1053,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array: self._array.__ifloordiv__(other._array) return self - def __rfloordiv__(self, other: Array | int | float, /) -> Array: + def __rfloordiv__(self, other: Array | float, /) -> Array: """ Performs the operation __rfloordiv__. """ @@ -1114,7 +1114,7 @@ def __rmatmul__(self, other: Array, /) -> Array: res = self._array.__rmatmul__(other._array) return self.__class__._new(res, device=self.device) - def __imod__(self, other: Array | int | float, /) -> Array: + def __imod__(self, other: Array | float, /) -> Array: """ Performs the operation __imod__. """ @@ -1124,7 +1124,7 @@ def __imod__(self, other: Array | int | float, /) -> Array: self._array.__imod__(other._array) return self - def __rmod__(self, other: Array | int | float, /) -> Array: + def __rmod__(self, other: Array | float, /) -> Array: """ Performs the operation __rmod__. """ @@ -1136,7 +1136,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array: res = self._array.__rmod__(other._array) return self.__class__._new(res, device=self.device) - def __imul__(self, other: Array | int | float | complex, /) -> Array: + def __imul__(self, other: Array | complex, /) -> Array: """ Performs the operation __imul__. """ @@ -1146,7 +1146,7 @@ def __imul__(self, other: Array | int | float | complex, /) -> Array: self._array.__imul__(other._array) return self - def __rmul__(self, other: Array | int | float | complex, /) -> Array: + def __rmul__(self, other: Array | complex, /) -> Array: """ Performs the operation __rmul__. """ @@ -1158,7 +1158,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__rmul__(other._array) return self.__class__._new(res, device=self.device) - def __ior__(self, other: Array | bool | int, /) -> Array: + def __ior__(self, other: Array | int, /) -> Array: """ Performs the operation __ior__. """ @@ -1168,7 +1168,7 @@ def __ior__(self, other: Array | bool | int, /) -> Array: self._array.__ior__(other._array) return self - def __ror__(self, other: Array | bool | int, /) -> Array: + def __ror__(self, other: Array | int, /) -> Array: """ Performs the operation __ror__. """ @@ -1180,7 +1180,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array: res = self._array.__ror__(other._array) return self.__class__._new(res, device=self.device) - def __ipow__(self, other: Array | int | float | complex, /) -> Array: + def __ipow__(self, other: Array | complex, /) -> Array: """ Performs the operation __ipow__. """ @@ -1190,7 +1190,7 @@ def __ipow__(self, other: Array | int | float | complex, /) -> Array: self._array.__ipow__(other._array) return self - def __rpow__(self, other: Array | int | float | complex, /) -> Array: + def __rpow__(self, other: Array | complex, /) -> Array: """ Performs the operation __rpow__. """ @@ -1225,7 +1225,7 @@ def __rrshift__(self, other: Array | int, /) -> Array: res = self._array.__rrshift__(other._array) return self.__class__._new(res, device=self.device) - def __isub__(self, other: Array | int | float | complex, /) -> Array: + def __isub__(self, other: Array | complex, /) -> Array: """ Performs the operation __isub__. """ @@ -1235,7 +1235,7 @@ def __isub__(self, other: Array | int | float | complex, /) -> Array: self._array.__isub__(other._array) return self - def __rsub__(self, other: Array | int | float | complex, /) -> Array: + def __rsub__(self, other: Array | complex, /) -> Array: """ Performs the operation __rsub__. """ @@ -1247,7 +1247,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__rsub__(other._array) return self.__class__._new(res, device=self.device) - def __itruediv__(self, other: Array | int | float | complex, /) -> Array: + def __itruediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __itruediv__. """ @@ -1257,7 +1257,7 @@ def __itruediv__(self, other: Array | int | float | complex, /) -> Array: self._array.__itruediv__(other._array) return self - def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: + def __rtruediv__(self, other: Array | complex, /) -> Array: """ Performs the operation __rtruediv__. """ @@ -1269,7 +1269,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array: res = self._array.__rtruediv__(other._array) return self.__class__._new(res, device=self.device) - def __ixor__(self, other: Array | bool | int, /) -> Array: + def __ixor__(self, other: Array | int, /) -> Array: """ Performs the operation __ixor__. """ @@ -1279,7 +1279,7 @@ def __ixor__(self, other: Array | bool | int, /) -> Array: self._array.__ixor__(other._array) return self - def __rxor__(self, other: Array | bool | int, /) -> Array: + def __rxor__(self, other: Array | int, /) -> Array: """ Performs the operation __rxor__. """ diff --git a/array_api_strict/_creation_functions.py b/array_api_strict/_creation_functions.py index db3897c..69d37aa 100644 --- a/array_api_strict/_creation_functions.py +++ b/array_api_strict/_creation_functions.py @@ -68,13 +68,7 @@ def _check_device(device: Device | None) -> None: def asarray( - obj: Array - | bool - | int - | float - | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, @@ -135,10 +129,10 @@ def asarray( def arange( - start: int | float, + start: float, /, - stop: int | float | None = None, - step: int | float = 1, + stop: float | None = None, + step: float = 1, *, dtype: DType | None = None, device: Device | None = None, @@ -248,7 +242,7 @@ def from_dlpack( def full( shape: int | tuple[int, ...], - fill_value: bool | int | float | complex, + fill_value: complex, *, dtype: DType | None = None, device: Device | None = None, @@ -276,7 +270,7 @@ def full( def full_like( x: Array, /, - fill_value: bool | int | float | complex, + fill_value: complex, *, dtype: DType | None = None, device: Device | None = None, @@ -302,8 +296,8 @@ def full_like( def linspace( - start: int | float | complex, - stop: int | float | complex, + start: complex, + stop: complex, /, num: int, *, diff --git a/array_api_strict/_data_type_functions.py b/array_api_strict/_data_type_functions.py index e318724..82d438f 100644 --- a/array_api_strict/_data_type_functions.py +++ b/array_api_strict/_data_type_functions.py @@ -210,9 +210,7 @@ def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool: raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") -def result_type( - *arrays_and_dtypes: DType | Array | bool | int | float | complex, -) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: """ Array API compatible wrapper for :py:func:`np.result_type `. diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index b05e0fd..8cd42cf 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -51,15 +51,15 @@ def inner(x1, x2, /) -> Array: # static type annotation for ArrayOrPythonScalar arguments given a category # NB: keep the keys in sync with the _dtype_categories dict _annotations = { - "all": "bool | int | float | complex | Array", - "real numeric": "int | float | Array", - "numeric": "int | float | complex | Array", + "all": "complex | Array", + "real numeric": "float | Array", + "numeric": "complex | Array", "integer": "int | Array", - "integer or boolean": "bool | int | Array", + "integer or boolean": "int | Array", "boolean": "bool | Array", "real floating-point": "float | Array", "complex floating-point": "complex | Array", - "floating-point": "float | complex | Array", + "floating-point": "complex | Array", } @@ -268,8 +268,8 @@ def ceil(x: Array, /) -> Array: def clip( x: Array, /, - min: Array | int | float | None = None, - max: Array | int | float | None = None, + min: Array | float | None = None, + max: Array | float | None = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.clip `. diff --git a/array_api_strict/_helpers.py b/array_api_strict/_helpers.py index e8c6767..db58667 100644 --- a/array_api_strict/_helpers.py +++ b/array_api_strict/_helpers.py @@ -8,8 +8,8 @@ def _maybe_normalize_py_scalars( - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, + x1: Array | complex, + x2: Array | complex, dtype_category: str, func_name: str, ) -> tuple[Array, Array]: diff --git a/array_api_strict/_linalg.py b/array_api_strict/_linalg.py index 72d7f0a..84a31f6 100644 --- a/array_api_strict/_linalg.py +++ b/array_api_strict/_linalg.py @@ -415,7 +415,7 @@ def vector_norm( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: int | float = 2, + ord: float = 2, ) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index c42ccc7..3fbb0cf 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -92,12 +92,7 @@ def searchsorted( ) -def where( - condition: Array, - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, - /, -) -> Array: +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: """ Array API compatible wrapper for :py:func:`np.where `. diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 4160f7a..35876dd 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -149,7 +149,7 @@ def std( /, *, axis: int | tuple[int, ...] | None = None, - correction: int | float = 0.0, + correction: float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here @@ -181,7 +181,7 @@ def var( /, *, axis: int | tuple[int, ...] | None = None, - correction: int | float = 0.0, + correction: float = 0.0, keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here