Skip to content

Commit

Permalink
chore: Add kSwizzled flag in test_gemm.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 18, 2025
1 parent 25e31ec commit bbb78c8
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions tests/cpp/cell/test_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ void cublas_hgemm(int m, int n, int k, const __half* A, const __half* B,
// @param strided_k: chunk size to partition the k dimension of the shared
// memory tile.
template <typename Element, typename ElementAcc, const int kM, const int kN,
const int kK, typename WarpLayout_, const int kChunkK>
const int kK, typename WarpLayout_, const int kChunkK,
const bool kSwizzled>
struct TestTraits {
using BaseShape = traits::BaseTileShape<Element>;

Expand All @@ -118,7 +119,6 @@ struct TestTraits {

/// == 2. configure tile transfer between global and shared using CuTe ==
using GlobalA = GlobalTile<Element, tl::RowMajor<kM, kK>>;
static const bool kSwizzled = true;
using SharedA = SharedTile<Element, tl::RowMajor<kM, kK>, kSwizzled>;
using LoadSharedA = GlobalToSharedLoader<SharedA, WarpLayout>;

Expand Down Expand Up @@ -216,7 +216,7 @@ __global__ void test_gemm(const Element* ga, const Element* gb,
} // namespace

template <const int kM, const int kN, const int kK, typename WarpLayout,
const int kChunkK>
const int kChunkK, const bool kSwizzled>
void run_test() {
// unittest for register-level gemm by calling into wmma PTX
using Element = __half;
Expand Down Expand Up @@ -249,8 +249,8 @@ void run_test() {
thrust::device_vector<ElementAcc> d_c = h_c;

// define the configuration of the test
using config =
TestTraits<Element, ElementAcc, kM, kN, kK, WarpLayout, kChunkK>;
using config = TestTraits<Element, ElementAcc, kM, kN, kK, WarpLayout,
kChunkK, kSwizzled>;

LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "], warps: ["
<< config::kWarpPerRow << ", " << config::kWarpPerCol
Expand All @@ -276,9 +276,6 @@ void run_test() {
dim3 dim_block(config::kThreads, 1, 1);
int shm_size = (kM + kN) * kK * sizeof(Element);

printf("config::kWarpPerRow: %d, config::kWarpPerCol: %d\n",
config::kWarpPerRow, config::kWarpPerCol);

auto kernel = test_gemm<
Element, ElementAcc, typename config::GlobalA, typename config::SharedA,
typename config::LoadSharedA, typename config::GlobalB,
Expand Down Expand Up @@ -321,27 +318,29 @@ TEST(TestGemm, test) {
// as this will cause a shared memory overflow.

// 1 warp
run_test<16, 16, 64, tl::RowMajor<1, 1>, 64>(); // minimal shape
run_test<32, 16, 64, tl::RowMajor<1, 1>, 64>();
run_test<16, 32, 64, tl::RowMajor<1, 1>, 64>();
run_test<32, 32, 64, tl::RowMajor<1, 1>, 64>();
run_test<64, 64, 64, tl::RowMajor<1, 1>, 64>();
run_test<128, 64, 64, tl::RowMajor<1, 1>, 64>();
run_test<16, 16, 64, tl::RowMajor<1, 1>, 64, true>(); // minimal shape
run_test<32, 16, 64, tl::RowMajor<1, 1>, 64, true>();
run_test<16, 32, 64, tl::RowMajor<1, 1>, 64, true>();
run_test<32, 32, 64, tl::RowMajor<1, 1>, 64, true>();
run_test<64, 64, 64, tl::RowMajor<1, 1>, 64, true>();
run_test<128, 64, 64, tl::RowMajor<1, 1>, 64, true>();

// 2 x 1 warps
run_test<32, 64, 128, tl::RowMajor<2, 1>, 128>();
run_test<64, 64, 128, tl::RowMajor<2, 1>, 128>();
run_test<32, 128, 128, tl::RowMajor<2, 1>, 128>();
run_test<32, 64, 128, tl::RowMajor<2, 1>, 128, true>();
run_test<64, 64, 128, tl::RowMajor<2, 1>, 128, true>();
run_test<32, 128, 128, tl::RowMajor<2, 1>, 128, true>();

// 1 x 2 warps
run_test<32, 128, 128, tl::RowMajor<1, 2>, 128>();
// run_test<64, 128, 128, tl::RowMajor<1, 2>, 128>();
run_test<32, 128, 128, tl::RowMajor<1, 2>, 128, true>();
// TODO(KuangjuX): CUDA free failed: cudaErrorMisalignedAddress: misaligned
// address.
// run_test<64, 128, 128, tl::RowMajor<1, 2>, 128, true>();

// 2 x 2 warps
run_test<64, 64, 128, tl::RowMajor<2, 2>, 128>();
run_test<64, 64, 128, tl::RowMajor<2, 2>, 128, true>();

// 4 x 1 warps
run_test<64, 16, 256, tl::RowMajor<4, 1>, 256>();
run_test<64, 16, 256, tl::RowMajor<4, 1>, 256, true>();
}

} // namespace tilefusion::testing

0 comments on commit bbb78c8

Please sign in to comment.