From 493fd46731db0f40dabd4982d6aad9a0ba36f383 Mon Sep 17 00:00:00 2001 From: jorenham Date: Mon, 14 Jul 2025 11:26:47 +0200 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=A8=20`linalg`:=20improved=20`det`=20?= =?UTF-8?q?annotations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scipy-stubs/linalg/_basic.pyi | 42 +++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/scipy-stubs/linalg/_basic.pyi b/scipy-stubs/linalg/_basic.pyi index e9e96251..8cabd577 100644 --- a/scipy-stubs/linalg/_basic.pyi +++ b/scipy-stubs/linalg/_basic.pyi @@ -1037,18 +1037,36 @@ def inv( ) -> onp.ArrayND[np.complex128, _ShapeT]: ... # TODO(jorenham): improve this -@overload # floating 2d -def det(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float: ... -@overload # floating 3d -def det(a: onp.ToFloatStrict3D, overwrite_a: bool = False, check_finite: bool = True) -> _Float1D: ... -@overload # floating -def det(a: onp.ToFloatND, overwrite_a: bool = False, check_finite: bool = True) -> _Float | _FloatND: ... -@overload # complexfloating 2d -def det(a: onp.ToJustComplexStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact1D: ... -@overload # complexfloating 3d -def det(a: onp.ToJustComplexStrict3D, overwrite_a: bool = False, check_finite: bool = True) -> _InexactND: ... -@overload # complexfloating -def det(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact | _InexactND: ... +@overload # +float64 2d +def det(a: onp.ToFloat64Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64: ... +@overload # +float64 3d +def det(a: onp.ToFloat64Strict3D, overwrite_a: bool = False, check_finite: bool = True) -> onp.Array1D[np.float64]: ... +@overload # +float64 ND +def det(a: onp.ToFloat64_ND, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | onp.ArrayND[np.float64]: ... +@overload # complex128 | complex64 2d +def det( + a: onp.ToJustComplex128Strict2D | onp.CanArray2D[np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> np.complex128: ... +@overload # complex128 | complex64 3d +def det( + a: onp.ToJustComplex128Strict3D | onp.CanArray3D[np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> onp.Array1D[np.complex128]: ... +@overload # complex128 | complex64 Nd +def det( + a: onp.ToJustComplex128_ND, overwrite_a: bool = False, check_finite: bool = True +) -> np.complex128 | onp.ArrayND[np.complex128]: ... +@overload # +complex128 2d +def det( + a: onp.ToComplex128Strict2D | onp.CanArray2D[np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> np.float64 | np.complex128: ... +@overload # +complex128 3d +def det( + a: onp.ToComplex128Strict3D | onp.CanArray3D[np.complex64], overwrite_a: bool = False, check_finite: bool = True +) -> onp.Array1D[np.float64 | np.complex128]: ... +@overload # +complex128 Nd +def det( + a: onp.ToComplex128_ND, overwrite_a: bool = False, check_finite: bool = True +) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... # TODO(jorenham): improve this @overload # (float[:, :], float[:]) -> (float[:], float[], ...) From 22360ced42876fbad4df63b4edcc1ff6dfecc996 Mon Sep 17 00:00:00 2001 From: jorenham Date: Mon, 14 Jul 2025 11:26:58 +0200 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=85=20`linalg`:=20type-tests=20for=20?= =?UTF-8?q?`det`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/linalg/test__basic.pyi | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/linalg/test__basic.pyi b/tests/linalg/test__basic.pyi index eb30c91a..55eba162 100644 --- a/tests/linalg/test__basic.pyi +++ b/tests/linalg/test__basic.pyi @@ -6,7 +6,7 @@ import numpy as np import optype.numpy as onp import optype.numpy.compat as npc -from scipy.linalg import inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular +from scipy.linalg import det, inv, solve, solve_banded, solve_circulant, solve_toeplitz, solve_triangular b1_nd: onp.ArrayND[np.bool_] @@ -369,7 +369,27 @@ assert_type(inv(c128_nd), onp.ArrayND[np.complex128]) assert_type(inv(c160_nd), onp.ArrayND[np.complex128]) ### -# TODO(jorenham): det +# det + +assert_type(det(f32_2d), np.float64) +assert_type(det(f64_2d), np.float64) +assert_type(det(c64_2d), np.complex128) +assert_type(det(c128_2d), np.complex128) + +assert_type(det(py_b_2d), np.float64) +assert_type(det(py_i_2d), np.float64) +assert_type(det(py_f_2d), np.float64) +assert_type(det(py_c_2d), np.complex128) + +assert_type(det(f32_3d), onp.Array1D[np.float64]) +assert_type(det(f64_3d), onp.Array1D[np.float64]) +assert_type(det(c64_3d), onp.Array1D[np.complex128]) +assert_type(det(c128_3d), onp.Array1D[np.complex128]) + +assert_type(det(py_b_3d), onp.Array1D[np.float64]) +assert_type(det(py_i_3d), onp.Array1D[np.float64]) +assert_type(det(py_f_3d), onp.Array1D[np.float64]) +assert_type(det(py_c_3d), onp.Array1D[np.complex128]) ### # TODO(jorenham): lstsq From 4a80fa1785898c28617d3c9c78c8e6b91518c741 Mon Sep 17 00:00:00 2001 From: jorenham Date: Wed, 16 Jul 2025 19:57:32 +0200 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=A4=A1=20`linalg.det`:=20shuffle=20di?= =?UTF-8?q?sjoint=20overloads=20to=20avoid=20triggering=20a=20pyright=20bu?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scipy-stubs/linalg/_basic.pyi | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/scipy-stubs/linalg/_basic.pyi b/scipy-stubs/linalg/_basic.pyi index 8cabd577..570e3d49 100644 --- a/scipy-stubs/linalg/_basic.pyi +++ b/scipy-stubs/linalg/_basic.pyi @@ -994,7 +994,6 @@ def solve_circulant( ) -> onp.ArrayND[npc.inexact]: ... # - @overload # 2d bool sequence def inv(a: Sequence[Sequence[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.Array2D[np.float32]: ... @overload # Nd bool sequence @@ -1036,32 +1035,30 @@ def inv( a: onp.CanArrayND[np.complex128 | np.clongdouble, _ShapeT], overwrite_a: bool = False, check_finite: bool = True ) -> onp.ArrayND[np.complex128, _ShapeT]: ... -# TODO(jorenham): improve this +# NOTE: The order of the overloads has been carefully chosen to avoid triggering a Pyright bug. @overload # +float64 2d def det(a: onp.ToFloat64Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64: ... -@overload # +float64 3d -def det(a: onp.ToFloat64Strict3D, overwrite_a: bool = False, check_finite: bool = True) -> onp.Array1D[np.float64]: ... -@overload # +float64 ND -def det(a: onp.ToFloat64_ND, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | onp.ArrayND[np.float64]: ... @overload # complex128 | complex64 2d def det( - a: onp.ToJustComplex128Strict2D | onp.CanArray2D[np.complex64], overwrite_a: bool = False, check_finite: bool = True + a: onp.ToArrayStrict2D[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True ) -> np.complex128: ... +@overload # +float64 3d +def det(a: onp.ToFloat64Strict3D, overwrite_a: bool = False, check_finite: bool = True) -> onp.Array1D[np.float64]: ... @overload # complex128 | complex64 3d def det( - a: onp.ToJustComplex128Strict3D | onp.CanArray3D[np.complex64], overwrite_a: bool = False, check_finite: bool = True + a: onp.ToArrayStrict3D[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True ) -> onp.Array1D[np.complex128]: ... +@overload # +float64 ND +def det(a: onp.ToFloat64_ND, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | onp.ArrayND[np.float64]: ... @overload # complex128 | complex64 Nd def det( - a: onp.ToJustComplex128_ND, overwrite_a: bool = False, check_finite: bool = True + a: onp.ToArrayND[op.JustComplex, np.complex128 | np.complex64], overwrite_a: bool = False, check_finite: bool = True ) -> np.complex128 | onp.ArrayND[np.complex128]: ... @overload # +complex128 2d -def det( - a: onp.ToComplex128Strict2D | onp.CanArray2D[np.complex64], overwrite_a: bool = False, check_finite: bool = True -) -> np.float64 | np.complex128: ... +def det(a: onp.ToComplex128Strict2D, overwrite_a: bool = False, check_finite: bool = True) -> np.float64 | np.complex128: ... @overload # +complex128 3d def det( - a: onp.ToComplex128Strict3D | onp.CanArray3D[np.complex64], overwrite_a: bool = False, check_finite: bool = True + a: onp.ToComplex128Strict3D, overwrite_a: bool = False, check_finite: bool = True ) -> onp.Array1D[np.float64 | np.complex128]: ... @overload # +complex128 Nd def det(