@@ -27,6 +27,8 @@ from typing import (
27
27
)
28
28
from typing_extensions import Buffer , CapsuleType , LiteralString , Never , Protocol , Self , TypeVar , Unpack , deprecated , override
29
29
30
+ import numpy as np
31
+
30
32
from . import (
31
33
__config__ as __config__ ,
32
34
_array_api_info as _array_api_info ,
@@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
611
613
_DT64ItemT_co = TypeVar ("_DT64ItemT_co" , bound = dt .date | int | None , default = dt .date | int | None , covariant = True )
612
614
_TD64UnitT = TypeVar ("_TD64UnitT" , bound = _TD64Unit , default = _TD64Unit )
613
615
616
+ _Array1D : TypeAlias = np .ndarray [tuple [int ], np .dtype [_ScalarT ]]
617
+
614
618
###
615
619
# Type Aliases (for internal use only)
616
620
@@ -2530,9 +2534,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2530
2534
def __imul__ (self : NDArray [complexfloating ], rhs : _ArrayLikeComplex_co , / ) -> ndarray [_ShapeT_co , _DTypeT_co ]: ...
2531
2535
@overload
2532
2536
def __imul__ (self : NDArray [object_ ], rhs : object , / ) -> ndarray [_ShapeT_co , _DTypeT_co ]: ...
2533
-
2534
- # TODO(jorenham): Support the "1d @ 1d -> scalar" case
2535
- # https://github.com/numpy/numtype/issues/197
2537
+ @overload
2538
+ def __matmul__ (self : _Array1D [_ScalarT ], rhs : _Array1D [_ScalarT ], / ) -> _ScalarT : ...
2536
2539
@overload
2537
2540
def __matmul__ (self : NDArray [_NumberT ], rhs : _ArrayLikeBool_co , / ) -> NDArray [_NumberT ]: ...
2538
2541
@overload
@@ -2566,12 +2569,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2566
2569
@overload
2567
2570
def __matmul__ (self : NDArray [bool_ | number ], rhs : _ArrayLikeNumber_co , / ) -> NDArray [Incomplete ]: ...
2568
2571
@overload
2569
- def __matmul__ (self : NDArray [object_ ], rhs : object , / ) -> NDArray [object_ ]: ...
2572
+ def __matmul__ (self : NDArray [object_ ], rhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2570
2573
@overload
2571
2574
def __matmul__ (self , rhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2572
2575
2573
2576
# keep in sync with __matmul__
2574
2577
@overload
2578
+ def __rmatmul__ (self : _Array1D [_ScalarT ], rhs : _Array1D [_ScalarT ], / ) -> _ScalarT : ...
2579
+ @overload
2575
2580
def __rmatmul__ (self : NDArray [_NumberT ], lhs : _ArrayLikeBool_co , / ) -> NDArray [_NumberT ]: ...
2576
2581
@overload
2577
2582
def __rmatmul__ (self : NDArray [bool_ ], lhs : _ArrayLike [_NumberT ], / ) -> NDArray [_NumberT ]: ...
@@ -2604,7 +2609,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2604
2609
@overload
2605
2610
def __rmatmul__ (self : NDArray [bool_ | number ], lhs : _ArrayLikeNumber_co , / ) -> NDArray [Incomplete ]: ...
2606
2611
@overload
2607
- def __rmatmul__ (self : NDArray [object_ ], lhs : object , / ) -> NDArray [object_ ]: ...
2612
+ def __rmatmul__ (self : NDArray [object_ ], lhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2608
2613
@overload
2609
2614
def __rmatmul__ (self , lhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2610
2615
0 commit comments