Skip to content

Commit a670442

Browse files
committed
Merge remote-tracking branch 'jg/cuda-fa-mma-17' into debug4
2 parents 6fa50f7 + 727db80 commit a670442

File tree

6 files changed

+724
-719
lines changed

6 files changed

+724
-719
lines changed

ggml/src/ggml-cuda/common.cuh

+15-6
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@
4141
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
4242
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
4343

44-
#define GGML_CUDA_CC_PASCAL 600
45-
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46-
#define GGML_CUDA_CC_VOLTA 700
47-
#define GGML_CUDA_CC_TURING 750
48-
#define GGML_CUDA_CC_AMPERE 800
49-
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
44+
#define GGML_CUDA_CC_PASCAL 600
45+
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
46+
#define GGML_CUDA_CC_VOLTA 700
47+
#define GGML_CUDA_CC_TURING 750
48+
#define GGML_CUDA_CC_AMPERE 800
49+
#define GGML_CUDA_CC_ADA_LOVELACE 890
50+
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
5051

5152
// GCN/CNDA, wave size is 64
5253
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
@@ -199,6 +200,10 @@ typedef float2 dfloat2;
199200
#define NEW_MMA_AVAILABLE
200201
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
201202

203+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
204+
#define CP_ASYNC_AVAILABLE
205+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
206+
202207
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
203208
#define FLASH_ATTN_AVAILABLE
204209
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
231236
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
232237
}
233238

239+
static bool cp_async_available(const int cc) {
240+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
241+
}
242+
234243
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
235244
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
236245
return __AMDGCN_WAVEFRONT_SIZE;

ggml/src/ggml-cuda/cp-async.cuh

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Simplified API for asynchronous data loading.
2+
3+
#include "common.cuh"
4+
5+
// Copies data from global to shared memory, cg == cache global.
6+
// Both the src and dst pointers must be aligned to 16 bit.
7+
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
8+
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
9+
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
10+
template <int preload>
11+
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
12+
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
13+
#ifdef CP_ASYNC_AVAILABLE
14+
#if CUDART_VERSION >= 11040
15+
if (preload == 256) {
16+
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
17+
: : "r"(dst), "l"(src));
18+
} else if (preload == 128) {
19+
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
20+
: : "r"(dst), "l"(src));
21+
} else if (preload == 64) {
22+
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
23+
: : "r"(dst), "l"(src));
24+
} else
25+
#endif // CUDART_VERSION >= 11040
26+
{
27+
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
28+
: : "r"(dst), "l"(src));
29+
}
30+
#else
31+
GGML_UNUSED(dst);
32+
GGML_UNUSED(src);
33+
NO_DEVICE_CODE;
34+
#endif // CP_ASYNC_AVAILABLE
35+
}
36+
37+
// Makes each thread wait until its asynchronous data copies are done.
38+
// This does NOT provide any additional synchronization.
39+
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
40+
static __device__ __forceinline__ void cp_async_wait_all() {
41+
#ifdef CP_ASYNC_AVAILABLE
42+
asm volatile("cp.async.wait_all;");
43+
#else
44+
NO_DEVICE_CODE;
45+
#endif // CP_ASYNC_AVAILABLE
46+
}

ggml/src/ggml-cuda/fattn-common.cuh

+9-6
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,9 @@ void launch_fattn(
716716

717717
ggml_cuda_pool & pool = ctx.pool();
718718
cudaStream_t main_stream = ctx.stream();
719-
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
719+
const int id = ggml_cuda_get_device();
720+
const int cc = ggml_cuda_info().devices[id].cc;
721+
const int nsm = ggml_cuda_info().devices[id].nsm;
720722

721723
ggml_cuda_pool_alloc<half> K_f16(pool);
722724
ggml_cuda_pool_alloc<half> V_f16(pool);
@@ -768,13 +770,14 @@ void launch_fattn(
768770
dim3 blocks_num;
769771
if (parallel_blocks == 0) {
770772
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
771-
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
772-
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
773-
const bool short_context = K->ne[1] < 4096;
773+
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
774+
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
774775

775776
const int nblocks_stream_k = 2*nsm;
776777

777-
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
778+
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
779+
780+
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
778781
blocks_num.y = 1;
779782
blocks_num.z = 1;
780783

@@ -827,7 +830,7 @@ void launch_fattn(
827830
CUDA_CHECK(cudaGetLastError());
828831

829832
if constexpr (parallel_blocks == 0) {
830-
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
833+
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
831834
const dim3 block_dim_combine(D, 1, 1);
832835
const dim3 blocks_num_combine = blocks_num;
833836

0 commit comments

Comments
 (0)