Skip to content

Commit d027e3d

Browse files
committed
Fix implementations of loadding data.
1 parent bbc6ae1 commit d027e3d

13 files changed

+362
-233
lines changed

.vscode/settings.json

+1-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
],
55
"files.associations": {
66
"optional": "cpp",
7-
"system_error": "cpp",
8-
"array": "cpp",
9-
"string": "cpp"
7+
"cstdint": "cpp"
108
}
119
}

csrc/kernels/copy/atom.cuh

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#pragma once
4+
5+
#include <stdint.h>
6+
7+
namespace vptq::kernels::copy {
8+
9+
namespace {
10+
/// ld.shared
11+
template <const int kBytes>
12+
DEVICE void ld_shared(void* dst, uint32_t src);
13+
14+
/// ld.shared - 16b
15+
template <>
16+
DEVICE void ld_shared<2>(void* dst, uint32_t src) {
17+
asm volatile("ld.shared.u16 %0, [%1];\n"
18+
: "=h"(*reinterpret_cast<uint16_t*>(dst))
19+
: "r"(src));
20+
}
21+
22+
/// ld.shared - 32b
23+
template <>
24+
DEVICE void ld_shared<4>(void* dst, uint32_t src) {
25+
asm volatile("ld.shared.u32 %0, [%1];\n"
26+
: "=r"(*reinterpret_cast<uint32_t*>(dst))
27+
: "r"(src));
28+
}
29+
30+
/// ld.shared - 64b
31+
template <>
32+
DEVICE void ld_shared<8>(void* dst, uint32_t src) {
33+
uint2* dst_u64 = reinterpret_cast<uint2*>(dst);
34+
asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n"
35+
: "=r"(dst_u64->x), "=r"(dst_u64->y)
36+
: "r"(src));
37+
}
38+
39+
/// ld.shared - 128b
40+
template <>
41+
DEVICE void ld_shared<16>(void* dst, uint32_t src) {
42+
uint4* dst_u128 = reinterpret_cast<uint4*>(dst);
43+
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n"
44+
: "=r"(dst_u128->x), "=r"(dst_u128->y), "=r"(dst_u128->z),
45+
"=r"(dst_u128->w)
46+
: "r"(src));
47+
}
48+
49+
/// st.shared
50+
template <int kBytes>
51+
DEVICE void st_shared(uint32_t dst, void const* src);
52+
53+
/// st.shared - 16b
54+
template <>
55+
DEVICE void st_shared<2>(uint32_t dst, void const* src) {
56+
asm volatile("st.shared.u16 [%0], %1;\n"
57+
:
58+
: "r"(dst), "h"(*reinterpret_cast<uint16_t const*>(src)));
59+
}
60+
61+
/// st.shared - 32b
62+
template <>
63+
DEVICE void st_shared<4>(uint32_t dst, void const* src) {
64+
asm volatile("st.shared.u32 [%0], %1;\n"
65+
:
66+
: "r"(dst), "r"(*reinterpret_cast<uint32_t const*>(src)));
67+
}
68+
69+
/// st.shared - 64b
70+
template <>
71+
DEVICE void st_shared<8>(uint32_t dst, void const* src) {
72+
uint2 const* dst_u64 = reinterpret_cast<uint2 const*>(src);
73+
asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n"
74+
:
75+
: "r"(dst), "r"(dst_u64->x), "r"(dst_u64->y));
76+
}
77+
78+
/// st.shared - 128b
79+
template <>
80+
DEVICE void st_shared<16>(uint32_t dst, void const* src) {
81+
uint4 const* dst_u128 = reinterpret_cast<uint4 const*>(src);
82+
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
83+
:
84+
: "r"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z),
85+
"r"(dst_u128->w));
86+
}
87+
88+
/// st.global
89+
template <int kBytes>
90+
DEVICE void st_global(void* dst, const void* src);
91+
92+
template <>
93+
DEVICE void st_global<16>(void* dst, const void* src) {
94+
uint4 const* dst_u128 = reinterpret_cast<uint4 const*>(src);
95+
asm volatile("st.global.v4.b32 [%0], {%1, %2, %3, %4};\n"
96+
:
97+
: "l"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z),
98+
"r"(dst_u128->w));
99+
}
100+
} // namespace
101+
102+
template <int kBytes>
103+
DEVICE void ld_shared_st_global(void* dst, uint32_t src);
104+
105+
template <>
106+
DEVICE void ld_shared_st_global<16>(void* dst, uint32_t src) {
107+
unsigned tmp[4];
108+
ld_shared<16>(tmp, src);
109+
st_global<16>(dst, tmp);
110+
}
111+
112+
template <const int kBytes>
113+
DEVICE void ld_global_st_shared(uint32_t dst, void const* src) {
114+
static_assert(kBytes == 4 || kBytes == 8 || kBytes == 16);
115+
116+
#if (__CUDA_ARCH__ >= 800)
117+
// SM90, hopper, SM80, SM86, ampere
118+
119+
// TODO(ying): add a wrapper to allow choosing between different caching
120+
// policies (e.g. "cache all levels").
121+
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(dst),
122+
"l"(src), "n"(kBytes));
123+
#else
124+
unsigned tmp[kBytes / 4];
125+
if constexpr (kBytes == 16) {
126+
asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n"
127+
: "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3])
128+
: "l"(src));
129+
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(dst),
130+
"r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), "r"(tmp[3]));
131+
} else if constexpr (kBytes == 8) {
132+
asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n"
133+
: "=r"(tmp[0]), "=r"(tmp[1])
134+
: "l"(src));
135+
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(dst), "r"(tmp[0]),
136+
"r"(tmp[1]));
137+
} else if constexpr (kBytes == 4) {
138+
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(src));
139+
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(dst), "r"(tmp[0]));
140+
}
141+
#endif
142+
}
143+
144+
} // namespace vptq::kernels::copy

csrc/kernels/copy/copy.cuh

+41-136
Original file line numberDiff line numberDiff line change
@@ -2,149 +2,19 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5+
/// The Loader and Storer in this file use all collaborative threads in a thread
6+
/// block to transfer data tiles between global memory and shared memory.
7+
8+
#include "kernels/copy/atom.cuh"
59
#include "kernels/copy/copy_traits.cuh"
10+
#include "kernels/copy/warp.cuh"
611

712
#include <cute/tensor.hpp>
813

914
namespace vptq::kernels::copy {
10-
15+
namespace tl = vptq::tile_layout;
1116
using namespace cute;
1217

13-
namespace {
14-
/// ld.shared
15-
template <const int kBytes>
16-
DEVICE void ld_shared(void* dst, uint32_t src);
17-
18-
/// ld.shared - 16b
19-
template <>
20-
DEVICE void ld_shared<2>(void* dst, uint32_t src) {
21-
asm volatile("ld.shared.u16 %0, [%1];\n"
22-
: "=h"(*reinterpret_cast<uint16_t*>(dst))
23-
: "r"(src));
24-
}
25-
26-
/// ld.shared - 32b
27-
template <>
28-
DEVICE void ld_shared<4>(void* dst, uint32_t src) {
29-
asm volatile("ld.shared.u32 %0, [%1];\n"
30-
: "=r"(*reinterpret_cast<uint32_t*>(dst))
31-
: "r"(src));
32-
}
33-
34-
/// ld.shared - 64b
35-
template <>
36-
DEVICE void ld_shared<8>(void* dst, uint32_t src) {
37-
uint2* dst_u64 = reinterpret_cast<uint2*>(dst);
38-
asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n"
39-
: "=r"(dst_u64->x), "=r"(dst_u64->y)
40-
: "r"(src));
41-
}
42-
43-
/// ld.shared - 128b
44-
template <>
45-
DEVICE void ld_shared<16>(void* dst, uint32_t src) {
46-
uint4* dst_u128 = reinterpret_cast<uint4*>(dst);
47-
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n"
48-
: "=r"(dst_u128->x), "=r"(dst_u128->y), "=r"(dst_u128->z),
49-
"=r"(dst_u128->w)
50-
: "r"(src));
51-
}
52-
53-
/// st.shared
54-
template <int kBytes>
55-
DEVICE void st_shared(uint32_t dst, void const* src);
56-
57-
/// st.shared - 16b
58-
template <>
59-
DEVICE void st_shared<2>(uint32_t dst, void const* src) {
60-
asm volatile("st.shared.u16 [%0], %1;\n"
61-
:
62-
: "r"(dst), "h"(*reinterpret_cast<uint16_t const*>(src)));
63-
}
64-
65-
/// st.shared - 32b
66-
template <>
67-
DEVICE void st_shared<4>(uint32_t dst, void const* src) {
68-
asm volatile("st.shared.u32 [%0], %1;\n"
69-
:
70-
: "r"(dst), "r"(*reinterpret_cast<uint32_t const*>(src)));
71-
}
72-
73-
/// st.shared - 64b
74-
template <>
75-
DEVICE void st_shared<8>(uint32_t dst, void const* src) {
76-
uint2 const* dst_u64 = reinterpret_cast<uint2 const*>(src);
77-
asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n"
78-
:
79-
: "r"(dst), "r"(dst_u64->x), "r"(dst_u64->y));
80-
}
81-
82-
/// st.shared - 128b
83-
template <>
84-
DEVICE void st_shared<16>(uint32_t dst, void const* src) {
85-
uint4 const* dst_u128 = reinterpret_cast<uint4 const*>(src);
86-
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
87-
:
88-
: "r"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z),
89-
"r"(dst_u128->w));
90-
}
91-
92-
/// st.global
93-
template <int kBytes>
94-
DEVICE void st_global(void* dst, const void* src);
95-
96-
template <>
97-
DEVICE void st_global<16>(void* dst, const void* src) {
98-
uint4 const* dst_u128 = reinterpret_cast<uint4 const*>(src);
99-
asm volatile("st.global.v4.b32 [%0], {%1, %2, %3, %4};\n"
100-
:
101-
: "l"(dst), "r"(dst_u128->x), "r"(dst_u128->y), "r"(dst_u128->z),
102-
"r"(dst_u128->w));
103-
}
104-
} // namespace
105-
106-
template <int kBytes>
107-
DEVICE void ld_shared_st_global(void* dst, uint32_t src);
108-
109-
template <>
110-
DEVICE void ld_shared_st_global<16>(void* dst, uint32_t src) {
111-
unsigned tmp[4];
112-
ld_shared<16>(tmp, src);
113-
st_global<16>(dst, tmp);
114-
}
115-
116-
template <const int kBytes>
117-
DEVICE void ld_global_st_shared(uint32_t dst, void const* src) {
118-
static_assert(kBytes == 4 || kBytes == 8 || kBytes == 16);
119-
120-
#if (__CUDA_ARCH__ >= 800)
121-
// SM90, hopper, SM80, SM86, ampere
122-
123-
// TODO(ying): add a wrapper to allow choosing between different caching
124-
// policies (e.g. "cache all levels").
125-
asm volatile("cp.async.cg.shared.global [%0], [%1], %2;\n" ::"r"(dst),
126-
"l"(src), "n"(kBytes));
127-
#else
128-
unsigned tmp[kBytes / 4];
129-
if constexpr (kBytes == 16) {
130-
asm volatile("ld.global.v4.b32 {%0, %1, %2, %3}, [%4];\n"
131-
: "=r"(tmp[0]), "=r"(tmp[1]), "=r"(tmp[2]), "=r"(tmp[3])
132-
: "l"(src));
133-
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" ::"r"(dst),
134-
"r"(tmp[0]), "r"(tmp[1]), "r"(tmp[2]), "r"(tmp[3]));
135-
} else if constexpr (kBytes == 8) {
136-
asm volatile("ld.global.v2.b32 {%0, %1}, [%2];\n"
137-
: "=r"(tmp[0]), "=r"(tmp[1])
138-
: "l"(src));
139-
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" ::"r"(dst), "r"(tmp[0]),
140-
"r"(tmp[1]));
141-
} else if constexpr (kBytes == 4) {
142-
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(tmp[0]) : "l"(src));
143-
asm volatile("st.shared.b32 [%0], %1;\n" ::"r"(dst), "r"(tmp[0]));
144-
}
145-
#endif
146-
}
147-
14818
template <typename DType, const int kNumPerAccess, typename ThreadLayout,
14919
typename GlobalLayout /*src*/, typename SharedLayout /*dst*/>
15020
struct GlobalToSharedLoader {
@@ -213,4 +83,39 @@ private:
21383
TiledCopy tiled_copy_;
21484
};
21585

86+
/// NOTE: This configuration is specialized for copying a small tile whose size
87+
/// is smaller than the data size accessed by all threads in a CTA concurrently.
88+
template <typename DType, const int kNumel, typename Base = AccessInfo<DType>>
89+
struct GlobalToSharedInputLoader : public Base {
90+
static constexpr int kWarpTileShape = Base::kNumPerAccess * WARP_SIZE;
91+
static constexpr int kThreads = kNumel / kWarpTileShape * WARP_SIZE;
92+
93+
DEVICE void operator()(const DType* src_, DType* dst_, int start_warp = 0) {
94+
int warp_id = threadIdx.x / WARP_SIZE - start_warp;
95+
int lane_id = threadIdx.x % WARP_SIZE;
96+
int offset = warp_id * kWarpTileShape + lane_id * Base::kNumPerAccess;
97+
98+
ld_global_st_shared<Base::kAccessInBytes>(
99+
__cvta_generic_to_shared(dst_ + offset), src_ + offset);
100+
}
101+
};
102+
103+
/// NOTE: This configuration is specialized for copying a small tile whose size
104+
/// is smaller than the data size accessed by all threads in a CTA concurrently.
105+
template <typename DType, const int kNumel, typename Base = AccessInfo<DType>>
106+
struct SharedToGlobalInputStorer : public Base {
107+
static constexpr int kWarpTileShape = Base::kNumPerAccess * WARP_SIZE;
108+
static constexpr int kThreads = kNumel / kWarpTileShape * WARP_SIZE;
109+
110+
DEVICE void operator()(const DType* src_, DType* dst_, int start_warp = 0) {
111+
int warp_id = threadIdx.x / WARP_SIZE - start_warp;
112+
int lane_id = threadIdx.x % WARP_SIZE;
113+
int offset = warp_id * kWarpTileShape + lane_id * Base::kNumPerAccess;
114+
115+
ld_shared_st_global<Base::kAccessInBytes>(
116+
dst_ + offset,
117+
static_cast<uint32_t>(__cvta_generic_to_shared(src_ + offset)));
118+
}
119+
};
120+
216121
} // namespace vptq::kernels::copy

csrc/kernels/copy/copy_traits.cuh

+3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5+
#include "kernels/copy/layout.cuh"
6+
57
namespace vptq::kernels::copy {
8+
namespace tl = vptq::tile_layout;
69

710
template <typename DType>
811
struct AccessInfo {

csrc/kernels/copy/layout.cuh

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
#pragma once
4+
5+
namespace vptq::tile_layout {
6+
7+
enum class Layout { kRowMajor = 0, kColMajor = 1 };
8+
9+
template <const int kRows_, const int kCols_, const int kRowStride_,
10+
const int kColStride_>
11+
struct MatrixLayout {
12+
static constexpr int kRows = kRows_;
13+
static constexpr int kCols = kCols_;
14+
15+
static constexpr int kRowStride = kRowStride_;
16+
static constexpr int kColStride = kColStride_;
17+
18+
static constexpr int kNumel = kRows * kCols;
19+
20+
// FIXME(ying): The current method to determine if the layout is row-major or
21+
// column-major may not be accurate for a matrix of shape (1, 1).
22+
static constexpr Layout kType =
23+
kColStride == 1 ? Layout::kRowMajor : Layout::kColMajor;
24+
25+
HOST_DEVICE int operator()(int i, int j) const {
26+
return i * kRowStride + j * kColStride;
27+
}
28+
};
29+
30+
template <const int kRow, const int kCol, const int kStride = kCol>
31+
using RowMajor = MatrixLayout<kRow, kCol, kStride, 1>;
32+
template <const int kRow, const int kCol, const int kStride = kRow>
33+
using ColMajor = MatrixLayout<kRow, kCol, 1, kStride>;
34+
35+
} // namespace vptq::tile_layout

0 commit comments

Comments
 (0)