Skip to content

Commit

Permalink
ReImplement SMEM to RMEM Based on SwizzledLayout (#51)
Browse files Browse the repository at this point in the history
Progress:
- [x] Implemented loading from SMEM to RMEM and storage from RMEM to
SMEM based on `Swizzle<3, 3, 3>`.
- [x] Modified the `STileIterator` to enable traversal based on the
Swizzle layout.
- [x] Support `float` data type for R2S Storer.
- [x] Add support for different `WarpLayout` configurations.
- [x] Support ColMajor SMEM/RMEM Loader/Storer.
- [x] Support Single Warp MMA.
- [x] Support ColMajor GMEM/SMEM Loader/Storer.
- [x] Support SIngle Warp Whole GEMM.
- [x] Verify the performance.
  • Loading branch information
KuangjuX authored Feb 13, 2025
1 parent 0a1a287 commit c75a343
Show file tree
Hide file tree
Showing 23 changed files with 1,179 additions and 582 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
*.ii
*.gpu
*.ptx
*.sass
*.S
*.log
*.cubin
*.fatbin

Expand Down
19 changes: 19 additions & 0 deletions benchmarks/cpp/g2s_copy/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the
# MIT License.
# --------------------------------------------------------------------------

cmake_minimum_required(VERSION 3.25 FATAL_ERROR)
project(bench_g2s_copy LANGUAGES C CXX CUDA)

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH}
"${PROJECT_SOURCE_DIR}/../../../cmake")
set(THIRD_PARTY_DIR "${PROJECT_SOURCE_DIR}/../../../3rd-party")

include(generic)

include_directories("${PROJECT_SOURCE_DIR}/../../../include")
include_directories("${PROJECT_SOURCE_DIR}/../../utils/cpp")
include_directories("${THIRD_PARTY_DIR}/cutlass/include")

add_executable(bench_g2s_copy main.cu)
17 changes: 17 additions & 0 deletions benchmarks/cpp/g2s_copy/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------


BUILD_DIR := build

.PHONY: build clean

build:
@mkdir -p $(BUILD_DIR)
@cd $(BUILD_DIR) && cmake .. && make -j$(proc)

clean:
@rm -rf $(BUILD_DIR)

30 changes: 30 additions & 0 deletions benchmarks/cpp/g2s_copy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Data Tile Transfer between Global and Shared Memory

### Overview
This preliminary test evaluates the performance of transferring a row-major data tile containing half-precision floating-point values between global memory and shared memory. The transfer process involves loading the data tile into shared memory and subsequently storing it back to global memory. This cycle is repeated 100 times to measure performance.

### Performance Evaluation
Performance is assessed based on the total time required to complete the 100 data tile transfers.

### Implementations
The test includes implementations using TileFusion and cutlass, with no bank conflicts observed in the NVIDIA Compute Utility. The cutlass implementation utilizes a copy plan that allows for maximal global memory coalescing to optimally utilize the global memory.

### Test Environment
- **GPU**: NVIDIA Tesla A100
- **CUDA Version**: 12.6

### Results

|Shape|Warp Layout|tilefusion(ms)|cutlass(ms)|Ratio|
|:---|:---:|:---:|:---:|:---:|
|RowMajor(16, 64)|(1, 1)|0.02996|0.02957|1.013|
|RowMajor(64, 64)|(1, 1)|0.05073|0.05071|1|
|RowMajor(64, 64)|(2, 1)|0.05045|0.05068|0.9956|
|RowMajor(64, 64)|(4, 1)|0.05119|0.05145|0.995|
|RowMajor(128, 128)|(1, 1)|0.1369|0.154|0.8888|
|RowMajor(128, 128)|(2, 2)|0.1374|0.134|1.025|
|RowMajor(128, 128)|(4, 2)|0.138|0.1382|0.9984|
|RowMajor(128, 256)|(1, 1)|0.2464|0.3694|0.6671|
|RowMajor(128, 256)|(2, 2)|0.2471|0.2458|1.005|
|RowMajor(128, 256)|(2, 4)|0.2592|0.2511|1.032|
|RowMajor(128, 256)|(4, 4)|0.2543|0.2572|0.9889|
140 changes: 140 additions & 0 deletions benchmarks/cpp/g2s_copy/cutlass_copy.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "cutlass/copy.cuh"

#include <cute/swizzle.hpp>
#include <cute/tensor.hpp>

using namespace cute;
using namespace benchmarks;

namespace {
// NOTE: The current implementation of Loader/Storer supports only
// half-precision (FP16) RowMajor data tiles. It is not implemented for other
// data types or memory layouts. Be cautious when using it for other cases.
template <typename Element, //
const int kRows, const int kCols, //
const int kWarpRows, const int kWarpCols>
struct Loader {
DEVICE void operator()(const Element* src_, Element* dst_) {
int tid = threadIdx.x;

auto gtile = make_tensor(make_gmem_ptr(src_), src_layout_);
auto stile = make_tensor(make_smem_ptr(dst_), dst_layout_);

auto loader = tiled_copy_.get_thread_slice(tid);

auto src = loader.partition_S(gtile);
auto dst = loader.partition_D(stile);

#pragma unroll
for (int i = 0; i < int(size<1>(src)); ++i)
#pragma unroll
for (int j = 0; j < int(size<2>(src)); ++j)
cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j));
}

private:
// source
using GlobalLayout =
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
GlobalLayout src_layout_;

// destination
using LayoutAtom =
decltype(composition(cute::Swizzle<2, 3, 3>{},
cute::Layout<Shape<_4, _64>, Stride<_64, _1>>{}));
using SharedLayout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<kRows>, Int<kCols>>{}, cute::Step<_2, _1>{}));
SharedLayout dst_layout_;

// tiled copy
static constexpr int kThreadCols = kCols * 16 / 128;
static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols;

using ThreadLayout = cute::Layout<Shape<Int<kThreadRows>, Int<kThreadCols>>,
Stride<Int<kThreadCols>, _1>>;
using ValueLayout = cute::Layout<Shape<_1, _8>>;

using CopyInst =
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>;
using TiledCopy =
decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
TiledCopy tiled_copy_;
};

template <typename Element, //
const int kRows, const int kCols, //
const int kWarpRows, const int kWarpCols>
struct Storer {
DEVICE void operator()(const Element* src_, Element* dst_) {
int tid = threadIdx.x;

auto stile = make_tensor(make_smem_ptr(src_), src_layout_); // shared
auto gtile = make_tensor(make_gmem_ptr(dst_), dst_layout_); // global

auto loader = tiled_copy_.get_thread_slice(tid);

auto src = loader.partition_S(stile);
auto dst = loader.partition_D(gtile);

#pragma unroll
for (int i = 0; i < int(size<1>(src)); ++i)
#pragma unroll
for (int j = 0; j < int(size<2>(src)); ++j)
cute::copy(tiled_copy_, src(cute::_, i, j), dst(cute::_, i, j));
}

private:
// declare the source layout
using LayoutAtom =
decltype(composition(cute::Swizzle<2, 3, 3>{},
cute::Layout<Shape<_4, _64>, Stride<_64, _1>>{}));
using SharedLayout = decltype(tile_to_shape(
LayoutAtom{}, Shape<Int<kRows>, Int<kCols>>{}, cute::Step<_2, _1>{}));
SharedLayout src_layout_;

// declare the destination layout
using GlobalLayout =
cute::Layout<Shape<Int<kRows>, Int<kCols>>, Stride<Int<kCols>, _1>>;
GlobalLayout dst_layout_;

// declare the tiled copy
static constexpr int kThreadCols = kCols * 16 / 128;
static constexpr int kThreadRows = kWarpRows * kWarpCols * 32 / kThreadCols;
using ThreadLayout = cute::Layout<Shape<Int<kThreadRows>, Int<kThreadCols>>,
Stride<Int<kThreadCols>, _1>>;
using ValueLayout = cute::Layout<Shape<_1, _8>>;

using CopyInst = Copy_Atom<DefaultCopy, Element>;
using TiledCopy =
decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
TiledCopy tiled_copy_;
};
} // namespace

template <typename Element, const int kRows, const int kCols,
const int kWarpRow, const int kWarpCol, const int kRepeat>
__global__ void cutlass_g2s_data_transfer(const Element* src, Element* dst) {
extern __shared__ __align__(sizeof(double)) unsigned char buf_[];
auto* buf = reinterpret_cast<Element*>(buf_);

using G2S = Loader<Element, kRows, kCols, kWarpRow, kWarpCol>;
G2S loader;

using S2G = Storer<Element, kRows, kCols, kWarpRow, kWarpCol>;
S2G storer;

for (int k = 0; k < kRepeat; ++k) {
loader(src, buf);

cutlass_wrapper::__copy_async();
__syncthreads();

storer(buf, dst);
__syncthreads();
}
}
166 changes: 166 additions & 0 deletions benchmarks/cpp/g2s_copy/main.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "cell/copy/mod.hpp"
#include "cutlass_copy.cuh"
#include "tilefusion_copy.cuh"
#include "types/mod.hpp"
#include "util/cuda_timer.hpp"

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

using namespace tilefusion;
using namespace tilefusion::cell;
using namespace tilefusion::cell::copy;

int warmup = 20;
int iters = 100;
const int kRepeat = 100;

template <typename Element>
bool check_results(const Element* dst1, const Element* dst2, int64_t numel) {
float epsilon = 1e-3;
for (int i = 0; i < numel; ++i) {
float v1 = abs(static_cast<float>(dst1[i]));
float v2 = abs(static_cast<float>(dst2[i]));
if (v1 - v2 > epsilon) {
std::cerr << "Mismatch at " << i << ": " << v1 << " vs " << v2
<< std::endl;
return false;
}
}
return true;
}

template <typename Element, typename Layout, typename WarpLayout,
const int kRepeat>
float test_tilefusion(const Element* src, Element* dst) {
using Global = GlobalTile<Element, Layout>;
using Shared = SharedTile<Element, Layout, true /*kSwizzled*/>;

using Loader = GlobalToSharedLoader<Shared, WarpLayout>;
Loader loader;

using Storer = SharedToGlobalStorer<Shared, WarpLayout>;
Storer storer;

auto kernel =
&g2s_data_transfer<Element, Global, Shared, Loader, Storer, kRepeat>;

static const int kThreads = WarpLayout::kNumel * 32;
int shm_size = Shared::kNumel * sizeof(Element);

if (shm_size > 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
}

dim3 grids(1, 1, 1);
dim3 blocks(kThreads);

for (int i = 0; i < warmup; ++i) // warm up
kernel<<<grids, blocks, shm_size>>>(src, dst, loader, storer);
cudaDeviceSynchronize();

CudaTimer timer;
timer.start();
for (int i = 0; i < iters; ++i)
kernel<<<grids, blocks, shm_size>>>(src, dst, loader, storer);
cudaDeviceSynchronize();
return timer.stop() / iters;
}

template <typename Element, typename Layout, typename WarpLayout,
const int kRepeat>
float test_cutlass(const Element* src, Element* dst) {
auto kernel = &cutlass_g2s_data_transfer<Element, Layout::kRows,
Layout::kCols, WarpLayout::kRows,
WarpLayout::kCols, kRepeat>;

int shm_size = Layout::kNumel * sizeof(Element);
int kThreads = WarpLayout::kNumel * 32;

if (shm_size > 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
}

dim3 grids(1, 1, 1);
dim3 blocks(kThreads);

for (int i = 0; i < warmup; ++i) {
kernel<<<grids, blocks, shm_size>>>(src, dst);
}
cudaDeviceSynchronize();

CudaTimer timer;
timer.start();
for (int i = 0; i < iters; ++i) {
kernel<<<grids, blocks, shm_size>>>(src, dst);
}
cudaDeviceSynchronize();
return timer.stop() / iters;
}

template <typename Element, typename Layout, typename WarpLayout>
void run_test_rowmajor() {
int numel = Layout::kNumel;

thrust::host_vector<Element> h_src(numel);
for (int i = 0; i < h_src.size(); ++i)
h_src[i] = static_cast<Element>(i % 2048);

thrust::device_vector<Element> d_src = h_src;
const Element* src = thrust::raw_pointer_cast(d_src.data());

thrust::device_vector<Element> d_dst1(numel);
thrust::fill(d_dst1.begin(), d_dst1.end(), static_cast<Element>(0.));
Element* dst1 = thrust::raw_pointer_cast(d_dst1.data());

thrust::device_vector<Element> d_dst2(numel);
thrust::fill(d_dst2.begin(), d_dst2.end(), static_cast<Element>(0.));
Element* dst2 = thrust::raw_pointer_cast(d_dst2.data());

float t1 = test_tilefusion<Element, Layout, WarpLayout, kRepeat>(src, dst1);
thrust::host_vector<Element> h_dst1 = d_dst1;

float t2 = test_cutlass<Element, Layout, WarpLayout, kRepeat>(src, dst2);
thrust::host_vector<Element> h_dst2 = d_dst2;

bool passed = check_results(thrust::raw_pointer_cast(h_dst1.data()),
thrust::raw_pointer_cast(h_dst2.data()), numel);
if (!passed) {
std::cerr << "Test failed" << std::endl;
return;
}

std::cout << "|RowMajor(" << Layout::kRows << ", " << Layout::kCols << ")|("
<< WarpLayout::kRows << ", " << WarpLayout::kCols << ")|" << t1
<< "|" << t2 << "|" << t1 / t2 << "|" << std::endl;
}

int main() {
std::cout << std::setprecision(4)
<< "|Shape|Warp Layout|tilefusion(ms)|cutlass(ms)|Ratio|"
<< std::endl
<< "|:---|:---:|:---:|:---:|:---:|" << std::endl;

using DType = __half;

run_test_rowmajor<DType, tl::RowMajor<16, 64>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<64, 64>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<64, 64>, tl::RowMajor<2, 1>>();
run_test_rowmajor<DType, tl::RowMajor<64, 64>, tl::RowMajor<4, 1>>();

run_test_rowmajor<DType, tl::RowMajor<128, 128>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<128, 128>, tl::RowMajor<2, 2>>();
run_test_rowmajor<DType, tl::RowMajor<128, 128>, tl::RowMajor<4, 2>>();

run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<1, 1>>();
run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<2, 2>>();
run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<2, 4>>();
run_test_rowmajor<DType, tl::RowMajor<128, 256>, tl::RowMajor<4, 4>>();

return 0;
}
Loading

0 comments on commit c75a343

Please sign in to comment.