Skip to content

TYP: Compact Python scalar types #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __array__(
# NumPy behavior

def _check_allowed_dtypes(
self, other: Array | bool | int | float | complex, dtype_category: str, op: str
self, other: Array | complex, dtype_category: str, op: str
) -> Array:
"""
Helper function for operators to only allow specific input dtypes
Expand Down Expand Up @@ -233,7 +233,7 @@ def _check_allowed_dtypes(

return other

def _check_type_device(self, other: Array | bool | int | float | complex) -> None:
def _check_device(self, other: Array | complex) -> None:
"""Check that other is either a Python scalar or an array on a device
compatible with the current array.
"""
Expand All @@ -245,7 +245,7 @@ def _check_type_device(self, other: Array | bool | int | float | complex) -> Non
raise TypeError(f"Expected Array or Python scalar; got {type(other)}")

# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
def _promote_scalar(self, scalar: complex) -> Array:
"""
Returns a promoted version of a Python scalar appropriate for use with
operations on self.
Expand Down Expand Up @@ -539,7 +539,7 @@ def __abs__(self) -> Array:
res = self._array.__abs__()
return self.__class__._new(res, device=self.device)

def __add__(self, other: Array | int | float | complex, /) -> Array:
def __add__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __add__.
"""
Expand All @@ -551,7 +551,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__add__(other._array)
return self.__class__._new(res, device=self.device)

def __and__(self, other: Array | bool | int, /) -> Array:
def __and__(self, other: Array | int, /) -> Array:
"""
Performs the operation __and__.
"""
Expand Down Expand Up @@ -648,7 +648,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
# Note: device support is required for this
return self._array.__dlpack_device__()

def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override]
"""
Performs the operation __eq__.
"""
Expand All @@ -674,7 +674,7 @@ def __float__(self) -> float:
res = self._array.__float__()
return res

def __floordiv__(self, other: Array | int | float, /) -> Array:
def __floordiv__(self, other: Array | float, /) -> Array:
"""
Performs the operation __floordiv__.
"""
Expand All @@ -686,7 +686,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
res = self._array.__floordiv__(other._array)
return self.__class__._new(res, device=self.device)

def __ge__(self, other: Array | int | float, /) -> Array:
def __ge__(self, other: Array | float, /) -> Array:
"""
Performs the operation __ge__.
"""
Expand Down Expand Up @@ -738,7 +738,7 @@ def __getitem__(
res = self._array.__getitem__(np_key)
return self._new(res, device=self.device)

def __gt__(self, other: Array | int | float, /) -> Array:
def __gt__(self, other: Array | float, /) -> Array:
"""
Performs the operation __gt__.
"""
Expand Down Expand Up @@ -793,7 +793,7 @@ def __iter__(self) -> Iterator[Array]:
# implemented, which implies iteration on 1-D arrays.
return (Array._new(i, device=self.device) for i in self._array)

def __le__(self, other: Array | int | float, /) -> Array:
def __le__(self, other: Array | float, /) -> Array:
"""
Performs the operation __le__.
"""
Expand All @@ -817,7 +817,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
res = self._array.__lshift__(other._array)
return self.__class__._new(res, device=self.device)

def __lt__(self, other: Array | int | float, /) -> Array:
def __lt__(self, other: Array | float, /) -> Array:
"""
Performs the operation __lt__.
"""
Expand All @@ -842,7 +842,7 @@ def __matmul__(self, other: Array, /) -> Array:
res = self._array.__matmul__(other._array)
return self.__class__._new(res, device=self.device)

def __mod__(self, other: Array | int | float, /) -> Array:
def __mod__(self, other: Array | float, /) -> Array:
"""
Performs the operation __mod__.
"""
Expand All @@ -854,7 +854,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
res = self._array.__mod__(other._array)
return self.__class__._new(res, device=self.device)

def __mul__(self, other: Array | int | float | complex, /) -> Array:
def __mul__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __mul__.
"""
Expand All @@ -866,7 +866,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__mul__(other._array)
return self.__class__._new(res, device=self.device)

def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override]
"""
Performs the operation __ne__.
"""
Expand All @@ -887,7 +887,7 @@ def __neg__(self) -> Array:
res = self._array.__neg__()
return self.__class__._new(res, device=self.device)

def __or__(self, other: Array | bool | int, /) -> Array:
def __or__(self, other: Array | int, /) -> Array:
"""
Performs the operation __or__.
"""
Expand All @@ -908,7 +908,7 @@ def __pos__(self) -> Array:
res = self._array.__pos__()
return self.__class__._new(res, device=self.device)

def __pow__(self, other: Array | int | float | complex, /) -> Array:
def __pow__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __pow__.
"""
Expand Down Expand Up @@ -945,7 +945,7 @@ def __setitem__(
| Array
| tuple[int | slice | EllipsisType, ...]
),
value: Array | bool | int | float | complex,
value: Array | complex,
/,
) -> None:
"""
Expand All @@ -958,7 +958,7 @@ def __setitem__(
np_key = key._array if isinstance(key, Array) else key
self._array.__setitem__(np_key, asarray(value)._array)

def __sub__(self, other: Array | int | float | complex, /) -> Array:
def __sub__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __sub__.
"""
Expand All @@ -972,7 +972,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:

# PEP 484 requires int to be a subtype of float, but __truediv__ should
# not accept int.
def __truediv__(self, other: Array | int | float | complex, /) -> Array:
def __truediv__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __truediv__.
"""
Expand All @@ -984,7 +984,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__truediv__(other._array)
return self.__class__._new(res, device=self.device)

def __xor__(self, other: Array | bool | int, /) -> Array:
def __xor__(self, other: Array | int, /) -> Array:
"""
Performs the operation __xor__.
"""
Expand All @@ -996,7 +996,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
res = self._array.__xor__(other._array)
return self.__class__._new(res, device=self.device)

def __iadd__(self, other: Array | int | float | complex, /) -> Array:
def __iadd__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __iadd__.
"""
Expand All @@ -1007,7 +1007,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
self._array.__iadd__(other._array)
return self

def __radd__(self, other: Array | int | float | complex, /) -> Array:
def __radd__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __radd__.
"""
Expand All @@ -1019,7 +1019,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__radd__(other._array)
return self.__class__._new(res, device=self.device)

def __iand__(self, other: Array | bool | int, /) -> Array:
def __iand__(self, other: Array | int, /) -> Array:
"""
Performs the operation __iand__.
"""
Expand All @@ -1030,7 +1030,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
self._array.__iand__(other._array)
return self

def __rand__(self, other: Array | bool | int, /) -> Array:
def __rand__(self, other: Array | int, /) -> Array:
"""
Performs the operation __rand__.
"""
Expand All @@ -1042,7 +1042,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
res = self._array.__rand__(other._array)
return self.__class__._new(res, device=self.device)

def __ifloordiv__(self, other: Array | int | float, /) -> Array:
def __ifloordiv__(self, other: Array | float, /) -> Array:
"""
Performs the operation __ifloordiv__.
"""
Expand All @@ -1053,7 +1053,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
self._array.__ifloordiv__(other._array)
return self

def __rfloordiv__(self, other: Array | int | float, /) -> Array:
def __rfloordiv__(self, other: Array | float, /) -> Array:
"""
Performs the operation __rfloordiv__.
"""
Expand Down Expand Up @@ -1114,7 +1114,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
res = self._array.__rmatmul__(other._array)
return self.__class__._new(res, device=self.device)

def __imod__(self, other: Array | int | float, /) -> Array:
def __imod__(self, other: Array | float, /) -> Array:
"""
Performs the operation __imod__.
"""
Expand All @@ -1124,7 +1124,7 @@ def __imod__(self, other: Array | int | float, /) -> Array:
self._array.__imod__(other._array)
return self

def __rmod__(self, other: Array | int | float, /) -> Array:
def __rmod__(self, other: Array | float, /) -> Array:
"""
Performs the operation __rmod__.
"""
Expand All @@ -1136,7 +1136,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
res = self._array.__rmod__(other._array)
return self.__class__._new(res, device=self.device)

def __imul__(self, other: Array | int | float | complex, /) -> Array:
def __imul__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __imul__.
"""
Expand All @@ -1146,7 +1146,7 @@ def __imul__(self, other: Array | int | float | complex, /) -> Array:
self._array.__imul__(other._array)
return self

def __rmul__(self, other: Array | int | float | complex, /) -> Array:
def __rmul__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __rmul__.
"""
Expand All @@ -1158,7 +1158,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__rmul__(other._array)
return self.__class__._new(res, device=self.device)

def __ior__(self, other: Array | bool | int, /) -> Array:
def __ior__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ior__.
"""
Expand All @@ -1168,7 +1168,7 @@ def __ior__(self, other: Array | bool | int, /) -> Array:
self._array.__ior__(other._array)
return self

def __ror__(self, other: Array | bool | int, /) -> Array:
def __ror__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ror__.
"""
Expand All @@ -1180,7 +1180,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
res = self._array.__ror__(other._array)
return self.__class__._new(res, device=self.device)

def __ipow__(self, other: Array | int | float | complex, /) -> Array:
def __ipow__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __ipow__.
"""
Expand All @@ -1190,7 +1190,7 @@ def __ipow__(self, other: Array | int | float | complex, /) -> Array:
self._array.__ipow__(other._array)
return self

def __rpow__(self, other: Array | int | float | complex, /) -> Array:
def __rpow__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __rpow__.
"""
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
res = self._array.__rrshift__(other._array)
return self.__class__._new(res, device=self.device)

def __isub__(self, other: Array | int | float | complex, /) -> Array:
def __isub__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __isub__.
"""
Expand All @@ -1235,7 +1235,7 @@ def __isub__(self, other: Array | int | float | complex, /) -> Array:
self._array.__isub__(other._array)
return self

def __rsub__(self, other: Array | int | float | complex, /) -> Array:
def __rsub__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __rsub__.
"""
Expand All @@ -1247,7 +1247,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__rsub__(other._array)
return self.__class__._new(res, device=self.device)

def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
def __itruediv__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __itruediv__.
"""
Expand All @@ -1257,7 +1257,7 @@ def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
self._array.__itruediv__(other._array)
return self

def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
def __rtruediv__(self, other: Array | complex, /) -> Array:
"""
Performs the operation __rtruediv__.
"""
Expand All @@ -1269,7 +1269,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
res = self._array.__rtruediv__(other._array)
return self.__class__._new(res, device=self.device)

def __ixor__(self, other: Array | bool | int, /) -> Array:
def __ixor__(self, other: Array | int, /) -> Array:
"""
Performs the operation __ixor__.
"""
Expand All @@ -1279,7 +1279,7 @@ def __ixor__(self, other: Array | bool | int, /) -> Array:
self._array.__ixor__(other._array)
return self

def __rxor__(self, other: Array | bool | int, /) -> Array:
def __rxor__(self, other: Array | int, /) -> Array:
"""
Performs the operation __rxor__.
"""
Expand Down
Loading
Loading