Skip to content

Commit

Permalink
fix(bugs): Support different WarpLayout in GEMMs.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 18, 2025
1 parent 316d5a1 commit 8b89491
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 65 deletions.
2 changes: 1 addition & 1 deletion 3rd-party/cutlass
Submodule cutlass updated 2232 files
20 changes: 10 additions & 10 deletions include/cell/copy/shared_to_register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ struct RegToSharedStorerImpl<Reg_, Shared_, kRowExec_, kColExec_,

DEVICE void operator()(const Reg& src, DType* dst, int start_offset) {
#pragma unroll
for (int i = 0; i < kRowExec; ++i) {
for (int j = 0; j < kColExec; ++j) {
#pragma unroll
for (int j = 0; j < kColExec; ++j) {
for (int i = 0; i < kRowExec; ++i) {
int lane_row = this->lane_row_id();
int lane_col = this->lane_col_id();

Expand Down Expand Up @@ -372,30 +372,30 @@ struct RegToSharedStorerImpl<Reg_, Shared_, kRowExec_, kColExec_,

DEVICE void operator()(const Reg& src, DType* dst, int start_offset) {
#pragma unroll
for (int i = 0; i < kColExec; ++i) {
for (int j = 0; j < kColExec; ++j) {
#pragma unroll
for (int j = 0; j < kRowExec; ++j) {
for (int i = 0; i < kRowExec; ++i) {
int tile_offset =
start_offset + i * kColStride + j * kRowStride;
start_offset + j * kColStride + i * kRowStride;
int lane_row = this->lane_row_id();
int lane_col = this->lane_col_id();

int row = 0, col = 0;
#pragma unroll
for (int m = 0; m < StoreMat::kSegRows; ++m) {
row = StoreMat::kElemPerSeg *
(lane_row + i * StoreMat::kThreadRows);
(lane_row + m * StoreMat::kThreadRows);
#pragma unroll
for (int n = 0; n < StoreMat::kSegCols; ++n) {
col = lane_col + j * StoreMat::kThreadCols;
col = lane_col + n * StoreMat::kThreadCols;

int in_tile_offset = col * Shared::kColStride + row;
int offset = tile_offset + in_tile_offset;
int swizzled_offset = get_swizzle_offset(offset);

const PackedType* src_ptr =
reinterpret_cast<const PackedType*>(
src(j, i).data());
src(i, j).data());
PackedType* dst_ptr =
reinterpret_cast<PackedType*>(dst);
dst_ptr[swizzled_offset / StoreMat::kElemPerSeg] =
Expand Down Expand Up @@ -463,9 +463,9 @@ struct RegToSharedStorerImpl<Reg_, Shared_, kRowExec_, kColExec_,
auto swizzled_tile_id = get_swizzled_tile_id(offset);
auto in_swizzled_tile_id = get_in_swizzle_tile_id(offset);
int swizzled_tile_offset =
src_tile_(swizzled_tile_id.x, swizzled_tile_id.y);
dst_tile_(swizzled_tile_id.x, swizzled_tile_id.y);
int in_swizzled_tile_offset =
in_src_tile_(in_swizzled_tile_id.x, in_swizzled_tile_id.y);
in_dst_tile_(in_swizzled_tile_id.x, in_swizzled_tile_id.y);

int offset_ = swizzled_tile_offset + in_swizzled_tile_offset;

Expand Down
33 changes: 26 additions & 7 deletions include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,19 @@ template <typename WarpLayout_, typename BaseShape_, typename Shared_,
struct SharedOffsetHelper<WarpLayout_, BaseShape_, Shared_, kMode_,
tl::Layout::kRowMajor, false> {
DEVICE int get_warp_offset() {
// TODO(KuangjuX): hotfix this.
return warp_row_id<WarpLayout>() * kRowStride * BaseShape::kRows *
Shared::kCols +
warp_col_id<WarpLayout>() * kColStride * BaseShape::kCols;
switch (kMode) {
case WarpReuse::kCont:
return warp_row_id<WarpLayout>() * kRowStride *
BaseShape::kRows * Shared::kCols +
warp_col_id<WarpLayout>() * kColStride *
BaseShape::kCols;
case WarpReuse::kRowReuseCont:
return warp_row_id<WarpLayout>() * kRowStride *
BaseShape::kRows * Shared::kCols;
default:
assert(false && "Not implemented yet.");
return -1;
}
}

private:
Expand All @@ -308,9 +317,19 @@ template <typename WarpLayout_, typename BaseShape_, typename Shared_,
struct SharedOffsetHelper<WarpLayout_, BaseShape_, Shared_, kMode_,
tl::Layout::kColMajor, false> {
DEVICE int get_warp_offset() {
return warp_row_id<WarpLayout>() * kRowStride * BaseShape::kRows +
warp_col_id<WarpLayout>() * kColStride * BaseShape::kCols *
Shared::kRows;
switch (kMode) {
case WarpReuse::kCont:
return warp_row_id<WarpLayout>() * kRowStride *
BaseShape::kRows +
warp_col_id<WarpLayout>() * kColStride *
BaseShape::kCols * Shared::kRows;
case WarpReuse::kColReuseCont:
return warp_col_id<WarpLayout>() * kColStride *
BaseShape::kCols * Shared::kRows;
default:
assert(false && "Not implemented yet.");
return -1;
}
}

private:
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/cell/test_g2s_load.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ TEST(GlobalToSharedLoad, test_row_major_load) {
run_test_row_major<float, tl::RowMajor<4, 1>, 64, 32, true>();
run_test_row_major<float, tl::RowMajor<2, 2>, 32, 64, true>();
run_test_row_major<float, tl::RowMajor<2, 4>, 32, 128, true>();

// To check correctness for next tests.
run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, false>();
run_test_row_major<__half, tl::RowMajor<2, 1>, 64, 64, true>();
}

TEST(GlobalToSharedLoad, test_col_major_load) {
Expand Down
19 changes: 16 additions & 3 deletions tests/cpp/cell/test_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ void run_test() {
dim3 dim_block(config::kThreads, 1, 1);
int shm_size = (kM + kN) * kK * sizeof(Element);

printf("config::kWarpPerRow: %d, config::kWarpPerCol: %d\n",
config::kWarpPerRow, config::kWarpPerCol);

auto kernel = test_gemm<
Element, ElementAcc, typename config::GlobalA, typename config::SharedA,
typename config::LoadSharedA, typename config::GlobalB,
Expand Down Expand Up @@ -326,9 +329,19 @@ TEST(TestGemm, test) {
run_test<128, 64, 64, tl::RowMajor<1, 1>, 64>();

// 2 x 1 warps
// TODO(KuangjuX): fix different warp layout.
// run_test<128, 64, 128, tl::RowMajor<2, 1>, 64>();
// run_test<128, 128, 128, tl::RowMajor<2, 1>, 64>();
run_test<32, 64, 128, tl::RowMajor<2, 1>, 128>();
run_test<64, 64, 128, tl::RowMajor<2, 1>, 128>();
run_test<32, 128, 128, tl::RowMajor<2, 1>, 128>();

// 1 x 2 warps
run_test<32, 128, 128, tl::RowMajor<1, 2>, 128>();
// run_test<64, 128, 128, tl::RowMajor<1, 2>, 128>();

// 2 x 2 warps
run_test<64, 64, 128, tl::RowMajor<2, 2>, 128>();

// 4 x 1 warps
run_test<64, 16, 256, tl::RowMajor<4, 1>, 256>();
}

} // namespace tilefusion::testing
89 changes: 62 additions & 27 deletions tests/cpp/cell/test_single_wmma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ __device__ void naive_gemm(int kM, int kN, int kK, //
}

__device__ void check_results(const float* hc1, const float* hc2, int numel) {
for (int i = 0; i < numel; ++i) assert(abs(hc1[i] - hc2[i]) < 1e-3);
for (int i = 0; i < numel; ++i) {
if (fabs(hc1[i] - hc2[i]) > 1e-3) {
printf("error: %d, %.4f, %.4f\n", i, hc1[i], hc2[i]);
printf("test failed!\n");
return;
}
}

printf("test passed!\n");

#if defined(DEBUG)
if (thread0()) {
Expand All @@ -59,10 +67,10 @@ __device__ void init_values(Element* a, Element* b, ElementAcc* c, int M, int N,
if (!thread0()) return;

for (int i = 0; i < M * K; ++i)
a[i] = static_cast<Element>(i % 2048 / 1000.);
a[i] = static_cast<Element>(i % 2048 / 1000);

for (int i = 0; i < K * N; ++i)
b[i] = static_cast<Element>(i % 2048 / 1000.);
b[i] = static_cast<Element>(i % 2048 / 1000);

for (int i = 0; i < M * N; ++i) c[i] = 0.;
}
Expand All @@ -84,6 +92,8 @@ __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB,

init_values<Element, ElementAcc>(shared_a, shared_b, shared_c, M, N, K);

__syncthreads();

SharedC sC(shared_c);

TileIteratorA sAs(shared_a);
Expand All @@ -99,10 +109,13 @@ __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB,

load_rA(sA, rA);
load_rB(sB, rB);
__syncthreads();

compute::gemm(rA, rB, acc);
}

__syncthreads();

store_rC(acc, sC);
__syncthreads();

Expand All @@ -119,29 +132,36 @@ __global__ void test_wmma(LoadRegA& load_rA, LoadRegB& load_rB,
} // namespace

template <typename Element, typename ElementAcc, const int kM, const int kN,
const int kK>
const int kK, typename WarpLayout>
struct TestTraits {
using WarpLayout = tl::RowMajor<1, 1>;
static const int kThreads = tl::get_numel<WarpLayout> * 32;

static constexpr int kWarpPerRow = tl::num_rows<WarpLayout>;
static constexpr int kWarpPerCol = tl::num_cols<WarpLayout>;

// ============= shared to register loader =================
// TODO: whether BaseTileShape should depend on Element type?
using BaseShape = traits::BaseTileShape<Element>;
// how many elements a BaseTile are executed along the m, n, k dimension
static constexpr int kMs = kM / BaseShape::kTileSize;
static constexpr int kNs = kN / BaseShape::kTileSize;
static constexpr int kKs = kK / BaseShape::kTileSize;

static constexpr int kAMs = kM / kWarpPerRow / BaseShape::kTileSize;
static constexpr int kAKs = kK / BaseShape::kTileSize;

static constexpr int kBKs = kK / BaseShape::kTileSize;
static constexpr int kBNs = kN / kWarpPerCol / BaseShape::kTileSize;

static constexpr int kCMs = kM / kWarpPerRow / BaseShape::kTileSize;
static constexpr int kCNs = kN / kWarpPerCol / BaseShape::kTileSize;

using SharedA = SharedTile<Element, tl::RowMajor<kM, kK>>;
using TileIteratorA = STileIterator<SharedA, TileShape<kM, kK>>;

using RegA = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<kMs, kKs>>;
using RegA = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<kAMs, kAKs>>;
using LoadRegA =
SharedToRegLoader<RegA, WarpLayout, WarpReuse::kRowReuseCont>;

using SharedB = SharedTile<Element, tl::ColMajor<kK, kN>>;

using RegB = RegTile<BaseTileColMajor<Element>, tl::ColMajor<kKs, kNs>>;
using RegB = RegTile<BaseTileColMajor<Element>, tl::ColMajor<kBKs, kBNs>>;
using TileIteratorB = STileIterator<SharedB, TileShape<kK, kN>>;
using LoadRegB =
SharedToRegLoader<RegB, WarpLayout, WarpReuse::kColReuseCont>;
Expand All @@ -151,23 +171,21 @@ struct TestTraits {

// ============= register to shared storer =================
using SharedC = SharedTile<ElementAcc, tl::RowMajor<kM, kN>>;
using RegC = RegTile<BaseTileRowMajor<ElementAcc>, tl::RowMajor<kMs, kNs>>;
using RegC =
RegTile<BaseTileRowMajor<ElementAcc>, tl::RowMajor<kCMs, kCNs>>;
using StoreRegC = RegToSharedStorer<RegC, WarpLayout>;
};

template <const int kM, const int kN, const int kK>
template <const int kM, const int kN, const int kK, typename WarpLayout>
void run_test() {
using Element = cutlass::half_t;
using ElementAcc = float;

using config = TestTraits<Element, ElementAcc, kM, kN, kK>;
using config = TestTraits<Element, ElementAcc, kM, kN, kK, WarpLayout>;

dim3 dim_grid(1, 1, 1);
dim3 dim_block(config::kThreads, 1, 1);

int shm_size =
(kM + kN) * kK * sizeof(Element) + kM * kN * sizeof(ElementAcc);

typename config::LoadRegA load_rA;
typename config::LoadRegB load_rB;
typename config::StoreRegC store_rC;
Expand All @@ -181,23 +199,40 @@ void run_test() {
<< "RegB: " << RegB{} << std::endl
<< "RegC: " << RegC{} << std::endl;

test_wmma<Element, ElementAcc, kM, kN, kK, typename config::TileIteratorA,
RegA, typename config::LoadRegA, typename config::TileIteratorB,
RegB, typename config::LoadRegB, typename config::SharedC, RegC,
typename config::StoreRegC>
<<<dim_grid, dim_block, shm_size>>>(load_rA, load_rB, store_rC);
auto kernel =
test_wmma<Element, ElementAcc, kM, kN, kK,
typename config::TileIteratorA, RegA,
typename config::LoadRegA, typename config::TileIteratorB,
RegB, typename config::LoadRegB, typename config::SharedC,
RegC, typename config::StoreRegC>;

int shm_size =
(kM + kN) * kK * sizeof(Element) + 2 * kM * kN * sizeof(ElementAcc);

if (shm_size > 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
}

kernel<<<dim_grid, dim_block, shm_size>>>(load_rA, load_rB, store_rC);

cudaDeviceSynchronize();

LOG(INFO) << "[" << kM << ", " << kN << ", " << kK << "]. Test passed!"
<< std::endl;
}

TEST(TestWmma, test_m16n16k16_f) {
run_test<16, 16, 64>();
run_test<32, 32, 64>();
run_test<64, 64, 64>();
run_test<64, 128, 128>();
run_test<128, 128, 128>();
run_test<16, 16, 64, tl::RowMajor<1, 1>>();
run_test<32, 32, 64, tl::RowMajor<1, 1>>();
run_test<64, 64, 64, tl::RowMajor<1, 1>>();
// TODO(KuangjuX): It doesn't seem to be executed.
// run_test<64, 128, 128, tl::RowMajor<1, 1>>();

run_test<32, 64, 64, tl::RowMajor<2, 1>>();
run_test<32, 64, 64, tl::RowMajor<2, 1>>();
run_test<64, 64, 64, tl::RowMajor<2, 1>>();
run_test<128, 64, 64, tl::RowMajor<2, 1>>();
}

} // namespace tilefusion::testing
Loading

0 comments on commit 8b89491

Please sign in to comment.