Skip to content

Commit

Permalink
Hotfix: dgmm test fix (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
daineAMD authored Aug 26, 2020
1 parent 5299dc5 commit 50b865f
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 57 deletions.
64 changes: 32 additions & 32 deletions clients/gtest/dgmm_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,22 @@ TEST_P(dgmm_gtest, dgmm_gtest_float)
// The Arguments data struture have physical meaning associated.
// while the tuple is non-intuitive.

// Arguments arg = setup_dgmm_arguments(GetParam());

// hipblasStatus_t status = testing_dgmm<float>(arg);

// // if not success, then the input argument is problematic, so detect the error message
// if(status != HIPBLAS_STATUS_SUCCESS)
// {
// if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M || arg.incx == 0)
// {
// EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
// }
// else
// {
// EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
// }
// }
Arguments arg = setup_dgmm_arguments(GetParam());

hipblasStatus_t status = testing_dgmm<float>(arg);

// if not success, then the input argument is problematic, so detect the error message
if(status != HIPBLAS_STATUS_SUCCESS)
{
if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M)
{
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
}
else
{
EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
}
}
}

TEST_P(dgmm_gtest, dgmm_gtest_float_complex)
Expand All @@ -147,22 +147,22 @@ TEST_P(dgmm_gtest, dgmm_gtest_float_complex)
// The Arguments data struture have physical meaning associated.
// while the tuple is non-intuitive.

// Arguments arg = setup_dgmm_arguments(GetParam());

// hipblasStatus_t status = testing_dgmm<hipblasComplex>(arg);

// // if not success, then the input argument is problematic, so detect the error message
// if(status != HIPBLAS_STATUS_SUCCESS)
// {
// if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M || arg.incx == 0)
// {
// EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
// }
// else
// {
// EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
// }
// }
Arguments arg = setup_dgmm_arguments(GetParam());

hipblasStatus_t status = testing_dgmm<hipblasComplex>(arg);

// if not success, then the input argument is problematic, so detect the error message
if(status != HIPBLAS_STATUS_SUCCESS)
{
if(arg.M < 0 || arg.N < 0 || arg.lda < arg.M || arg.ldc < arg.M)
{
EXPECT_EQ(HIPBLAS_STATUS_INVALID_VALUE, status);
}
else
{
EXPECT_EQ(HIPBLAS_STATUS_SUCCESS, status); // fail
}
}
}

TEST_P(dgmm_gtest, dgmm_batched_gtest_float)
Expand Down
2 changes: 1 addition & 1 deletion clients/include/hipblas_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class device_batch_vector : private d_vector<T, PAD, U>
return data[n];
}

operator T**()
operator T* *()
{
return data;
}
Expand Down
8 changes: 5 additions & 3 deletions clients/include/testing_dgmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ hipblasStatus_t testing_dgmm(Arguments argus)
int C_size = size_t(ldc) * N;
int k = (side == HIPBLAS_SIDE_RIGHT ? N : M);
int X_size = size_t(incx) * k;
if(!X_size)
X_size = 1;

hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS;

// argument sanity check, quick return if input parameters are invalid before allocating invalid
// memory
if(M < 0 || N < 0 || lda < M || ldc < M || incx == 0)
if(M < 0 || N < 0 || lda < M || ldc < M)
{
status = HIPBLAS_STATUS_INVALID_VALUE;
return status;
Expand Down Expand Up @@ -110,11 +112,11 @@ hipblasStatus_t testing_dgmm(Arguments argus)
{
if(HIPBLAS_SIDE_RIGHT == side)
{
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] + hx_copy[i2 * incx];
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] * hx_copy[i2 * incx];
}
else
{
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] + hx_copy[i1 * incx];
hC_gold[i1 + i2 * ldc] = hA_copy[i1 + i2 * lda] * hx_copy[i1 * incx];
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions clients/include/testing_dgmm_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ hipblasStatus_t testing_dgmm_batched(Arguments argus)
int C_size = size_t(ldc) * N;
int k = (side == HIPBLAS_SIDE_RIGHT ? N : M);
int X_size = size_t(incx) * k;
if(!X_size)
X_size = 1;

hipblasStatus_t status = HIPBLAS_STATUS_SUCCESS;

// argument sanity check, quick return if input parameters are invalid before allocating invalid
// memory
if(M < 0 || N < 0 || lda < M || ldc < M || incx == 0 || batch_count < 0)
if(M < 0 || N < 0 || lda < M || ldc < M || batch_count < 0)
{
status = HIPBLAS_STATUS_INVALID_VALUE;
return status;
Expand Down Expand Up @@ -142,12 +144,12 @@ hipblasStatus_t testing_dgmm_batched(Arguments argus)
if(HIPBLAS_SIDE_RIGHT == side)
{
hC_gold[b][i1 + i2 * ldc]
= hA_copy[b][i1 + i2 * lda] + hx_copy[b][i2 * incx];
= hA_copy[b][i1 + i2 * lda] * hx_copy[b][i2 * incx];
}
else
{
hC_gold[b][i1 + i2 * ldc]
= hA_copy[b][i1 + i2 * lda] + hx_copy[b][i1 * incx];
= hA_copy[b][i1 + i2 * lda] * hx_copy[b][i1 * incx];
}
}
}
Expand Down
8 changes: 5 additions & 3 deletions clients/include/testing_dgmm_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ hipblasStatus_t testing_dgmm_strided_batched(Arguments argus)
int stride_A = size_t(lda) * N * stride_scale;
int stride_x = size_t(incx) * k * stride_scale;
int stride_C = size_t(ldc) * N * stride_scale;
if(!stride_x)
stride_x = 1;

int A_size = stride_A * batch_count;
int C_size = stride_C * batch_count;
Expand All @@ -49,7 +51,7 @@ hipblasStatus_t testing_dgmm_strided_batched(Arguments argus)

// argument sanity check, quick return if input parameters are invalid before allocating invalid
// memory
if(M < 0 || N < 0 || lda < M || ldc < M || incx == 0 || batch_count < 0)
if(M < 0 || N < 0 || lda < M || ldc < M || batch_count < 0)
{
status = HIPBLAS_STATUS_INVALID_VALUE;
return status;
Expand Down Expand Up @@ -123,11 +125,11 @@ hipblasStatus_t testing_dgmm_strided_batched(Arguments argus)
{
if(HIPBLAS_SIDE_RIGHT == side)
{
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] + hx_copyb[i2 * incx];
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] * hx_copyb[i2 * incx];
}
else
{
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] + hx_copyb[i1 * incx];
hC_goldb[i1 + i2 * ldc] = hA_copyb[i1 + i2 * lda] * hx_copyb[i1 * incx];
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion clients/include/testing_geqrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_geqrf(Arguments argus)
{
bool FORTRAN = argus.fortran;
bool FORTRAN = argus.fortran;
auto hipblasGeqrfFn = FORTRAN ? hipblasGeqrf<T, true> : hipblasGeqrf<T, false>;

int M = argus.M;
Expand Down
5 changes: 3 additions & 2 deletions clients/include/testing_geqrf_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_geqrf_batched(Arguments argus)
{
bool FORTRAN = argus.fortran;
auto hipblasGeqrfBatchedFn = FORTRAN ? hipblasGeqrfBatched<T, true> : hipblasGeqrfBatched<T, false>;
bool FORTRAN = argus.fortran;
auto hipblasGeqrfBatchedFn
= FORTRAN ? hipblasGeqrfBatched<T, true> : hipblasGeqrfBatched<T, false>;

int M = argus.M;
int N = argus.N;
Expand Down
5 changes: 3 additions & 2 deletions clients/include/testing_geqrf_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_geqrf_strided_batched(Arguments argus)
{
bool FORTRAN = argus.fortran;
auto hipblasGeqrfStridedBatchedFn = FORTRAN ? hipblasGeqrfStridedBatched<T, true> : hipblasGeqrfStridedBatched<T, false>;
bool FORTRAN = argus.fortran;
auto hipblasGeqrfStridedBatchedFn
= FORTRAN ? hipblasGeqrfStridedBatched<T, true> : hipblasGeqrfStridedBatched<T, false>;

int M = argus.M;
int N = argus.N;
Expand Down
2 changes: 1 addition & 1 deletion clients/include/testing_getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_getrf(Arguments argus)
{
bool FORTRAN = argus.fortran;
bool FORTRAN = argus.fortran;
auto hipblasGetrfFn = FORTRAN ? hipblasGetrf<T, true> : hipblasGetrf<T, false>;

int M = argus.N;
Expand Down
5 changes: 3 additions & 2 deletions clients/include/testing_getrf_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_getrf_batched(Arguments argus)
{
bool FORTRAN = argus.fortran;
auto hipblasGetrfBatchedFn = FORTRAN ? hipblasGetrfBatched<T, true> : hipblasGetrfBatched<T, false>;
bool FORTRAN = argus.fortran;
auto hipblasGetrfBatchedFn
= FORTRAN ? hipblasGetrfBatched<T, true> : hipblasGetrfBatched<T, false>;

int M = argus.N;
int N = argus.N;
Expand Down
5 changes: 3 additions & 2 deletions clients/include/testing_getrf_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_getrf_strided_batched(Arguments argus)
{
bool FORTRAN = argus.fortran;
auto hipblasGetrfStridedBatchedFn = FORTRAN ? hipblasGetrfStridedBatched<T, true> : hipblasGetrfStridedBatched<T, false>;
bool FORTRAN = argus.fortran;
auto hipblasGetrfStridedBatchedFn
= FORTRAN ? hipblasGetrfStridedBatched<T, true> : hipblasGetrfStridedBatched<T, false>;

int M = argus.N;
int N = argus.N;
Expand Down
2 changes: 1 addition & 1 deletion clients/include/testing_getrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_getrs(Arguments argus)
{
bool FORTRAN = argus.fortran;
bool FORTRAN = argus.fortran;
auto hipblasGetrsFn = FORTRAN ? hipblasGetrs<T, true> : hipblasGetrs<T, false>;

int N = argus.N;
Expand Down
5 changes: 3 additions & 2 deletions clients/include/testing_getrs_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_getrs_batched(Arguments argus)
{
bool FORTRAN = argus.fortran;
auto hipblasGetrsBatchedFn = FORTRAN ? hipblasGetrsBatched<T, true> : hipblasGetrsBatched<T, false>;
bool FORTRAN = argus.fortran;
auto hipblasGetrsBatchedFn
= FORTRAN ? hipblasGetrsBatched<T, true> : hipblasGetrsBatched<T, false>;

int N = argus.N;
int lda = argus.lda;
Expand Down
5 changes: 3 additions & 2 deletions clients/include/testing_getrs_strided_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ using namespace std;
template <typename T, typename U>
hipblasStatus_t testing_getrs_strided_batched(Arguments argus)
{
bool FORTRAN = argus.fortran;
auto hipblasGetrsStridedBatchedFn = FORTRAN ? hipblasGetrsStridedBatched<T, true> : hipblasGetrsStridedBatched<T, false>;
bool FORTRAN = argus.fortran;
auto hipblasGetrsStridedBatchedFn
= FORTRAN ? hipblasGetrsStridedBatched<T, true> : hipblasGetrsStridedBatched<T, false>;

int N = argus.N;
int lda = argus.lda;
Expand Down

0 comments on commit 50b865f

Please sign in to comment.