Skip to content

Commit 771f834

Browse files
authored
Merge pull request #154 from QingleiCao/qinglei/trmm_gpu
Qinglei/trmm gpu
2 parents ebc116a + b2589e5 commit 771f834

File tree

11 files changed

+527
-14
lines changed

11 files changed

+527
-14
lines changed

src/ztrmm_LLN.jdf

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ extern "C" %{
88
* @precisions normal z -> s d c
99
*
1010
*/
11+
#include "dplasma/config.h"
12+
#if defined(DPLASMA_HAVE_CUDA)
13+
#include <cublas.h>
14+
#endif /* defined(DPLASMA_HAVE_CUDA) */
1115
#include "dplasmajdf.h"
1216
#include "parsec/data_dist/matrix/matrix.h"
1317

@@ -54,6 +58,9 @@ descA [type = "const parsec_tiled_matrix_t*" hidden = on default = "((dplas
5458
ddescB [type = "dplasma_data_collection_t*"]
5559
descB [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescB)->dc_original" aligned=ddescB]
5660

61+
hip_handles_infokey [type = "int" hidden = on default = -1 ]
62+
63+
5764
read_A(m, k) [profile = off]
5865
/* Execution Space */
5966
m = 0..(descB->mt-1)
@@ -153,6 +160,61 @@ loc_C = %{ return LOC(descB, (descB->mt-1)-m, n); %}
153160
type_data = %{ return ADTT_DC(ddescB, loc_C, B_SHAPE, LAPACK); %} ]
154161
-> ((k+m) < (descB->mt-2)) ? C zgemm(m, n, k+1) /* dep OUT: rely on datacopy dtt for sending */
155162

163+
164+
BODY [type=CUDA]
165+
{
166+
#if defined(PRECISION_z) || defined(PRECISION_c)
167+
cuDoubleComplex lalpha = make_cuDoubleComplex(creal(alpha), cimag(alpha));
168+
cuDoubleComplex lbeta = make_cuDoubleComplex( 1., 0.);
169+
#else
170+
double lalpha = alpha;
171+
double lbeta = 1.0;
172+
#endif
173+
int tempmm = (((descB->mt-1)-m)==(descB->mt-1)) ? (descB->m-(((descB->mt-1)-m)*descB->mb)) : descB->mb;
174+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
175+
int lda = LDA(ddescA, A);
176+
int ldb = LDA(ddescB, B);
177+
int ldc = LDA(ddescB, C);
178+
179+
cublasStatus_t status;
180+
cublasSetKernelStream( parsec_body.stream );
181+
cublasZgemm( dplasma_lapack_const(trans), 'N',
182+
tempmm, tempnn, descB->mb,
183+
lalpha, (cuDoubleComplex*)A, lda,
184+
(cuDoubleComplex*)B, ldb,
185+
lbeta, (cuDoubleComplex*)C, ldc );
186+
status = cublasGetError();
187+
PARSEC_CUDA_CHECK_ERROR( "cublasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
188+
}
189+
END
190+
191+
BODY [type=HIP]
192+
{
193+
#if defined(PRECISION_z) || defined(PRECISION_c)
194+
hipDoubleComplex lalpha = make_hipDoubleComplex(creal(alpha), cimag(alpha));
195+
hipDoubleComplex lbeta = make_hipDoubleComplex( 1., 0.);
196+
#else
197+
double lalpha = alpha;
198+
double lbeta = 1.0;
199+
#endif
200+
int tempmm = (((descB->mt-1)-m)==(descB->mt-1)) ? (descB->m-(((descB->mt-1)-m)*descB->mb)) : descB->mb;
201+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
202+
int lda = LDA(ddescA, A);
203+
int ldb = LDA(ddescB, B);
204+
int ldc = LDA(ddescB, C);
205+
206+
hipblasStatus_t status;
207+
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
208+
assert(NULL != handles);
209+
status = hipblasZgemm( handles->hipblas_handle, dplasma_hipblas_op(trans), HIPBLAS_OP_N,
210+
tempmm, tempnn, descB->mb,
211+
&lalpha, (hipDoubleComplex*)A, lda,
212+
(hipDoubleComplex*)B, ldb,
213+
&lbeta, (hipDoubleComplex*)C, ldc );
214+
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
215+
}
216+
END
217+
156218
BODY
157219
{
158220
int tempmm = (((descB->mt-1)-m)==(descB->mt-1)) ? (descB->m-(((descB->mt-1)-m)*descB->mb)) : descB->mb;

src/ztrmm_LLT.jdf

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ extern "C" %{
88
* @precisions normal z -> s d c
99
*
1010
*/
11+
#include "dplasma/config.h"
12+
#if defined(DPLASMA_HAVE_CUDA)
13+
#include <cublas.h>
14+
#endif /* defined(DPLASMA_HAVE_CUDA) */
1115
#include "dplasmajdf.h"
1216
#include "parsec/data_dist/matrix/matrix.h"
1317

@@ -54,6 +58,8 @@ descA [type = "const parsec_tiled_matrix_t*" hidden = on default = "((dplas
5458
ddescB [type = "dplasma_data_collection_t*"]
5559
descB [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescB)->dc_original" aligned=ddescB]
5660

61+
hip_handles_infokey [type = "int" hidden = on default = -1 ]
62+
5763
read_A(m, k) [profile = off]
5864
/* Execution Space */
5965
m = 0 .. (descB->mt-1)
@@ -154,6 +160,62 @@ loc_C = %{ return LOC(descB, m, n); %}
154160
type_data = %{ return ADTT_DC(ddescB, loc_C, B_SHAPE, LAPACK); %} ]
155161
-> (k < (descB->mt-1)) ? C zgemm(m, n, k+1) /* dep OUT: rely on datacopy dtt for sending */
156162

163+
BODY [type=CUDA]
164+
{
165+
#if defined(PRECISION_z) || defined(PRECISION_c)
166+
cuDoubleComplex lalpha = make_cuDoubleComplex(creal(alpha), cimag(alpha));
167+
cuDoubleComplex lbeta = make_cuDoubleComplex( 1., 0.);
168+
#else
169+
double lalpha = alpha;
170+
double lbeta = 1.0;
171+
#endif
172+
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
173+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
174+
int tempkm = ((k)==(descA->mt-1)) ? (descA->m-(k*descA->mb)) : descA->mb;
175+
int lda = LDA(ddescA, A);
176+
int ldb = LDA(ddescB, B);
177+
int ldc = LDA(ddescB, C);
178+
179+
cublasStatus_t status;
180+
cublasSetKernelStream( parsec_body.stream );
181+
cublasZgemm( dplasma_lapack_const(trans), 'N',
182+
tempmm, tempnn, tempkm,
183+
lalpha, (cuDoubleComplex*)A, lda,
184+
(cuDoubleComplex*)B, ldb,
185+
lbeta, (cuDoubleComplex*)C, ldc );
186+
status = cublasGetError();
187+
PARSEC_CUDA_CHECK_ERROR( "cublasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
188+
}
189+
END
190+
191+
BODY [type=HIP]
192+
{
193+
#if defined(PRECISION_z) || defined(PRECISION_c)
194+
hipDoubleComplex lalpha = make_hipDoubleComplex(creal(alpha), cimag(alpha));
195+
hipDoubleComplex lbeta = make_hipDoubleComplex( 1., 0.);
196+
#else
197+
double lalpha = alpha;
198+
double lbeta = 1.0;
199+
#endif
200+
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
201+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
202+
int tempkm = ((k)==(descA->mt-1)) ? (descA->m-(k*descA->mb)) : descA->mb;
203+
int lda = LDA(ddescA, A);
204+
int ldb = LDA(ddescB, B);
205+
int ldc = LDA(ddescB, C);
206+
207+
hipblasStatus_t status;
208+
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
209+
assert(NULL != handles);
210+
status = hipblasZgemm( handles->hipblas_handle, dplasma_hipblas_op(trans), HIPBLAS_OP_N,
211+
tempmm, tempnn, tempkm,
212+
&lalpha, (hipDoubleComplex*)A, lda,
213+
(hipDoubleComplex*)B, ldb,
214+
&lbeta, (hipDoubleComplex*)C, ldc );
215+
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
216+
}
217+
END
218+
157219
BODY
158220
{
159221
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;

src/ztrmm_LUN.jdf

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ extern "C" %{
88
* @precisions normal z -> s d c
99
*
1010
*/
11+
#include "dplasma/config.h"
12+
#if defined(DPLASMA_HAVE_CUDA)
13+
#include <cublas.h>
14+
#endif /* defined(DPLASMA_HAVE_CUDA) */
1115
#include "dplasmajdf.h"
1216
#include "parsec/data_dist/matrix/matrix.h"
1317

@@ -54,6 +58,8 @@ descA [type = "const parsec_tiled_matrix_t*" hidden = on default = "((dplas
5458
ddescB [type = "dplasma_data_collection_t*"]
5559
descB [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescB)->dc_original" aligned=ddescB]
5660

61+
hip_handles_infokey [type = "int" hidden = on default = -1 ]
62+
5763
read_A(m, k) [profile = off]
5864
/* Execution Space */
5965
m = 0 .. (descB->mt-1)
@@ -154,6 +160,62 @@ loc_C = %{ return LOC(descB, m, n); %}
154160
type_data = %{ return ADTT_DC(ddescB, loc_C, B_SHAPE, LAPACK); %} ]
155161
-> (k < (descB->mt-1)) ? C zgemm(m, n, k+1) /* dep OUT: rely on datacopy dtt for sending */
156162

163+
BODY [type=CUDA]
164+
{
165+
#if defined(PRECISION_z) || defined(PRECISION_c)
166+
cuDoubleComplex lalpha = make_cuDoubleComplex(creal(alpha), cimag(alpha));
167+
cuDoubleComplex lbeta = make_cuDoubleComplex( 1., 0.);
168+
#else
169+
double lalpha = alpha;
170+
double lbeta = 1.0;
171+
#endif
172+
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
173+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
174+
int tempkn = ((k)==(descA->nt-1)) ? (descA->n-(k*descA->nb)) : descA->nb;
175+
int lda = LDA(ddescA, A);
176+
int ldb = LDA(ddescB, B);
177+
int ldc = LDA(ddescB, C);
178+
179+
cublasStatus_t status;
180+
cublasSetKernelStream( parsec_body.stream );
181+
cublasZgemm( dplasma_lapack_const(trans), 'N',
182+
tempmm, tempnn, tempkn,
183+
lalpha, (cuDoubleComplex*)A, lda,
184+
(cuDoubleComplex*)B, ldb,
185+
lbeta, (cuDoubleComplex*)C, ldc );
186+
status = cublasGetError();
187+
PARSEC_CUDA_CHECK_ERROR( "cublasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
188+
}
189+
END
190+
191+
BODY [type=HIP]
192+
{
193+
#if defined(PRECISION_z) || defined(PRECISION_c)
194+
hipDoubleComplex lalpha = make_hipDoubleComplex(creal(alpha), cimag(alpha));
195+
hipDoubleComplex lbeta = make_hipDoubleComplex( 1., 0.);
196+
#else
197+
double lalpha = alpha;
198+
double lbeta = 1.0;
199+
#endif
200+
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;
201+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
202+
int tempkn = ((k)==(descA->nt-1)) ? (descA->n-(k*descA->nb)) : descA->nb;
203+
int lda = LDA(ddescA, A);
204+
int ldb = LDA(ddescB, B);
205+
int ldc = LDA(ddescB, C);
206+
207+
hipblasStatus_t status;
208+
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
209+
assert(NULL != handles);
210+
status = hipblasZgemm( handles->hipblas_handle, dplasma_hipblas_op(trans), HIPBLAS_OP_N,
211+
tempmm, tempnn, tempkn,
212+
&lalpha, (hipDoubleComplex*)A, lda,
213+
(hipDoubleComplex*)B, ldb,
214+
&lbeta, (hipDoubleComplex*)C, ldc );
215+
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
216+
}
217+
END
218+
157219
BODY
158220
{
159221
int tempmm = ((m)==(descB->mt-1)) ? (descB->m-(m*descB->mb)) : descB->mb;

src/ztrmm_LUT.jdf

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ extern "C" %{
88
* @precisions normal z -> s d c
99
*
1010
*/
11+
#include "dplasma/config.h"
12+
#if defined(DPLASMA_HAVE_CUDA)
13+
#include <cublas.h>
14+
#endif /* defined(DPLASMA_HAVE_CUDA) */
1115
#include "dplasmajdf.h"
1216
#include "parsec/data_dist/matrix/matrix.h"
1317

@@ -54,6 +58,8 @@ descA [type = "const parsec_tiled_matrix_t*" hidden = on default = "((dplas
5458
ddescB [type = "dplasma_data_collection_t*"]
5559
descB [type = "parsec_tiled_matrix_t*" hidden = on default = "((dplasma_data_collection_t*)ddescB)->dc_original" aligned=ddescB]
5660

61+
hip_handles_infokey [type = "int" hidden = on default = -1 ]
62+
5763
read_A(m, k) [profile = off]
5864
/* Execution Space */
5965
m = 0..(descB->mt-1)
@@ -153,6 +159,60 @@ loc_C = %{ return LOC(descB, (descB->mt-1)-m, n); %}
153159
type_data = %{ return ADTT_DC(ddescB, loc_C, B_SHAPE, LAPACK); %} ]
154160
-> ((k+m) < (descB->mt-2)) ? C zgemm(m, n, k+1) /* dep OUT: rely on datacopy dtt for sending */
155161

162+
BODY [type=CUDA]
163+
{
164+
#if defined(PRECISION_z) || defined(PRECISION_c)
165+
cuDoubleComplex lalpha = make_cuDoubleComplex(creal(alpha), cimag(alpha));
166+
cuDoubleComplex lbeta = make_cuDoubleComplex( 1., 0.);
167+
#else
168+
double lalpha = alpha;
169+
double lbeta = 1.0;
170+
#endif
171+
int tempmm = (((descB->mt-1)-m)==(descB->mt-1)) ? (descB->m-(((descB->mt-1)-m)*descB->mb)) : descB->mb;
172+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
173+
int lda = LDA(ddescA, A);
174+
int ldb = LDA(ddescB, B);
175+
int ldc = LDA(ddescB, C);
176+
177+
cublasStatus_t status;
178+
cublasSetKernelStream( parsec_body.stream );
179+
cublasZgemm( dplasma_lapack_const(trans), 'N',
180+
tempmm, tempnn, descB->mb,
181+
lalpha, (cuDoubleComplex*)A, lda,
182+
(cuDoubleComplex*)B, ldb,
183+
lbeta, (cuDoubleComplex*)C, ldc );
184+
status = cublasGetError();
185+
PARSEC_CUDA_CHECK_ERROR( "cublasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
186+
}
187+
END
188+
189+
BODY [type=HIP]
190+
{
191+
#if defined(PRECISION_z) || defined(PRECISION_c)
192+
hipDoubleComplex lalpha = make_hipDoubleComplex(creal(alpha), cimag(alpha));
193+
hipDoubleComplex lbeta = make_hipDoubleComplex( 1., 0.);
194+
#else
195+
double lalpha = alpha;
196+
double lbeta = 1.0;
197+
#endif
198+
int tempmm = (((descB->mt-1)-m)==(descB->mt-1)) ? (descB->m-(((descB->mt-1)-m)*descB->mb)) : descB->mb;
199+
int tempnn = ((n)==(descB->nt-1)) ? (descB->n-(n*descB->nb)) : descB->nb;
200+
int lda = LDA(ddescA, A);
201+
int ldb = LDA(ddescB, B);
202+
int ldc = LDA(ddescB, C);
203+
204+
hipblasStatus_t status;
205+
dplasma_hip_handles_t *handles = parsec_info_get(&gpu_stream->infos, hip_handles_infokey);
206+
assert(NULL != handles);
207+
status = hipblasZgemm( handles->hipblas_handle, dplasma_hipblas_op(trans), HIPBLAS_OP_N,
208+
tempmm, tempnn, descB->mb,
209+
&lalpha, (hipDoubleComplex*)A, lda,
210+
(hipDoubleComplex*)B, ldb,
211+
&lbeta, (hipDoubleComplex*)C, ldc );
212+
DPLASMA_HIPBLAS_CHECK_ERROR( "hipblasZgemm ", status, {return PARSEC_HOOK_RETURN_ERROR;} );
213+
}
214+
END
215+
156216
BODY
157217
{
158218
int tempmm = (((descB->mt-1)-m)==(descB->mt-1)) ? (descB->m-(((descB->mt-1)-m)*descB->mb)) : descB->mb;

0 commit comments

Comments
 (0)