From 4223c723c1146c942c096bc9767daca7361460f5 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 18 Aug 2025 05:55:46 -0700 Subject: [PATCH 01/26] Pass trans_code to getrs in dpnp_solve() --- dpnp/linalg/dpnp_utils_linalg.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index fdf46174bfc..6e9037c3b35 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2875,6 +2875,12 @@ def dpnp_solve(a, b): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events + # TODO: remove after PR #2558 is merged + # Temporarily set trans_code=1 (transpose) because the LU-factorized + # array is C-contiguous. + # For F-contiguous arrays use 0 (non-transpose) + trans_code = 1 + # use DPCTL tensor function to fill the сopy of the input array # from the input array ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( From 80ce50c9a7f215661f6ad18f8b7145eef5cde8eb Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 07:22:10 -0700 Subject: [PATCH 02/26] Remove TODO --- dpnp/linalg/dpnp_utils_linalg.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 6e9037c3b35..fdf46174bfc 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2875,12 +2875,6 @@ def dpnp_solve(a, b): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # TODO: remove after PR #2558 is merged - # Temporarily set trans_code=1 (transpose) because the LU-factorized - # array is C-contiguous. - # For F-contiguous arrays use 0 (non-transpose) - trans_code = 1 - # use DPCTL tensor function to fill the сopy of the input array # from the input array ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( From af0ab7def4deb360d1c15838d62fc20dfd9e4c62 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 06:40:10 -0700 Subject: [PATCH 03/26] Implement of dpnp.linalg.lu_solve for 2D inputs --- dpnp/linalg/dpnp_iface_linalg.py | 71 ++++++++++++++++++++ dpnp/linalg/dpnp_utils_linalg.py | 112 +++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 1a46205ea82..d6ee9aef136 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -57,6 +57,7 @@ dpnp_inv, dpnp_lstsq, dpnp_lu_factor, + dpnp_lu_solve, dpnp_matrix_power, dpnp_matrix_rank, dpnp_multi_dot, @@ -81,6 +82,7 @@ "inv", "lstsq", "lu_factor", + "lu_solve", "matmul", "matrix_norm", "matrix_power", @@ -966,6 +968,75 @@ def lu_factor(a, overwrite_a=False, check_finite=True): return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) +def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): + """ + Solve an equation system, a x = b, given the LU factorization of `a` + + For full documentation refer to :obj:`scipy.linalg.lu_solve`. + + Parameters + ---------- + (lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} + LU factorization of matrix `a` ((M, N)) together with pivot indices. + b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} + Right-hand side + trans : {0, 1, 2} , optional + Type of system to solve: + + ===== ========= + trans system + ===== ========= + 0 a x = b + 1 a^T x = b + 2 a^H x = b + ===== ========= + overwrite_b : {None, bool}, optional + Whether to overwrite data in `b` (may increase performance). + + Default: ``False``. + check_finite : {None, bool}, optional + Whether to check that the input matrix contains only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Default: ``True``. + + Returns + ------- + x : {(M,), (M, K)} dpnp.ndarray + Solution to the system + + Warning + ------- + This function synchronizes in order to validate array elements + when ``check_finite=True``. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) + >>> b = np.array([1, 1, 1, 1]) + >>> lu, piv = np.linalg.lu_factor(A) + >>> x = np.linalg.lu_solve((lu, piv), b) + >>> np.allclose(A @ x - b, np.zeros((4,))) + array(True) + + """ + + (lu, piv) = lu_and_piv + dpnp.check_supported_arrays_type(lu, piv, b) + assert_stacked_2d(lu) + + return dpnp_lu_solve( + lu, + piv, + b, + trans=trans, + overwrite_b=overwrite_b, + check_finite=check_finite, + ) + + def matmul(x1, x2, /): """ Computes the matrix product. diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index fdf46174bfc..04339a5f587 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2514,6 +2514,118 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): return (a_h, ipiv_h) +def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): + """ + dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True) + + Solve an equation system (SciPy-compatible behavior). + + This function mimics the behavior of `scipy.linalg.lu_solve` including + support for `trans`, `overwrite_b`, `check_finite`, + and 0-based pivot indexing. + + """ + + res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) + + res_type = _common_type(lu, b) + + # TODO: add broadcasting + if lu.shape[0] != b.shape[0]: + raise ValueError( + f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" + ) + + if b.size == 0: + return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) + + if lu.ndim > 2: + raise NotImplementedError("Batched matrices are not supported") + + if check_finite: + if not dpnp.isfinite(lu).all(): + raise ValueError( + "array must not contain infs or NaNs.\n" + "Note that when a singular matrix is given, unlike " + "dpnp.linalg.lu_factor returns an array containing NaN." + ) + if not dpnp.isfinite(b).all(): + raise ValueError("array must not contain infs or NaNs") + + lu_usm_arr = dpnp.get_usm_ndarray(lu) + piv_usm_arr = dpnp.get_usm_ndarray(piv) + b_usm_arr = dpnp.get_usm_ndarray(b) + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getrf overwrites `a`. + lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=lu_usm_arr, + dst=lu_h.get_array(), + sycl_queue=lu.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, lu_copy_ev) + + # SciPy-compatible behavior + # Copy is required if: + # - overwrite_a is False (always copy), + # - dtype mismatch, + # - not F-contiguous,s + # - not writeable + if not overwrite_b or _is_copy_required(b, res_type): + b_h = dpnp.empty_like( + b, order="F", dtype=res_type, usm_type=res_usm_type + ) + ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_h.get_array(), + sycl_queue=b.sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, dep_ev) + dep_ev = [dep_ev] + else: + # input is suitable for in-place modification + b_h = b + dep_ev = _manager.submitted_events + + # oneMKL LAPACK getrf overwrites `a`. + piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the pivot array + # from the pivot array + ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=piv_usm_arr, + dst=piv_h.get_array(), + sycl_queue=piv.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, piv_copy_ev) + # MKL lapack uses 1-origin while SciPy uses 0-origin + piv_h += 1 + + # Call the LAPACK extension function _getrs + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_ev = li._getrs( + exec_q, + lu_h.get_array(), + piv_h.get_array(), + b_h.get_array(), + trans, + depends=dep_ev, + ) + _manager.add_event_pair(ht_ev, getrs_ev) + + return b_h + + def dpnp_matrix_power(a, n): """ dpnp_matrix_power(a, n) From 17b11ae44a5cc9779454642c41b7e21d84a2490e Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 06:41:18 -0700 Subject: [PATCH 04/26] Add dpnp.linalg.lu_solve to generated docs --- doc/reference/linalg.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/reference/linalg.rst b/doc/reference/linalg.rst index 107b5a86a5b..142c6052db8 100644 --- a/doc/reference/linalg.rst +++ b/doc/reference/linalg.rst @@ -86,6 +86,7 @@ Solving linear equations dpnp.linalg.solve dpnp.linalg.tensorsolve dpnp.linalg.lstsq + dpnp.linalg.lu_solve dpnp.linalg.inv dpnp.linalg.pinv dpnp.linalg.tensorinv From b10a8d6bbf19acb511577f47718c9e12ed904600 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 06:42:46 -0700 Subject: [PATCH 05/26] Add TestLuSolve to test_linalg.py --- dpnp/tests/test_linalg.py | 209 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 97379876c90..c14bd9c54d9 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2144,6 +2144,215 @@ def test_check_finite_raises(self): assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True) +class TestLuSolve: + @staticmethod + def _make_nonsingular_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + m, n = shape + k = min(m, n) + for i in range(k): + off = numpy.sum(numpy.abs(A[i, :n])) - numpy.abs(A[i, i]) + A[i, i] = A.dtype.type(off + 1.0) + return A + + @pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)]) + @pytest.mark.parametrize("rhs_cols", [None, 1, 3]) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_lu_solve(self, shape, rhs_cols, order, dtype): + a_np = self._make_nonsingular_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + n = shape[0] + if rhs_cols is None: + b_np = generate_random_numpy_array((n,), dtype, order) + else: + b_np = generate_random_numpy_array((n, rhs_cols), dtype, order) + b_dp = dpnp.array(b_np, order=order) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=0, overwrite_b=False, check_finite=False + ) + + # check A @ x = b + Ax = a_dp @ x + assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("trans", [0, 1, 2]) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_trans(self, trans, dtype): + n = 4 + a_np = self._make_nonsingular_np((n, n), dtype, order="F") + a_dp = dpnp.array(a_np, order="F") + b_dp = dpnp.array(generate_random_numpy_array((n, 2), dtype, "F")) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False + ) + + if trans == 0: + lhs = a_dp @ x + elif trans == 1: + lhs = a_dp.T @ x + else: # trans == 2 + lhs = a_dp.conj().T @ x + + assert dpnp.allclose(lhs, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_inplace(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + b_dp = dpnp.array([1, 0], dtype=dtype, order="F") + b_orig = b_dp.copy() + + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=0, overwrite_b=True, check_finite=False + ) + + assert x is b_dp + assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_copy_special(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + # F-contig but dtype != res_type + b1 = dpnp.array([1, 0], dtype=dpnp.int32, order="F") + x1 = dpnp.linalg.lu_solve( + (lu, piv), b1, overwrite_b=True, check_finite=False + ) + assert x1 is not b1 + + # F-contig, match dtype but read-only input + b2 = dpnp.array([1, 0], dtype=dtype, order="F") + b2.flags["WRITABLE"] = False + x2 = dpnp.linalg.lu_solve( + (lu, piv), b2, overwrite_b=True, check_finite=False + ) + assert x2 is not b2 + + for x in (x1, x2): + assert dpnp.allclose( + a_dp @ x, + dpnp.array([1, 0], dtype=x.dtype), + rtol=1e-6, + atol=1e-6, + ) + + @pytest.mark.parametrize( + "dtype_a", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize( + "dtype_b", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_diff_type(self, dtype_a, dtype_b): + a_np = self._make_nonsingular_np((3, 3), dtype_a, order="F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array((3,), dtype_b, order="F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + assert dpnp.allclose( + a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-6, atol=1e-6 + ) + + def test_strided_rhs(self): + n = 7 + a_np = self._make_nonsingular_np( + (n, n), dpnp.default_float_type(), order="F" + ) + a_dp = dpnp.array(a_np, order="F") + + rhs_full = ( + dpnp.arange(n * n, dtype=dpnp.default_float_type()).reshape( + n, n, order="F" + ) + + 1.0 + ) + b_dp = rhs_full[:, ::2][:, :3] + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=False, check_finite=False + ) + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.skip("Not implemented yet") + @pytest.mark.parametrize( + "b_shape", + [ + (4,), + (4, 1), + (4, 3), + # (1, 4, 3), + # (2, 4, 3), + # (1, 1, 4, 3) + ], + ) + def test_broadcast_rhs(self, b_shape): + dtype = dpnp.default_float_type() + + a_np = self._make_nonsingular_np((4, 4), dtype, order="F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array(b_shape, dtype, order="F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=True, check_finite=False + ) + + assert x.shape == b_dp.shape + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)]) + @pytest.mark.parametrize("rhs_cols", [None, 0, 3]) + def test_empty_shapes(self, shape, rhs_cols): + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + if min(shape) > 0: + for i in range(min(shape)): + a_dp[i, i] = a_dp.dtype.type(1.0) + + n = shape[0] + if rhs_cols is None: + b_shape = (n,) + else: + b_shape = (n, rhs_cols) + b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + + assert x.shape == b_shape + + @pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan]) + def test_check_finite_raises(self, bad): + a_dp = dpnp.array([[1.0, 0.0], [0.0, 1.0]], order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + b_bad = dpnp.array([1.0, bad], order="F") + assert_raises( + ValueError, + dpnp.linalg.lu_solve, + (lu, piv), + b_bad, + check_finite=True, + ) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( From 2021f772bfff02e52ba3adc57285782496494e7f Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 07:06:47 -0700 Subject: [PATCH 06/26] Add sycl_queue and usm_type tests --- dpnp/tests/test_sycl_queue.py | 15 +++++++++++++++ dpnp/tests/test_usm_type.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 3f7c399a26f..c26ba12b5c6 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -9,6 +9,7 @@ from numpy.testing import assert_array_equal, assert_raises import dpnp +import dpnp.linalg from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import get_usm_allocations @@ -1602,6 +1603,20 @@ def test_lu_factor(self, data, device): param_queue = param.sycl_queue assert_sycl_queue_equal(param_queue, a.sycl_queue) + @pytest.mark.parametrize( + "data", + [[1.0, 2.0], numpy.empty((2, 0))], + ) + def test_lu_solve(self, data, device): + a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) + lu, piv = dpnp.linalg.lu_factor(a) + b = dpnp.array(data, device=device) + + result = dpnp.linalg.lu_solve((lu, piv), b) + + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + assert_sycl_queue_equal(result.sycl_queue, b.sycl_queue) + @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3]) def test_matrix_power(self, n, device): x = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index d6143301463..0c9a3eeecf8 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1480,6 +1480,24 @@ def test_lu_factor(self, data, usm_type): for param in result: assert param.usm_type == a.usm_type + @pytest.mark.parametrize("usm_type_rhs", list_of_usm_types) + @pytest.mark.parametrize( + "data", + [[1.0, 2.0], numpy.empty((2, 0))], + ) + def test_lu_solve(self, data, usm_type, usm_type_rhs): + a = dpnp.array(data, usm_type=usm_type) + lu, piv = dpnp.linalg.lu_factor(a) + b = dpnp.array(data, usm_type=usm_type_rhs) + + result = dpnp.linalg.lu_solve((lu, piv), b) + + assert lu.usm_type == usm_type + assert b.usm_type == usm_type_rhs + assert result.usm_type == du.get_coerced_usm_type( + [usm_type, usm_type_rhs] + ) + @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3]) def test_matrix_power(self, n, usm_type): a = dpnp.array([[1, 2], [3, 5]], usm_type=usm_type) From be2725af0f3541022013d47128feb0bcf29853be Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 16 Sep 2025 03:34:11 -0700 Subject: [PATCH 07/26] Update doc/comment lines --- dpnp/linalg/dpnp_iface_linalg.py | 6 +++--- dpnp/linalg/dpnp_utils_linalg.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index d6ee9aef136..4c414383f3b 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -970,14 +970,14 @@ def lu_factor(a, overwrite_a=False, check_finite=True): def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): """ - Solve an equation system, a x = b, given the LU factorization of `a` + Solve an equation system, a x = b, given the LU factorization of `a`. For full documentation refer to :obj:`scipy.linalg.lu_solve`. Parameters ---------- - (lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} - LU factorization of matrix `a` ((M, N)) together with pivot indices. + lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} + LU factorization of matrix `a` (M, N) together with pivot indices. b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} Right-hand side trans : {0, 1, 2} , optional diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 04339a5f587..06893126f13 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2559,7 +2559,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # oneMKL LAPACK getrf overwrites `a`. + # oneMKL LAPACK getrf overwrites `lu`. lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) # use DPCTL tensor function to fill the сopy of the input array @@ -2574,9 +2574,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): # SciPy-compatible behavior # Copy is required if: - # - overwrite_a is False (always copy), + # - overwrite_b is False (always copy), # - dtype mismatch, - # - not F-contiguous,s + # - not F-contiguous, # - not writeable if not overwrite_b or _is_copy_required(b, res_type): b_h = dpnp.empty_like( @@ -2595,7 +2595,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): b_h = b dep_ev = _manager.submitted_events - # oneMKL LAPACK getrf overwrites `a`. + # oneMKL LAPACK getrf overwrites `piv`. piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) # use DPCTL tensor function to fill the сopy of the pivot array From 1e09cb75b57f26d7b6c2cc4ebb1174175b61cd58 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 18 Sep 2025 03:32:56 -0700 Subject: [PATCH 08/26] Update dependency logic --- dpnp/linalg/dpnp_utils_linalg.py | 37 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 06893126f13..23c1c40b7c4 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2572,6 +2572,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): ) _manager.add_event_pair(ht_ev, lu_copy_ev) + # oneMKL LAPACK getrf overwrites `piv`. + piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the pivot array + # from the pivot array + ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=piv_usm_arr, + dst=piv_h.get_array(), + sycl_queue=piv.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, piv_copy_ev) + # SciPy-compatible behavior # Copy is required if: # - overwrite_b is False (always copy), @@ -2582,31 +2595,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): b_h = dpnp.empty_like( b, order="F", dtype=res_type, usm_type=res_usm_type ) - ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray( + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=b_usm_arr, dst=b_h.get_array(), sycl_queue=b.sycl_queue, - depends=_manager.submitted_events, + depends=dep_evs, ) - _manager.add_event_pair(ht_ev, dep_ev) - dep_ev = [dep_ev] + _manager.add_event_pair(ht_ev, b_copy_ev) + dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev] else: # input is suitable for in-place modification b_h = b - dep_ev = _manager.submitted_events - - # oneMKL LAPACK getrf overwrites `piv`. - piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) + dep_evs = [lu_copy_ev, piv_copy_ev] - # use DPCTL tensor function to fill the сopy of the pivot array - # from the pivot array - ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=piv_usm_arr, - dst=piv_h.get_array(), - sycl_queue=piv.sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, piv_copy_ev) # MKL lapack uses 1-origin while SciPy uses 0-origin piv_h += 1 @@ -2619,7 +2620,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): piv_h.get_array(), b_h.get_array(), trans, - depends=dep_ev, + depends=dep_evs, ) _manager.add_event_pair(ht_ev, getrs_ev) From 9345b7b745b8869357c3f8a87508d810504de655 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 18 Sep 2025 03:49:52 -0700 Subject: [PATCH 09/26] Add trans code handling --- dpnp/linalg/dpnp_utils_linalg.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 23c1c40b7c4..2fb43d57895 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2611,6 +2611,19 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): # MKL lapack uses 1-origin while SciPy uses 0-origin piv_h += 1 + if not isinstance(trans, int): + raise TypeError("`trans` must be an integer") + + # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums + if trans == 0: + trans_mkl = li.Transpose.N + elif trans == 1: + trans_mkl = li.Transpose.T + elif trans == 2: + trans_mkl = li.Transpose.C + else: + raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + # Call the LAPACK extension function _getrs # to solve the system of linear equations with an LU-factored # coefficient square matrix, with multiple right-hand sides. @@ -2619,7 +2632,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): lu_h.get_array(), piv_h.get_array(), b_h.get_array(), - trans, + trans_mkl, depends=dep_evs, ) _manager.add_event_pair(ht_ev, getrs_ev) From 687006f6fb47a3cc3ee73879a19986585177fd52 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 18 Sep 2025 06:09:26 -0700 Subject: [PATCH 10/26] Fix docs for lu:must be square --- dpnp/linalg/dpnp_iface_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 4c414383f3b..59b6a5b23f4 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -977,7 +977,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): Parameters ---------- lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} - LU factorization of matrix `a` (M, N) together with pivot indices. + LU factorization of matrix `a` (M, M) together with pivot indices. b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} Right-hand side trans : {0, 1, 2} , optional From 9aaff82636986cbc83c6d9e51c0b40e3edf07c07 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 18 Sep 2025 07:23:47 -0700 Subject: [PATCH 11/26] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2bf08ff9af..ff44c36e92d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534) * Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565) * Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550) +* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575) ### Changed From 23ad15d7f12ec784355ce0efecd74666a3031892 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 19 Sep 2025 04:56:14 -0700 Subject: [PATCH 12/26] Apply docs remarks --- dpnp/linalg/dpnp_iface_linalg.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 59b6a5b23f4..5f3e73afbeb 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -970,7 +970,7 @@ def lu_factor(a, overwrite_a=False, check_finite=True): def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): """ - Solve an equation system, a x = b, given the LU factorization of `a`. + Solve a linear system, :math:`a x = b`, given the LU factorization of `a`. For full documentation refer to :obj:`scipy.linalg.lu_solve`. @@ -983,13 +983,15 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): trans : {0, 1, 2} , optional Type of system to solve: - ===== ========= + ===== ================= trans system - ===== ========= - 0 a x = b - 1 a^T x = b - 2 a^H x = b - ===== ========= + ===== ================= + 0 :math:`a x = b` + 1 :math:`a^T x = b` + 2 :math:`a^H x = b` + ===== ================= + + Default: ``0``. overwrite_b : {None, bool}, optional Whether to overwrite data in `b` (may increase performance). @@ -1011,6 +1013,10 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): This function synchronizes in order to validate array elements when ``check_finite=True``. + See Also + -------- + :obj:`dpnp.linalg.lu_factor` : LU factorize a matrix. + Examples -------- >>> import dpnp as np From 82de136eb4f009184fd7723ca94a68b038c51cdf Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 19 Sep 2025 05:13:20 -0700 Subject: [PATCH 13/26] Apply remarks --- dpnp/linalg/dpnp_utils_linalg.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 088d3570166..5b1e2688c7c 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2508,12 +2508,14 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): if check_finite: if not dpnp.isfinite(lu).all(): raise ValueError( - "array must not contain infs or NaNs.\n" + "LU factorization array must not contain infs or NaNs.\n" "Note that when a singular matrix is given, unlike " "dpnp.linalg.lu_factor returns an array containing NaN." ) if not dpnp.isfinite(b).all(): - raise ValueError("array must not contain infs or NaNs") + raise ValueError( + "Right-hand side array must not contain infs or NaNs" + ) lu_usm_arr = dpnp.get_usm_ndarray(lu) piv_usm_arr = dpnp.get_usm_ndarray(piv) From e586075280a5632d6dedf29d2d08f4c80f6e8014 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 19 Sep 2025 05:15:46 -0700 Subject: [PATCH 14/26] Add assert on USM data pointer to tests --- dpnp/tests/test_linalg.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index c14bd9c54d9..d2e2df3694a 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -1931,6 +1931,7 @@ def test_overwrite_inplace(self, dtype): ) assert lu is a_dp + assert lu.data.ptr == a_dp.data.ptr assert lu.flags["F_CONTIGUOUS"] is True L, U = self._split_lu(lu, 2, 2) @@ -1948,6 +1949,7 @@ def test_overwrite_copy(self, dtype): ) assert lu is not a_dp + assert lu.data.ptr != a_dp.data.ptr assert lu.flags["F_CONTIGUOUS"] is True L, U = self._split_lu(lu, 2, 2) @@ -1974,6 +1976,7 @@ def test_overwrite_copy_special(self): ) assert lu is not a_dp + assert lu.data.ptr != a_dp.data.ptr assert lu.flags["F_CONTIGUOUS"] is True L, U = self._split_lu(lu, 2, 2) @@ -2217,6 +2220,7 @@ def test_overwrite_inplace(self, dtype): ) assert x is b_dp + assert x.data.ptr == b_dp.data.ptr assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) @@ -2230,6 +2234,7 @@ def test_overwrite_copy_special(self, dtype): (lu, piv), b1, overwrite_b=True, check_finite=False ) assert x1 is not b1 + assert x1.data.ptr != b1.data.ptr # F-contig, match dtype but read-only input b2 = dpnp.array([1, 0], dtype=dtype, order="F") @@ -2238,6 +2243,7 @@ def test_overwrite_copy_special(self, dtype): (lu, piv), b2, overwrite_b=True, check_finite=False ) assert x2 is not b2 + assert x2.data.ptr != b2.data.ptr for x in (x1, x2): assert dpnp.allclose( From 7d1fd0b1c689454138d066260c4274ca7f56ab9d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 19 Sep 2025 05:25:02 -0700 Subject: [PATCH 15/26] Update data inputs for test_usm_type --- dpnp/tests/test_sycl_queue.py | 6 +++--- dpnp/tests/test_usm_type.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 73b65392158..f4299325762 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1612,13 +1612,13 @@ def test_lu_factor(self, data, device): assert_sycl_queue_equal(param_queue, a.sycl_queue) @pytest.mark.parametrize( - "data", + "b_data", [[1.0, 2.0], numpy.empty((2, 0))], ) - def test_lu_solve(self, data, device): + def test_lu_solve(self, b_data, device): a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) lu, piv = dpnp.linalg.lu_factor(a) - b = dpnp.array(data, device=device) + b = dpnp.array(b_data, device=device) result = dpnp.linalg.lu_solve((lu, piv), b) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index b7e66438d17..c17526649ab 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1489,13 +1489,13 @@ def test_lu_factor(self, data, usm_type): @pytest.mark.parametrize("usm_type_rhs", list_of_usm_types) @pytest.mark.parametrize( - "data", + "b_data", [[1.0, 2.0], numpy.empty((2, 0))], ) - def test_lu_solve(self, data, usm_type, usm_type_rhs): - a = dpnp.array(data, usm_type=usm_type) + def test_lu_solve(self, b_data, usm_type, usm_type_rhs): + a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], usm_type=usm_type) lu, piv = dpnp.linalg.lu_factor(a) - b = dpnp.array(data, usm_type=usm_type_rhs) + b = dpnp.array(b_data, usm_type=usm_type_rhs) result = dpnp.linalg.lu_solve((lu, piv), b) From 87074fadd803234ef08a1d13fe084b3342326ad8 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 22 Sep 2025 03:11:27 -0700 Subject: [PATCH 16/26] Add See Also to lu_factor --- dpnp/linalg/dpnp_iface_linalg.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 5f3e73afbeb..f73443229c4 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -907,7 +907,7 @@ def lstsq(a, b, rcond=None): def lu_factor(a, overwrite_a=False, check_finite=True): """ - Compute the pivoted LU decomposition of a matrix. + Compute the pivoted LU decomposition of `a` matrix. The decomposition is:: @@ -949,6 +949,11 @@ def lu_factor(a, overwrite_a=False, check_finite=True): This function synchronizes in order to validate array elements when ``check_finite=True``. + See Also + -------- + :obj:`dpnp.linalg.lu_solve` : Solve an equation system using + the LU factorization of `a` matrix. + Examples -------- >>> import dpnp as np From d81454e88214fe78e606d2b3eeb6d81873605158 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 22 Sep 2025 04:31:49 -0700 Subject: [PATCH 17/26] Enable cupyx tests --- dpnp/tests/helper.py | 9 + .../scipy_tests/linalg_tests/__init__.py | 0 .../linalg_tests/test_decomp_lu.py | 200 ++++++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/__init__.py create mode 100644 dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py diff --git a/dpnp/tests/helper.py b/dpnp/tests/helper.py index edb077a161c..93146159b11 100644 --- a/dpnp/tests/helper.py +++ b/dpnp/tests/helper.py @@ -1,3 +1,4 @@ +import importlib.util from sys import platform import dpctl @@ -488,6 +489,14 @@ def is_ptl(device=None): return _get_dev_mask(device) in (0xB000, 0xFD00) +def is_scipy_available(): + """ + Return True if SciPy is installed and can be found, + False otherwise. + """ + return importlib.util.find_spec("scipy") is not None + + def is_tgllp_iris_xe(device=None): """ Return True if a test is running on Tiger Lake-LP with Iris Xe GPU device, diff --git a/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/__init__.py b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py new file mode 100644 index 00000000000..b510a40c138 --- /dev/null +++ b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import unittest +import warnings + +import numpy +import pytest + +import dpnp as cupy +from dpnp.tests.third_party.cupy import testing + +if cupy.tests.helper.is_scipy_available(): + import scipy.linalg + + +# TODO: After the feature is released +# requires_scipy_linalg_backend = testing.with_requires('scipy>=1.x.x') +requires_scipy_linalg_backend = unittest.skip( + "scipy.linalg backend feature has not been released" +) + + +@testing.parameterize( + *testing.product( + { + "shape": [ + (1, 1), + (2, 2), + (3, 3), + (5, 5), + (1, 5), + (5, 1), + (2, 5), + (5, 2), + ], + } + ) +) +@testing.fix_random() +@testing.with_requires("scipy") +class TestLUFactor(unittest.TestCase): + + @testing.for_dtypes("fdFD") + def test_lu_factor(self, dtype): + if self.shape[0] != self.shape[1]: + self.skipTest( + "skip non-square tests since scipy.lu_factor requires square" + ) + a_cpu = testing.shaped_random(self.shape, numpy, dtype=dtype) + a_gpu = cupy.asarray(a_cpu) + result_cpu = scipy.linalg.lu_factor(a_cpu) + # Originally used cupyx.scipy.linalg.lu_factor + result_gpu = cupy.linalg.lu_factor(a_gpu) + assert len(result_cpu) == len(result_gpu) + assert result_cpu[0].dtype == result_gpu[0].dtype + # DPNP returns pivot indices as int64, while SciPy returns int32. + # Check for the expected dtypes explicitly. + # assert result_cpu[1].dtype == result_gpu[1].dtype + assert result_cpu[1].dtype == cupy.int32 + assert result_gpu[1].dtype == cupy.int64 + testing.assert_allclose(result_cpu[0], result_gpu[0], atol=1e-5) + testing.assert_array_equal(result_cpu[1], result_gpu[1]) + + def check_lu_factor_reconstruction(self, A): + m, n = self.shape + lu, piv = cupy.linalg.lu_factor(A) + # extract ``L`` and ``U`` from ``lu`` + L = cupy.tril(lu, k=-1) + cupy.fill_diagonal(L, 1.0) + L = L[:, :m] + U = cupy.triu(lu) + U = U[:n, :] + # check output shapes + assert lu.shape == (m, n) + assert L.shape == (m, min(m, n)) + assert U.shape == (min(m, n), n) + assert piv.shape == (min(m, n),) + # apply pivot (on CPU since slaswp is not available in cupy) + piv = cupy.asnumpy(piv) + rows = list(range(m)) + for i, row in enumerate(piv): + if i != row: + rows[i], rows[row] = rows[row], rows[i] + rows = cupy.asarray(rows) + PA = A[rows] + # check that reconstruction is close to original + LU = L.dot(U) + testing.assert_allclose(LU, PA, atol=1e-5) + + @testing.for_dtypes("fdFD") + def test_lu_factor_reconstruction(self, dtype): + A = testing.shaped_random(self.shape, cupy, dtype=dtype) + self.check_lu_factor_reconstruction(A) + + @testing.for_dtypes("fdFD") + def test_lu_factor_reconstruction_singular(self, dtype): + if self.shape[0] != self.shape[1]: + self.skipTest( + "skip non-square tests since scipy.lu_factor requires square" + ) + A = testing.shaped_random(self.shape, cupy, dtype=dtype) + A -= A.mean(axis=0, keepdims=True) + A -= A.mean(axis=1, keepdims=True) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + self.check_lu_factor_reconstruction(A) + + +@testing.parameterize( + *testing.product( + { + "shape": [ + (1, 1), + (2, 2), + (3, 3), + (5, 5), + (1, 5), + (5, 1), + (2, 5), + (5, 2), + ], + "permute_l": [False, True], + } + ) +) +@testing.fix_random() +@testing.with_requires("scipy") +class TestLU(unittest.TestCase): + @classmethod + def setUpClass(cls): + pytest.skip("lu() is not supported yet") + + @testing.for_dtypes("fdFD") + def test_lu(self, dtype): + a_cpu = testing.shaped_random(self.shape, numpy, dtype=dtype) + a_gpu = cupy.asarray(a_cpu) + result_cpu = scipy.linalg.lu(a_cpu, permute_l=self.permute_l) + result_gpu = cupy.linalg.lu(a_gpu, permute_l=self.permute_l) + assert len(result_cpu) == len(result_gpu) + if not self.permute_l: + # check permutation matrix + result_cpu = list(result_cpu) + result_gpu = list(result_gpu) + P_cpu = result_cpu.pop(0) + P_gpu = result_gpu.pop(0) + cupy.testing.assert_array_equal(P_gpu, P_cpu) + cupy.testing.assert_allclose(result_gpu[0], result_cpu[0], atol=1e-5) + cupy.testing.assert_allclose(result_gpu[1], result_cpu[1], atol=1e-5) + + @testing.for_dtypes("fdFD") + def test_lu_reconstruction(self, dtype): + m, n = self.shape + A = testing.shaped_random(self.shape, cupy, dtype=dtype) + if self.permute_l: + PL, U = cupy.linalg.lu(A, permute_l=self.permute_l) + PLU = PL @ U + else: + P, L, U = cupy.linalg.lu(A, permute_l=self.permute_l) + PLU = P @ L @ U + # check that reconstruction is close to original + cupy.testing.assert_allclose(PLU, A, atol=1e-5) + + +@testing.parameterize( + *testing.product( + { + "trans": [0, 1, 2], + "shapes": [((4, 4), (4,)), ((5, 5), (5, 2))], + } + ) +) +@testing.fix_random() +@testing.with_requires("scipy") +class TestLUSolve(unittest.TestCase): + + @testing.for_dtypes("fdFD") + @testing.numpy_cupy_allclose(atol=1e-5, scipy_name="scp") + def test_lu_solve(self, xp, scp, dtype): + a_shape, b_shape = self.shapes + A = testing.shaped_random(a_shape, xp, dtype=dtype) + b = testing.shaped_random(b_shape, xp, dtype=dtype) + lu = scp.linalg.lu_factor(A) + return scp.linalg.lu_solve(lu, b, trans=self.trans) + + @requires_scipy_linalg_backend + @testing.for_dtypes("fdFD") + @testing.numpy_cupy_allclose(atol=1e-5) + def test_lu_solve_backend(self, xp, dtype): + a_shape, b_shape = self.shapes + A = testing.shaped_random(a_shape, xp, dtype=dtype) + b = testing.shaped_random(b_shape, xp, dtype=dtype) + if xp is numpy: + lu = scipy.linalg.lu_factor(A) + backend = "scipy" + else: + lu = cupy.linalg.lu_factor(A) + backend = cupy.linalg + with scipy.linalg.set_backend(backend): + out = scipy.linalg.lu_solve(lu, b, trans=self.trans) + return out From 78a4c78ebf45d8a81eeaf32e71087eb33ec07611 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 22 Sep 2025 05:17:05 -0700 Subject: [PATCH 18/26] Adjust tolerance for test_lu_solve --- dpnp/tests/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index d2e2df3694a..68439cdf105 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2182,7 +2182,7 @@ def test_lu_solve(self, shape, rhs_cols, order, dtype): # check A @ x = b Ax = a_dp @ x - assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(Ax, b_dp, atol=1e-5) @pytest.mark.parametrize("trans", [0, 1, 2]) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) From d0fbd4921d2f4186609dca3b46566412dde30b64 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 22 Sep 2025 05:19:37 -0700 Subject: [PATCH 19/26] Apply remark --- .../cupyx/scipy_tests/linalg_tests/test_decomp_lu.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py index b510a40c138..2e0da004413 100644 --- a/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py +++ b/dpnp/tests/third_party/cupyx/scipy_tests/linalg_tests/test_decomp_lu.py @@ -125,10 +125,8 @@ def test_lu_factor_reconstruction_singular(self, dtype): ) @testing.fix_random() @testing.with_requires("scipy") +@pytest.mark.skip("lu() is not supported yet") class TestLU(unittest.TestCase): - @classmethod - def setUpClass(cls): - pytest.skip("lu() is not supported yet") @testing.for_dtypes("fdFD") def test_lu(self, dtype): From 52eac3df366d26631842e9b27248f898c11d3765 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 22 Sep 2025 06:55:08 -0700 Subject: [PATCH 20/26] Adjust tolerance for interger dtypes --- dpnp/tests/test_linalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 68439cdf105..bcf8e6ae29b 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2182,7 +2182,10 @@ def test_lu_solve(self, shape, rhs_cols, order, dtype): # check A @ x = b Ax = a_dp @ x - assert dpnp.allclose(Ax, b_dp, atol=1e-5) + if dpnp.issubdtype(dtype, dpnp.integer): + assert dpnp.allclose(Ax, b_dp, rtol=1e-5, atol=1e-5) + else: + assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) @pytest.mark.parametrize("trans", [0, 1, 2]) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) From ccbfccfcca5d4c3765d694d339a5ace4b2860129 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 23 Sep 2025 07:33:09 -0700 Subject: [PATCH 21/26] Solve race conditions issue for pivots --- dpnp/linalg/dpnp_utils_linalg.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 5b1e2688c7c..f01a2158262 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2518,9 +2518,12 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): ) lu_usm_arr = dpnp.get_usm_ndarray(lu) - piv_usm_arr = dpnp.get_usm_ndarray(piv) b_usm_arr = dpnp.get_usm_ndarray(b) + # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy, + # convert to 1-based for oneMKL getrs + piv_h = piv + 1 + _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events @@ -2537,19 +2540,6 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): ) _manager.add_event_pair(ht_ev, lu_copy_ev) - # oneMKL LAPACK getrf overwrites `piv`. - piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) - - # use DPCTL tensor function to fill the сopy of the pivot array - # from the pivot array - ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=piv_usm_arr, - dst=piv_h.get_array(), - sycl_queue=piv.sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, piv_copy_ev) - # SciPy-compatible behavior # Copy is required if: # - overwrite_b is False (always copy), @@ -2567,14 +2557,11 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): depends=dep_evs, ) _manager.add_event_pair(ht_ev, b_copy_ev) - dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev] + dep_evs = [lu_copy_ev, b_copy_ev] else: # input is suitable for in-place modification b_h = b - dep_evs = [lu_copy_ev, piv_copy_ev] - - # MKL lapack uses 1-origin while SciPy uses 0-origin - piv_h += 1 + dep_evs = [lu_copy_ev] if not isinstance(trans, int): raise TypeError("`trans` must be an integer") From e91c24aafd4fab2e95444ee16a7cb6674fcd1c08 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 23 Sep 2025 07:34:59 -0700 Subject: [PATCH 22/26] Revert the tolerance adjustment --- dpnp/tests/test_linalg.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index bcf8e6ae29b..d2e2df3694a 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2182,10 +2182,7 @@ def test_lu_solve(self, shape, rhs_cols, order, dtype): # check A @ x = b Ax = a_dp @ x - if dpnp.issubdtype(dtype, dpnp.integer): - assert dpnp.allclose(Ax, b_dp, rtol=1e-5, atol=1e-5) - else: - assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) @pytest.mark.parametrize("trans", [0, 1, 2]) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) From dfc76e0e5bfd15244177566ef8585b670c74cd8b Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 24 Sep 2025 02:18:34 -0700 Subject: [PATCH 23/26] Apply remark --- dpnp/linalg/dpnp_utils_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index f01a2158262..44a3816cc16 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2527,7 +2527,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # oneMKL LAPACK getrf overwrites `lu`. + # oneMKL LAPACK getrs overwrites `lu`. lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) # use DPCTL tensor function to fill the сopy of the input array From 312719090b266fbdf83e4c484076ba4ca2806d2d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 24 Sep 2025 02:19:36 -0700 Subject: [PATCH 24/26] Adjust tolerance for interger dtypes --- dpnp/tests/test_linalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index d2e2df3694a..bcf8e6ae29b 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2182,7 +2182,10 @@ def test_lu_solve(self, shape, rhs_cols, order, dtype): # check A @ x = b Ax = a_dp @ x - assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) + if dpnp.issubdtype(dtype, dpnp.integer): + assert dpnp.allclose(Ax, b_dp, rtol=1e-5, atol=1e-5) + else: + assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) @pytest.mark.parametrize("trans", [0, 1, 2]) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) From 05907b166214bcee517673b067535051a37f0eaa Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 24 Sep 2025 02:33:11 -0700 Subject: [PATCH 25/26] Enable test_broadcast_rhs --- dpnp/tests/test_linalg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index bcf8e6ae29b..f8eed95099d 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2297,7 +2297,6 @@ def test_strided_rhs(self): assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) - @pytest.mark.skip("Not implemented yet") @pytest.mark.parametrize( "b_shape", [ @@ -2320,7 +2319,7 @@ def test_broadcast_rhs(self, b_shape): lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) x = dpnp.linalg.lu_solve( - (lu, piv), b_dp, overwrite_b=True, check_finite=False + (lu, piv), b_dp, overwrite_b=False, check_finite=False ) assert x.shape == b_dp.shape From 56fea1e46a87080dc31a7d52bb12b21439f42eac Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 24 Sep 2025 03:51:13 -0700 Subject: [PATCH 26/26] Make test more stable by adjusting tolerance --- dpnp/tests/test_linalg.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index f8eed95099d..a25c237f846 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2182,10 +2182,7 @@ def test_lu_solve(self, shape, rhs_cols, order, dtype): # check A @ x = b Ax = a_dp @ x - if dpnp.issubdtype(dtype, dpnp.integer): - assert dpnp.allclose(Ax, b_dp, rtol=1e-5, atol=1e-5) - else: - assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(Ax, b_dp, rtol=1e-5, atol=1e-5) @pytest.mark.parametrize("trans", [0, 1, 2]) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) @@ -2207,7 +2204,7 @@ def test_trans(self, trans, dtype): else: # trans == 2 lhs = a_dp.conj().T @ x - assert dpnp.allclose(lhs, b_dp, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(lhs, b_dp, rtol=1e-5, atol=1e-5) @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) def test_overwrite_inplace(self, dtype): @@ -2272,7 +2269,7 @@ def test_diff_type(self, dtype_a, dtype_b): lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) assert dpnp.allclose( - a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-6, atol=1e-6 + a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-5, atol=1e-5 ) def test_strided_rhs(self): @@ -2295,7 +2292,7 @@ def test_strided_rhs(self): (lu, piv), b_dp, overwrite_b=False, check_finite=False ) - assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) @pytest.mark.parametrize( "b_shape", @@ -2324,7 +2321,7 @@ def test_broadcast_rhs(self, b_shape): assert x.shape == b_dp.shape - assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) @pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)]) @pytest.mark.parametrize("rhs_cols", [None, 0, 3])