diff --git a/library/src/include/rocblas.hpp b/library/src/include/rocblas.hpp index 922b7d9d6..5ee3cbfc4 100644 --- a/library/src/include/rocblas.hpp +++ b/library/src/include/rocblas.hpp @@ -59,17 +59,24 @@ constexpr auto rocblas2string_status(rocblas_status status) } } -#define ROCBLAS_CHECK(fcn) \ - { \ - rocblas_status _status = (fcn); \ - if(_status != rocblas_status_success) \ - return _status; \ +#define HIP_CHECK(...) \ + { \ + hipError_t _status = (__VA_ARGS__); \ + if(_status != hipSuccess) \ + return get_rocblas_status_for_hip_status(_status); \ } -#define THROW_IF_ROCBLAS_ERROR(fcn) \ - { \ - rocblas_status _status = (fcn); \ - if(_status != rocblas_status_success) \ - throw _status; \ + +#define ROCBLAS_CHECK(...) \ + { \ + rocblas_status _status = (__VA_ARGS__); \ + if(_status != rocblas_status_success) \ + return _status; \ + } +#define THROW_IF_ROCBLAS_ERROR(...) \ + { \ + rocblas_status _status = (__VA_ARGS__); \ + if(_status != rocblas_status_success) \ + throw _status; \ } template diff --git a/library/src/include/rocsolver_run_specialized_kernels.hpp b/library/src/include/rocsolver_run_specialized_kernels.hpp index 082e83286..2d156cf30 100644 --- a/library/src/include/rocsolver_run_specialized_kernels.hpp +++ b/library/src/include/rocsolver_run_specialized_kernels.hpp @@ -55,96 +55,96 @@ void rocsolver_trsm_mem(const rocblas_side side, const rocblas_int incb = 1); template -void rocsolver_trsm_lower(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4); +rocblas_status rocsolver_trsm_lower(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4); template -void rocsolver_trsm_lower(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int inca, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int incb, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4); +rocblas_status rocsolver_trsm_lower(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int inca, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int incb, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4); template -void rocsolver_trsm_upper(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4); +rocblas_status rocsolver_trsm_upper(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4); template -void rocsolver_trsm_upper(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int inca, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int incb, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4); +rocblas_status rocsolver_trsm_upper(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int inca, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int incb, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4); // gemm template diff --git a/library/src/include/rocsparse.hpp b/library/src/include/rocsparse.hpp index 7b5fac7cd..7983226ba 100644 --- a/library/src/include/rocsparse.hpp +++ b/library/src/include/rocsparse.hpp @@ -67,15 +67,15 @@ constexpr auto rocsparse2rocblas_status(rocsparse_status status) } } -#define ROCSPARSE_CHECK(fcn) \ +#define ROCSPARSE_CHECK(...) \ { \ - rocsparse_status _status = (fcn); \ + rocsparse_status _status = (__VA_ARGS__); \ if(_status != rocsparse_status_success) \ return rocsparse2rocblas_status(_status); \ } -#define THROW_IF_ROCSPARSE_ERROR(fcn) \ +#define THROW_IF_ROCSPARSE_ERROR(...) \ { \ - rocsparse_status _status = (fcn); \ + rocsparse_status _status = (__VA_ARGS__); \ if(_status != rocsparse_status_success) \ throw rocsparse2rocblas_status(_status); \ } diff --git a/library/src/lapack/roclapack_syevdx_heevdx_inplace.hpp b/library/src/lapack/roclapack_syevdx_heevdx_inplace.hpp index 1a4887825..eafa3291b 100644 --- a/library/src/lapack/roclapack_syevdx_heevdx_inplace.hpp +++ b/library/src/lapack/roclapack_syevdx_heevdx_inplace.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -290,14 +290,9 @@ rocblas_status rocsolver_syevdx_heevdx_inplace_template(rocblas_handle handle, // copy nev from device to host if(h_nev) { - hipError_t status = hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count, - hipMemcpyDeviceToHost, stream); - if(status != hipSuccess) - return get_rocblas_status_for_hip_status(status); - - status = hipStreamSynchronize(stream); - if(status != hipSuccess) - return get_rocblas_status_for_hip_status(status); + HIP_CHECK(hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count, + hipMemcpyDeviceToHost, stream)); + HIP_CHECK(hipStreamSynchronize(stream)); } return rocblas_status_success; diff --git a/library/src/lapack/roclapack_syevj_heevj.hpp b/library/src/lapack/roclapack_syevj_heevj.hpp index d8dd90591..840f3aebd 100644 --- a/library/src/lapack/roclapack_syevj_heevj.hpp +++ b/library/src/lapack/roclapack_syevj_heevj.hpp @@ -6,7 +6,7 @@ * and * Hari & Kovac (2019). On the Convergence of Complex Jacobi Methods. * Linear and Multilinear Algebra 69(3), p. 489-514. - * Copyright (c) 2021-2023 Advanced Micro Devices, Inc. + * Copyright (C) 2021-2023 Advanced Micro Devices, Inc. * ***********************************************************************/ #pragma once @@ -1493,14 +1493,9 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle, while(h_sweeps < max_sweeps) { // if all instances in the batch have finished, exit the loop - hipError_t status = hipMemcpyAsync(&h_completed, completed, sizeof(rocblas_int), - hipMemcpyDeviceToHost, stream); - if(status != hipSuccess) - return get_rocblas_status_for_hip_status(status); - - status = hipStreamSynchronize(stream); - if(status != hipSuccess) - return get_rocblas_status_for_hip_status(status); + HIP_CHECK(hipMemcpyAsync(&h_completed, completed, sizeof(rocblas_int), + hipMemcpyDeviceToHost, stream)); + HIP_CHECK(hipStreamSynchronize(stream)); if(h_completed == batch_count) break; diff --git a/library/src/lapack/roclapack_sygvdx_hegvdx_inplace.hpp b/library/src/lapack/roclapack_sygvdx_hegvdx_inplace.hpp index 940cc7841..115d9f187 100644 --- a/library/src/lapack/roclapack_sygvdx_hegvdx_inplace.hpp +++ b/library/src/lapack/roclapack_sygvdx_hegvdx_inplace.hpp @@ -4,7 +4,7 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 - * Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions @@ -320,14 +320,9 @@ rocblas_status rocsolver_sygvdx_hegvdx_inplace_template(rocblas_handle handle, // copy nev from device to host if(h_nev) { - hipError_t status = hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count, - hipMemcpyDeviceToHost, stream); - if(status != hipSuccess) - return get_rocblas_status_for_hip_status(status); - - status = hipStreamSynchronize(stream); - if(status != hipSuccess) - return get_rocblas_status_for_hip_status(status); + HIP_CHECK(hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count, + hipMemcpyDeviceToHost, stream)); + HIP_CHECK(hipStreamSynchronize(stream)); } rocblas_set_pointer_mode(handle, old_mode); diff --git a/library/src/refact/rocrefact_csrrf_refactchol.cpp b/library/src/refact/rocrefact_csrrf_refactchol.cpp index 1f00d0a81..088e31d05 100644 --- a/library/src/refact/rocrefact_csrrf_refactchol.cpp +++ b/library/src/refact/rocrefact_csrrf_refactchol.cpp @@ -81,8 +81,8 @@ rocblas_status rocsolver_csrrf_refactchol_impl(rocblas_handle handle, work = mem[0]; // execution - return rocsolver_csrrf_refactchol_template(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT, - indT, valT, pivQ, rfinfo, work); + return rocsolver_csrrf_refactchol_template(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT, + indT, valT, pivQ, rfinfo, work); #else return rocblas_status_not_implemented; #endif diff --git a/library/src/refact/rocrefact_csrrf_refactchol.hpp b/library/src/refact/rocrefact_csrrf_refactchol.hpp index 0d6a0d2e5..008bd4f28 100644 --- a/library/src/refact/rocrefact_csrrf_refactchol.hpp +++ b/library/src/refact/rocrefact_csrrf_refactchol.hpp @@ -51,7 +51,6 @@ ROCSOLVER_KERNEL void rf_add_QAQ_kernel(const rocblas_int n, rocblas_int* Ap, rocblas_int* Ai, T* Ax, - const T beta, rocblas_int* LUp, rocblas_int* LUi, T* LUx) @@ -76,18 +75,6 @@ ROCSOLVER_KERNEL void rf_add_QAQ_kernel(const rocblas_int n, T aij; - // ---------------- - // scale B by beta - // ---------------- - for(i = istart + tiy; i < iend; i += hipBlockDim_y) - { - // only access lower triangle of B - if(irow < LUi[i]) - break; - LUx[i] *= beta; - } - __syncthreads(); - // ------------------------------ // scale A by alpha and add to B // ------------------------------ @@ -216,16 +203,17 @@ rocblas_status rocsolver_csrrf_refactchol_template(rocblas_handle handle, ROCSOLVER_LAUNCH_KERNEL(rf_ipvec_kernel, dim3(nblocks), dim3(BS2), 0, stream, n, pivQ, (rocblas_int*)work); + // set T to zero + HIP_CHECK(hipMemsetAsync((void*)valT, 0, sizeof(T) * nnzT, stream)); + // -------------------------------------------------------------- // copy Q'*A*Q into T // // Note: assume A and B are symmetric and ONLY the LOWER triangular parts of A and T are touched // -------------------------------------------------------------- T const alpha = static_cast(1); - T const beta = static_cast(0); ROCSOLVER_LAUNCH_KERNEL(rf_add_QAQ_kernel, dim3(nblocks, 1), dim3(BS2, BS2), 0, stream, n, - pivQ, (rocblas_int*)work, alpha, ptrA, indA, valA, beta, ptrT, indT, - valT); + pivQ, (rocblas_int*)work, alpha, ptrA, indA, valA, ptrT, indT, valT); // perform incomplete factorization of T ROCSPARSE_CHECK(rocsparseCall_csric0(rfinfo->sphandle, n, nnzT, rfinfo->descrT, valT, ptrT, diff --git a/library/src/refact/rocrefact_csrrf_refactlu.cpp b/library/src/refact/rocrefact_csrrf_refactlu.cpp index 78d794734..67a4c2266 100644 --- a/library/src/refact/rocrefact_csrrf_refactlu.cpp +++ b/library/src/refact/rocrefact_csrrf_refactlu.cpp @@ -82,8 +82,8 @@ rocblas_status rocsolver_csrrf_refactlu_impl(rocblas_handle handle, work = mem[0]; // execution - return rocsolver_csrrf_refactlu_template(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT, indT, - valT, pivP, pivQ, rfinfo, work); + return rocsolver_csrrf_refactlu_template(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT, + indT, valT, pivP, pivQ, rfinfo, work); #else return rocblas_status_not_implemented; #endif diff --git a/library/src/refact/rocrefact_csrrf_refactlu.hpp b/library/src/refact/rocrefact_csrrf_refactlu.hpp index dff0bb7f1..bf57cd9e3 100644 --- a/library/src/refact/rocrefact_csrrf_refactlu.hpp +++ b/library/src/refact/rocrefact_csrrf_refactlu.hpp @@ -51,7 +51,6 @@ ROCSOLVER_KERNEL void rf_add_PAQ_kernel(const rocblas_int n, rocblas_int* Ap, rocblas_int* Ai, T* Ax, - const T beta, rocblas_int* LUp, rocblas_int* LUi, T* LUx) @@ -74,15 +73,6 @@ ROCSOLVER_KERNEL void rf_add_PAQ_kernel(const rocblas_int n, rocblas_int iend_old = Ap[irow_old + 1]; rocblas_int i_old, icol_old; - // ---------------- - // scale B by beta - // ---------------- - for(i = istart + tiy; i < iend; i += hipBlockDim_y) - { - LUx[i] *= beta; - } - __syncthreads(); - // ------------------------------ // scale A by alpha and add to B // ------------------------------ @@ -194,6 +184,9 @@ rocblas_status rocsolver_csrrf_refactlu_template(rocblas_handle handle, ROCSOLVER_LAUNCH_KERNEL(rf_ipvec_kernel, dim3(nblocks), dim3(BS2), 0, stream, n, pivQ, (rocblas_int*)work); + // set T to zero + HIP_CHECK(hipMemsetAsync((void*)valT, 0, sizeof(T) * nnzT, stream)); + // --------------------------------------------------------------------- // copy P*A*Q into T // Note: the sparsity pattern of A is a subset of T, and since the re-orderings @@ -201,10 +194,8 @@ rocblas_status rocsolver_csrrf_refactlu_template(rocblas_handle handle, // yields the complete factorization of A. // --------------------------------------------------------------------- T const alpha = static_cast(1); - T const beta = static_cast(0); ROCSOLVER_LAUNCH_KERNEL(rf_add_PAQ_kernel, dim3(nblocks, 1), dim3(BS2, BS2), 0, stream, n, - pivP, (rocblas_int*)work, alpha, ptrA, indA, valA, beta, ptrT, indT, - valT); + pivP, (rocblas_int*)work, alpha, ptrA, indA, valA, ptrT, indT, valT); // perform incomplete factorization of T ROCSPARSE_CHECK(rocsparseCall_csrilu0(rfinfo->sphandle, n, nnzT, rfinfo->descrT, valT, ptrT, diff --git a/library/src/specialized/roclapack_trsm_specialized_kernels.hpp b/library/src/specialized/roclapack_trsm_specialized_kernels.hpp index 650f4c120..2aff0dce6 100644 --- a/library/src/specialized/roclapack_trsm_specialized_kernels.hpp +++ b/library/src/specialized/roclapack_trsm_specialized_kernels.hpp @@ -787,28 +787,28 @@ void rocsolver_trsm_mem(const rocblas_side side, This is blocked implementation that calls the internal forward/backward subtitution kernels to solve the diagonal blocks, and uses gemm to update the right/left -hand-sides **/ template -void rocsolver_trsm_lower(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int inca, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int incb, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4) +rocblas_status rocsolver_trsm_lower(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int inca, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int incb, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4) { ROCSOLVER_ENTER("trsm_lower", "side:", side, "trans:", trans, "diag:", diag, "m:", m, "n:", n, "shiftA:", shiftA, "lda:", lda, "shiftB:", shiftB, "ldb:", ldb, @@ -842,22 +842,20 @@ void rocsolver_trsm_lower(rocblas_handle handle, if(blk == 0) { - rocblasCall_trsm(handle, side, rocblas_fill_lower, trans, diag, m, n, &one, A, shiftA, lda, - strideA, B, shiftB, ldb, strideB, batch_count, optim_mem, work1, work2, - work3, work4); - return; + return rocblasCall_trsm(handle, side, rocblas_fill_lower, trans, diag, m, n, &one, A, + shiftA, lda, strideA, B, shiftB, ldb, strideB, batch_count, + optim_mem, work1, work2, work3, work4); } // TODO: Some architectures require synchronization between rocSOLVER and rocBLAS kernels; more investigation needed int device; - hipGetDevice(&device); + HIP_CHECK(hipGetDevice(&device)); hipDeviceProp_t deviceProperties; - hipGetDeviceProperties(&deviceProperties, device); + HIP_CHECK(hipGetDeviceProperties(&deviceProperties, device)); std::string deviceFullString(deviceProperties.gcnArchName); std::string deviceString = deviceFullString.substr(0, deviceFullString.find(":")); bool do_sync = (deviceString.find("gfx940") != std::string::npos - || deviceString.find("gfx941") != std::string::npos - || deviceString.find("gfx942") != std::string::npos); + || deviceString.find("gfx941") != std::string::npos); // ****** MAIN LOOP *********** if(isleft) @@ -889,15 +887,15 @@ void rocsolver_trsm_lower(rocblas_handle handle, FORWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update right hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, rocblas_operation_none, rocblas_operation_none, m - nextpiv, n, blk, &minone, A, shiftA + idx2D(nextpiv, j, inca, lda), inca, lda, strideA, B, shiftB + idx2D(j, 0, incb, ldb), incb, ldb, strideB, &one, B, shiftB + idx2D(nextpiv, 0, incb, ldb), incb, ldb, strideB, batch_count, - (T**)nullptr); + (T**)nullptr)); j = nextpiv; } @@ -933,14 +931,14 @@ void rocsolver_trsm_lower(rocblas_handle handle, BACKWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update right hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, trans, rocblas_operation_none, m - nextpiv, n, blk, &minone, A, shiftA + idx2D(m - nextpiv, 0, inca, lda), inca, lda, strideA, B, shiftB + idx2D(m - nextpiv, 0, incb, ldb), incb, ldb, strideB, &one, B, - shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr); + shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr)); j = nextpiv; } @@ -989,14 +987,14 @@ void rocsolver_trsm_lower(rocblas_handle handle, BACKWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update left hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, rocblas_operation_none, rocblas_operation_none, m, n - nextpiv, blk, &minone, B, shiftB + idx2D(0, n - nextpiv, incb, ldb), incb, ldb, strideB, A, shiftA + idx2D(n - nextpiv, 0, inca, lda), inca, lda, strideA, &one, B, - shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr); + shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr)); j = nextpiv; } @@ -1032,15 +1030,15 @@ void rocsolver_trsm_lower(rocblas_handle handle, FORWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update left hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, rocblas_operation_none, trans, m, n - nextpiv, blk, &minone, B, shiftB + idx2D(0, j, incb, ldb), incb, ldb, strideB, A, shiftA + idx2D(nextpiv, j, inca, lda), inca, lda, strideA, &one, B, shiftB + idx2D(0, nextpiv, incb, ldb), incb, ldb, strideB, batch_count, - (T**)nullptr); + (T**)nullptr)); j = nextpiv; } @@ -1059,6 +1057,8 @@ void rocsolver_trsm_lower(rocblas_handle handle, FORWARD_SUBSTITUTIONS; } } + + return rocblas_status_success; } /** Internal TRSM (upper case): @@ -1070,28 +1070,28 @@ void rocsolver_trsm_lower(rocblas_handle handle, This is blocked implementation that calls the internal forward/backward subtitution kernels to solve the diagonal blocks, and uses gemm to update the right/left -hand-sides **/ template -void rocsolver_trsm_upper(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int inca, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int incb, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4) +rocblas_status rocsolver_trsm_upper(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int inca, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int incb, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4) { ROCSOLVER_ENTER("trsm_upper", "side:", side, "trans:", trans, "diag:", diag, "m:", m, "n:", n, "shiftA:", shiftA, "lda:", lda, "shiftB:", shiftB, "ldb:", ldb, @@ -1125,22 +1125,20 @@ void rocsolver_trsm_upper(rocblas_handle handle, if(blk == 0) { - rocblasCall_trsm(handle, side, rocblas_fill_upper, trans, diag, m, n, &one, A, shiftA, lda, - strideA, B, shiftB, ldb, strideB, batch_count, optim_mem, work1, work2, - work3, work4); - return; + return rocblasCall_trsm(handle, side, rocblas_fill_upper, trans, diag, m, n, &one, A, + shiftA, lda, strideA, B, shiftB, ldb, strideB, batch_count, + optim_mem, work1, work2, work3, work4); } // TODO: Some architectures require synchronization between rocSOLVER and rocBLAS kernels; more investigation needed int device; - hipGetDevice(&device); + HIP_CHECK(hipGetDevice(&device)); hipDeviceProp_t deviceProperties; - hipGetDeviceProperties(&deviceProperties, device); + HIP_CHECK(hipGetDeviceProperties(&deviceProperties, device)); std::string deviceFullString(deviceProperties.gcnArchName); std::string deviceString = deviceFullString.substr(0, deviceFullString.find(":")); bool do_sync = (deviceString.find("gfx940") != std::string::npos - || deviceString.find("gfx941") != std::string::npos - || deviceString.find("gfx942") != std::string::npos); + || deviceString.find("gfx941") != std::string::npos); // ****** MAIN LOOP *********** if(isleft) @@ -1172,15 +1170,15 @@ void rocsolver_trsm_upper(rocblas_handle handle, FORWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update right hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, trans, rocblas_operation_none, m - nextpiv, n, blk, &minone, A, shiftA + idx2D(j, nextpiv, inca, lda), inca, lda, strideA, B, shiftB + idx2D(j, 0, incb, ldb), incb, ldb, strideB, &one, B, shiftB + idx2D(nextpiv, 0, incb, ldb), incb, ldb, strideB, batch_count, - (T**)nullptr); + (T**)nullptr)); j = nextpiv; } @@ -1216,14 +1214,14 @@ void rocsolver_trsm_upper(rocblas_handle handle, BACKWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update right hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, rocblas_operation_none, rocblas_operation_none, m - nextpiv, n, blk, &minone, A, shiftA + idx2D(0, m - nextpiv, inca, lda), inca, lda, strideA, B, shiftB + idx2D(m - nextpiv, 0, incb, ldb), incb, ldb, strideB, &one, B, - shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr); + shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr)); j = nextpiv; } @@ -1272,14 +1270,14 @@ void rocsolver_trsm_upper(rocblas_handle handle, BACKWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update left hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, rocblas_operation_none, trans, m, n - nextpiv, blk, &minone, B, shiftB + idx2D(0, n - nextpiv, incb, ldb), incb, ldb, strideB, A, shiftA + idx2D(0, n - nextpiv, inca, lda), inca, lda, strideA, &one, B, - shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr); + shiftB + idx2D(0, 0, incb, ldb), incb, ldb, strideB, batch_count, (T**)nullptr)); j = nextpiv; } @@ -1315,15 +1313,15 @@ void rocsolver_trsm_upper(rocblas_handle handle, FORWARD_SUBSTITUTIONS; if(do_sync) - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); // update left hand sides - rocsolver_gemm( + ROCBLAS_CHECK(rocsolver_gemm( handle, rocblas_operation_none, rocblas_operation_none, m, n - nextpiv, blk, &minone, B, shiftB + idx2D(0, j, incb, ldb), incb, ldb, strideB, A, shiftA + idx2D(j, nextpiv, inca, lda), inca, lda, strideA, &one, B, shiftB + idx2D(0, nextpiv, incb, ldb), incb, ldb, strideB, batch_count, - (T**)nullptr); + (T**)nullptr)); j = nextpiv; } @@ -1342,6 +1340,8 @@ void rocsolver_trsm_upper(rocblas_handle handle, FORWARD_SUBSTITUTIONS; } } + + return rocblas_status_success; } /************************************************************* @@ -1349,57 +1349,57 @@ void rocsolver_trsm_upper(rocblas_handle handle, *************************************************************/ template -inline void rocsolver_trsm_lower(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4) +inline rocblas_status rocsolver_trsm_lower(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4) { - rocsolver_trsm_lower(handle, side, trans, diag, m, n, A, shiftA, 1, lda, - strideA, B, shiftB, 1, ldb, strideB, batch_count, - optim_mem, work1, work2, work3, work4); + return rocsolver_trsm_lower( + handle, side, trans, diag, m, n, A, shiftA, 1, lda, strideA, B, shiftB, 1, ldb, strideB, + batch_count, optim_mem, work1, work2, work3, work4); } template -inline void rocsolver_trsm_upper(rocblas_handle handle, - const rocblas_side side, - const rocblas_operation trans, - const rocblas_diagonal diag, - const rocblas_int m, - const rocblas_int n, - U A, - const rocblas_int shiftA, - const rocblas_int lda, - const rocblas_stride strideA, - U B, - const rocblas_int shiftB, - const rocblas_int ldb, - const rocblas_stride strideB, - const rocblas_int batch_count, - const bool optim_mem, - void* work1, - void* work2, - void* work3, - void* work4) +inline rocblas_status rocsolver_trsm_upper(rocblas_handle handle, + const rocblas_side side, + const rocblas_operation trans, + const rocblas_diagonal diag, + const rocblas_int m, + const rocblas_int n, + U A, + const rocblas_int shiftA, + const rocblas_int lda, + const rocblas_stride strideA, + U B, + const rocblas_int shiftB, + const rocblas_int ldb, + const rocblas_stride strideB, + const rocblas_int batch_count, + const bool optim_mem, + void* work1, + void* work2, + void* work3, + void* work4) { - rocsolver_trsm_upper(handle, side, trans, diag, m, n, A, shiftA, 1, lda, - strideA, B, shiftB, 1, ldb, strideB, batch_count, - optim_mem, work1, work2, work3, work4); + return rocsolver_trsm_upper( + handle, side, trans, diag, m, n, A, shiftA, 1, lda, strideA, B, shiftB, 1, ldb, strideB, + batch_count, optim_mem, work1, work2, work3, work4); } /************************************************************* @@ -1413,7 +1413,7 @@ inline void rocsolver_trsm_upper(rocblas_handle handle, size_t* size_work2, size_t* size_work3, size_t* size_work4, bool* optim_mem, \ bool inblocked, const rocblas_int, const rocblas_int) #define INSTANTIATE_TRSM_LOWER(BATCHED, STRIDED, T, U) \ - template void rocsolver_trsm_lower( \ + template rocblas_status rocsolver_trsm_lower( \ rocblas_handle handle, const rocblas_side side, const rocblas_operation trans, \ const rocblas_diagonal diag, const rocblas_int m, const rocblas_int n, U A, \ const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, U B, \ @@ -1421,7 +1421,7 @@ inline void rocsolver_trsm_upper(rocblas_handle handle, const rocblas_int batch_count, const bool optim_mem, void* work1, void* work2, \ void* work3, void* work4) #define INSTANTIATE_TRSM_UPPER(BATCHED, STRIDED, T, U) \ - template void rocsolver_trsm_upper( \ + template rocblas_status rocsolver_trsm_upper( \ rocblas_handle handle, const rocblas_side side, const rocblas_operation trans, \ const rocblas_diagonal diag, const rocblas_int m, const rocblas_int n, U A, \ const rocblas_int shiftA, const rocblas_int lda, const rocblas_stride strideA, U B, \