Skip to content

Commit

Permalink
fix single wmma tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 18, 2025
1 parent 469bc18 commit 75b1f33
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 145 deletions.
31 changes: 2 additions & 29 deletions include/cell/copy/shared_to_register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
static constexpr int kRowExec = kRowExec_;
static constexpr int kColExec = kColExec_;

DEVICE void operator()(const DType* src, Reg& dst, int start_offset) {
__device__ void operator()(const DType* src, Reg& dst, int start_offset) {
int lane_row = this->lane_row_id();
int lane_col = this->lane_col_id() * LoadMat::kNumPerAccess;

if (thread0()) {
printf("kRowExec: %d, kColExec: %d\n", kRowExec, kColExec);
}

#pragma unroll
for (int i = 0; i < kRowExec; ++i) {
#pragma unroll
Expand All @@ -52,19 +48,9 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
j * BaseShape::kCols + lane_row * kSharedCols + lane_col;
int offset = get_swizzle_offset(tile_offset);

// if (threadIdx.x == 32) {
// printf("thread(32) tile_offset: %d, offset: %d\n",
// tile_offset, offset);
// }

// advance pointer to the 16x16 `BaseTile` indexed by(i, j).
// issue the hardware-backed memory access instruction.
this->ldmatrix(src + offset, dst(i, j).mutable_data());

if (threadIdx.x == 32) {
printf("\nthread(32) dst(%d, %d):\n", i, j);
dst(i, j).dump_value();
}
}
}
}
Expand Down Expand Up @@ -150,7 +136,7 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
: base_tiles_(BaseTilesLayout{})
, in_base_tile_(BaseTileSharedLayout{}) {}

DEVICE void operator()(const DType* src, Reg& dst, int start_offset) {
__device__ void operator()(const DType* src, Reg& dst, int start_offset) {
// transpose the lane position if the shared memory is in
// column-major. 16 threads are mapped to the strided dimension
// of the data while the 2 threads are mapped to the contiguous
Expand All @@ -168,11 +154,6 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,

// issue the hardware-backed memory access instruction
this->ldmatrix(src + offset, dst(j, i).mutable_data());

if (threadIdx.x == 32) {
printf("\nthread(32) dst(%d, %d):\n", j, i);
dst(j, i).dump_value();
}
}
}
}
Expand Down Expand Up @@ -528,10 +509,6 @@ struct SharedToRegLoader {
// warp reuse mode.
int offset = shared_offset_.get_warp_offset();

if (threadIdx.x == 0 || threadIdx.x == 32) {
printf("s2r loader tid: %d, offset: %d\n", threadIdx.x, offset);
}

using Loader = detail::SharedToRegLoaderImpl<Shared, Reg, kRowExec,
kColExec, Shared::kType>;
Loader loader;
Expand Down Expand Up @@ -587,10 +564,6 @@ struct RegToSharedStorer {
SharedOffset shared_offset_;
int offset = shared_offset_.get_warp_offset();

if (threadIdx.x == 0 || threadIdx.x == 32) {
printf("r2s storer tid: %d, offset: %d\n", threadIdx.x, offset);
}

using Storer = detail::RegToSharedStorerImpl<Reg, Shared, kRowExec,
kColExec, Reg::kType>;
Storer storer;
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/cell/test_g2s_load.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ TEST(GlobalToSharedLoad, test_row_major_load) {
run_test_row_major<float, tl::RowMajor<4, 1>, 64, 32, true>();
run_test_row_major<float, tl::RowMajor<2, 2>, 32, 64, true>();
run_test_row_major<float, tl::RowMajor<2, 4>, 32, 128, true>();

// To check correctness for next tests.
run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, false>();
run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, true>();
}

TEST(GlobalToSharedLoad, test_col_major_load) {
Expand Down
6 changes: 5 additions & 1 deletion tests/cpp/cell/test_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ bool check_correctness(const half* hc1, const float* hc2, int row, int col) {

for (int i = 0; i < numel; ++i) {
diff = abs(__half2float(hc1[i]) - hc2[i]);
if (diff > eps) {
printf("error: %d, %.4f, %.4f\n", i, __half2float(hc1[i]), hc2[i]);
return false;
}
max_abs_diff = max_abs_diff < diff ? diff : max_abs_diff;
total_diff += diff;

Expand Down Expand Up @@ -342,7 +346,7 @@ TEST(TestGemm, test) {

// 2 x 1 warps
// TODO(KuangjuX): fix different warp layout.
run_test<128, 64, 64, tl::RowMajor<2, 1>, 64>();
run_test<32, 64, 64, tl::RowMajor<2, 1>, 64>();
// run_test<128, 128, 128, tl::RowMajor<2, 1>, 64>();
}

Expand Down
32 changes: 15 additions & 17 deletions tests/cpp/cell/test_single_wmma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ __device__ void naive_gemm(int kM, int kN, int kK, //
}

__device__ void check_results(const float* hc1, const float* hc2, int numel) {
// for (int i = 0; i < numel; ++i) assert(fabs(hc1[i] - hc2[i]) < 1e-3);
for (int i = 0; i < numel; ++i) {
if (fabs(hc1[i] - hc2[i]) > 1e-3) {
printf("error: %d, %.4f, %.4f\n", i, hc1[i], hc2[i]);
printf("test failed!\n");
return;
}
}

printf("test passed!\n");

#if defined(DEBUG)
if (thread0()) {
int cut_off = numel < 128 ? numel : 128;
Expand Down Expand Up @@ -89,6 +93,8 @@ __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB,

init_values<Element, ElementAcc>(shared_a, shared_b, shared_c, M, N, K);

__syncthreads();

SharedC sC(shared_c);

TileIteratorA sAs(shared_a);
Expand All @@ -107,21 +113,9 @@ __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB,

__syncthreads();

if (threadIdx.x == 32) {
printf("\nrA:\n");
rA.dump_value();

printf("\nrB:\n");
rB.dump_value();
}

compute::gemm(rA, rB, acc);
}

// if (threadIdx.x == 32) {
// acc.dump_value();
// }

__syncthreads();

store_rC(acc, sC);
Expand Down Expand Up @@ -234,16 +228,20 @@ void run_test() {

cudaDeviceSynchronize();

LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "]. Test passed!"
<< std::endl;
// LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "]. Test passed!"
// << std::endl;
}

TEST(TestWmma, test_m16n16k16_f) {
// run_test<16, 16, 64, tl::RowMajor<1, 1>>();
// run_test<32, 32, 64, tl::RowMajor<1, 1>>();
// run_test<64, 64, 64, tl::RowMajor<1, 1>>();
run_test<16, 16, 64, tl::RowMajor<1, 1>>();
run_test<32, 32, 64, tl::RowMajor<1, 1>>();
run_test<64, 64, 64, tl::RowMajor<1, 1>>();
// TODO(KuangjuX): It doesn't seem to be executed.
// run_test<64, 128, 128, tl::RowMajor<1, 1>>();

run_test<32, 64, 64, tl::RowMajor<2, 1>>();
run_test<32, 64, 64, tl::RowMajor<2, 1>>();
run_test<64, 64, 64, tl::RowMajor<2, 1>>();
run_test<128, 64, 64, tl::RowMajor<2, 1>>();
// run_test<128, 128, 128, tl::RowMajor<2, 1>>();

Expand Down
Loading

0 comments on commit 75b1f33

Please sign in to comment.