Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new LASR in STEQR and BDSQR and parallelize STEQR #897

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 32 additions & 143 deletions library/src/auxiliary/rocauxiliary_bdsqr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* Univ. of Tennessee, Univ. of California Berkeley,
* Univ. of Colorado Denver and NAG Ltd..
* June 2017
* Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -33,6 +33,7 @@
#pragma once

#include "lapack_device_functions.hpp"
#include "rocauxiliary_lasr.hpp"
#include "rocblas.hpp"
#include "rocsolver/rocsolver.h"

Expand Down Expand Up @@ -80,6 +81,7 @@ __device__ T bdsqr_estimate(const rocblas_int n, T* D, T* E, int t2b, T tol, int
the n-by-n bidiagonal matrix given by D and E using shift = sh **/
template <typename T, typename S>
__device__ void bdsqr_QRstep(const rocblas_int tid,
const rocblas_int tid_inc,
const rocblas_int t2b,
const rocblas_int n,
const rocblas_int nv,
Expand Down Expand Up @@ -130,7 +132,7 @@ __device__ void bdsqr_QRstep(const rocblas_int tid,
if(t2b && nv)
{
rots[ek] = c;
rots[ek + n] = s;
rots[ek + n] = -s;
}
if(b2t && (nu || nc))
{
Expand Down Expand Up @@ -158,7 +160,7 @@ __device__ void bdsqr_QRstep(const rocblas_int tid,
if(t2b && (nu || nc))
{
rots[ek + nr] = c;
rots[ek + nr + n] = s;
rots[ek + nr + n] = -s;
}

dk += dir;
Expand All @@ -171,71 +173,21 @@ __device__ void bdsqr_QRstep(const rocblas_int tid,
__syncthreads();

// update singular vectors
rocblas_direct direc = (t2b ? rocblas_forward_direction : rocblas_backward_direction);
if(V && nv)
{
// rotate from the left
for(rocblas_int j = tid; j < nv; j += hipBlockDim_x)
{
rocblas_int k = (t2b ? 0 : n - 1);
rocblas_int rk = (t2b ? 0 : n - 2);

temp1 = V[k + j * ldv];
for(rocblas_int kk = 0; kk < n - 1; kk++)
{
temp2 = V[(k + dir) + j * ldv];
c = rots[rk];
s = rots[rk + n];
V[k + j * ldv] = c * temp1 - s * temp2;
V[(k + dir) + j * ldv] = temp1 = c * temp2 + s * temp1;

k += dir;
rk += dir;
}
}
run_lasr(rocblas_side_left, rocblas_pivot_variable, direc, n, nv, rots, rots + n, V, ldv,
tid, tid_inc);
}
if(U && nu)
{
// rotate from the right
for(rocblas_int i = tid; i < nu; i += hipBlockDim_x)
{
rocblas_int k = (t2b ? 0 : n - 1);
rocblas_int rk = (t2b ? nr : (n - 2) + nr);

temp1 = U[i + k * ldu];
for(rocblas_int kk = 0; kk < n - 1; kk++)
{
temp2 = U[i + (k + dir) * ldu];
c = rots[rk];
s = rots[rk + n];
U[i + k * ldu] = c * temp1 - s * temp2;
U[i + (k + dir) * ldu] = temp1 = c * temp2 + s * temp1;

k += dir;
rk += dir;
}
}
run_lasr(rocblas_side_right, rocblas_pivot_variable, direc, nu, n, rots + nr, rots + nr + n,
U, ldu, tid, tid_inc);
}
if(C && nc)
{
// rotate from the left
for(rocblas_int j = tid; j < nc; j += hipBlockDim_x)
{
rocblas_int k = (t2b ? 0 : n - 1);
rocblas_int rk = (t2b ? nr : (n - 2) + nr);

temp1 = C[k + j * ldc];
for(rocblas_int kk = 0; kk < n - 1; kk++)
{
temp2 = C[(k + dir) + j * ldc];
c = rots[rk];
s = rots[rk + n];
C[k + j * ldc] = c * temp1 - s * temp2;
C[(k + dir) + j * ldc] = temp1 = c * temp2 + s * temp1;

k += dir;
rk += dir;
}
}
run_lasr(rocblas_side_left, rocblas_pivot_variable, direc, n, nc, rots + nr, rots + nr + n,
C, ldc, tid, tid_inc);
}
}

Expand Down Expand Up @@ -475,6 +427,7 @@ ROCSOLVER_KERNEL void bdsqr_lower2upper(const rocblas_int n,
rocblas_int* completed)
{
rocblas_int tid = hipThreadIdx_x;
rocblas_int tid_inc = hipBlockDim_x;
rocblas_int bid = hipBlockIdx_y;

if(completed[bid + 2])
Expand Down Expand Up @@ -510,7 +463,7 @@ ROCSOLVER_KERNEL void bdsqr_lower2upper(const rocblas_int n,
if(nu || nc)
{
rots[i] = c;
rots[i + n] = s;
rots[i + n] = -s;
}
}
D[n - 1] = f;
Expand All @@ -520,35 +473,13 @@ ROCSOLVER_KERNEL void bdsqr_lower2upper(const rocblas_int n,
// update singular vectors
if(nu)
{
// rotate from the right (forward direction)
for(i = tid; i < nu; i += hipBlockDim_x)
{
temp1 = U[i + 0 * ldu];
for(j = 0; j < n - 1; j++)
{
temp2 = U[i + (j + 1) * ldu];
c = rots[j];
s = rots[j + n];
U[i + j * ldu] = c * temp1 - s * temp2;
U[i + (j + 1) * ldu] = temp1 = c * temp2 + s * temp1;
}
}
run_lasr(rocblas_side_right, rocblas_pivot_variable, rocblas_forward_direction, nu, n, rots,
rots + n, U, ldu, tid, tid_inc);
}
if(nc)
{
// rotate from the left (forward direction)
for(j = tid; j < nc; j += hipBlockDim_x)
{
temp1 = C[0 + j * ldc];
for(i = 0; i < n - 1; i++)
{
temp2 = C[(i + 1) + j * ldc];
c = rots[i];
s = rots[i + n];
C[i + j * ldc] = c * temp1 - s * temp2;
C[(i + 1) + j * ldc] = temp1 = c * temp2 + s * temp1;
}
}
run_lasr(rocblas_side_left, rocblas_pivot_variable, rocblas_forward_direction, n, nc, rots,
rots + n, C, ldc, tid, tid_inc);
}
}

Expand Down Expand Up @@ -587,6 +518,7 @@ ROCSOLVER_KERNEL void bdsqr_compute(const rocblas_int n,
rocblas_int* completed)
{
rocblas_int tid = hipThreadIdx_x;
rocblas_int tid_inc = hipBlockDim_x;
rocblas_int sid_start = hipBlockIdx_y;
rocblas_int bid = hipBlockIdx_z;

Expand Down Expand Up @@ -672,8 +604,8 @@ ROCSOLVER_KERNEL void bdsqr_compute(const rocblas_int n,
if(tid == 0)
splits[4 * sid] = (t2b ? 1 : -1);

bdsqr_QRstep(tid, t2b, k - i + 1, nv, nu, nc, D + i, E + i, V + i, ldv, U + i * ldu,
ldu, C + i, ldc, smin, rots + incW * i);
bdsqr_QRstep(tid, tid_inc, t2b, k - i + 1, nv, nu, nc, D + i, E + i, V + i, ldv,
U + i * ldu, ldu, C + i, ldc, smin, rots + incW * i);
}
else
{
Expand Down Expand Up @@ -710,6 +642,7 @@ ROCSOLVER_KERNEL void bdsqr_rotate(const rocblas_int n,
rocblas_int* completed)
{
rocblas_int tid = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
rocblas_int tid_inc = hipGridDim_x * hipBlockDim_x;
rocblas_int sid_start = hipBlockIdx_y;
rocblas_int bid = hipBlockIdx_z;

Expand Down Expand Up @@ -753,68 +686,24 @@ ROCSOLVER_KERNEL void bdsqr_rotate(const rocblas_int n,
{
S* rots = work + 4 + incW * k_start;

rocblas_int t2b = (dir > 0 ? 1 : 0);
rocblas_int b2t = 1 - t2b;

rocblas_int nn = k_end - k_start + 1;
rocblas_int nr = nv ? 2 * nn : 0;

if(V && tid < nv)
rocblas_direct direc = (dir > 0 ? rocblas_forward_direction : rocblas_backward_direction);
if(V && nv)
{
// rotate from the left
rocblas_int k = (t2b ? k_start : k_end);
rocblas_int rk = (t2b ? 0 : nn - 2);

temp1 = V[k + tid * ldv];
for(rocblas_int kk = k_start; kk < k_end; kk++)
{
temp2 = V[(k + dir) + tid * ldv];
c = rots[rk];
s = rots[rk + nn];
V[k + tid * ldv] = c * temp1 - s * temp2;
V[(k + dir) + tid * ldv] = temp1 = c * temp2 + s * temp1;

k += dir;
rk += dir;
}
run_lasr(rocblas_side_left, rocblas_pivot_variable, direc, nn, nv, rots, rots + nn,
V + k_start, ldv, tid, tid_inc);
}
if(U && tid < nu)
if(U && nu)
{
// rotate from the right
rocblas_int k = (t2b ? k_start : k_end);
rocblas_int rk = (t2b ? nr : (nn - 2) + nr);

temp1 = U[tid + k * ldu];
for(rocblas_int kk = k_start; kk < k_end; kk++)
{
temp2 = U[tid + (k + dir) * ldu];
c = rots[rk];
s = rots[rk + nn];
U[tid + k * ldu] = c * temp1 - s * temp2;
U[tid + (k + dir) * ldu] = temp1 = c * temp2 + s * temp1;

k += dir;
rk += dir;
}
run_lasr(rocblas_side_right, rocblas_pivot_variable, direc, nu, nn, rots + nr,
rots + nr + nn, U + k_start * ldu, ldu, tid, tid_inc);
}
if(C && tid < nc)
if(C && nc)
{
// rotate from the left
rocblas_int k = (t2b ? k_start : k_end);
rocblas_int rk = (t2b ? nr : (nn - 2) + nr);

temp1 = C[k + tid * ldc];
for(rocblas_int kk = k_start; kk < k_end; kk++)
{
temp2 = C[(k + dir) + tid * ldc];
c = rots[rk];
s = rots[rk + nn];
C[k + tid * ldc] = c * temp1 - s * temp2;
C[(k + dir) + tid * ldc] = temp1 = c * temp2 + s * temp1;

k += dir;
rk += dir;
}
run_lasr(rocblas_side_left, rocblas_pivot_variable, direc, nn, nc, rots + nr,
rots + nr + nn, C + k_start, ldc, tid, tid_inc);
}
}
}
Expand Down
51 changes: 35 additions & 16 deletions library/src/auxiliary/rocauxiliary_bdsqr_hybrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,56 @@ ROCSOLVER_BEGIN_NAMESPACE
/************************************************************************************/

template <typename S, typename T, typename I>
static void swap_template(I const n, T* x, I const incx, T* y, I const incy, hipStream_t stream)
static void swap_template(rocblas_handle handle,
I const n,
T* x,
I const incx,
T* y,
I const incy,
hipStream_t stream)
{
auto nthreads = warpSize * 2;
auto nblocks = (n - 1) / nthreads + 1;

hipLaunchKernelGGL((swap_kernel<S, T, I>), dim3(nblocks, 1, 1), dim3(nthreads, 1, 1), 0, stream,
n, x, incx, y, incy);
ROCSOLVER_LAUNCH_KERNEL((swap_kernel<S, T, I>), dim3(nblocks, 1, 1), dim3(nthreads, 1, 1), 0,
stream, n, x, incx, y, incy);
}

template <typename S, typename T, typename I>
static void
rot_template(I const n, T* x, I const incx, T* y, I const incy, S const c, S const s, hipStream_t stream)
static void rot_template(rocblas_handle handle,
I const n,
T* x,
I const incx,
T* y,
I const incy,
S const c,
S const s,
hipStream_t stream)
{
auto nthreads = warpSize * 2;
auto nblocks = (n - 1) / nthreads + 1;

hipLaunchKernelGGL((rot_kernel<S, T, I>), dim3(nblocks, 1, 1), dim3(nthreads, 1, 1), 0, stream,
n, x, incx, y, incy, c, s);
ROCSOLVER_LAUNCH_KERNEL((rot_kernel<S, T, I>), dim3(nblocks, 1, 1), dim3(nthreads, 1, 1), 0,
stream, n, x, incx, y, incy, c, s);
}

template <typename S, typename T, typename I>
static void scal_template(I const n, S const da, T* const x, I const incx, hipStream_t stream)
static void scal_template(rocblas_handle handle,
I const n,
S const da,
T* const x,
I const incx,
hipStream_t stream)
{
auto nthreads = warpSize * 2;
auto nblocks = (n - 1) / nthreads + 1;

hipLaunchKernelGGL((scal_kernel<S, T, I>), dim3(nblocks, 1, 1), dim3(nthreads, 1, 1), 0, stream,
n, da, x, incx);
ROCSOLVER_LAUNCH_KERNEL((scal_kernel<S, T, I>), dim3(nblocks, 1, 1), dim3(nthreads, 1, 1), 0,
stream, n, da, x, incx);
}

/** Call to lasr functionality.
lasr_body can be executed as a host or device function **/
run_lasr can be executed as a host or device function **/
template <typename S, typename T, typename I>
static void call_lasr(rocblas_side& side,
rocblas_pivot& pivot,
Expand All @@ -90,7 +108,7 @@ static void call_lasr(rocblas_side& side,
I const tid = 0;
I const i_inc = 1;

lasr_body<T, S, I>(side, pivot, direct, m, n, &c, &s, &A, lda, tid, i_inc);
run_lasr<T, S, I>(side, pivot, direct, m, n, &c, &s, &A, lda, tid, i_inc);
}

/************************************************************************************/
Expand Down Expand Up @@ -121,15 +139,16 @@ static void bdsqr_single_template(rocblas_handle handle,
// Lambda expressions used as helpers
// -------------------------------------
auto call_swap_gpu = [=](I n, T& x, I incx, T& y, I incy) {
swap_template<S, T, I>(n, &x, incx, &y, incy, stream);
swap_template<S, T, I>(handle, n, &x, incx, &y, incy, stream);
};

auto call_rot_gpu = [=](I n, T& x, I incx, T& y, I incy, S cosl, S sinl) {
rot_template<S, T, I>(n, &x, incx, &y, incy, cosl, sinl, stream);
rot_template<S, T, I>(handle, n, &x, incx, &y, incy, cosl, sinl, stream);
};

auto call_scal_gpu
= [=](I n, auto da, T& x, I incx) { scal_template<S, T, I>(n, da, &x, incx, stream); };
auto call_scal_gpu = [=](I n, auto da, T& x, I incx) {
scal_template<S, T, I>(handle, n, da, &x, incx, stream);
};

auto call_lasr_gpu_nocopy
= [=](rocblas_side const side, rocblas_pivot const pivot, rocblas_direct const direct,
Expand Down
Loading