2020#define CHECK_SHAPE (x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ " )" )
2121#define CHECK_CONTIGUOUS (x ) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous" )
2222
23+ inline float get_scalar_f32_cpu_only (c10::optional<const at::Tensor> & scale,
24+ const char * name = " dequant scale" ) {
25+ TORCH_CHECK (scale.has_value (),
26+ name, " is None (optional has no value)" );
27+ const at::Tensor& t = *scale;
28+ TORCH_CHECK (!t.device ().is_cuda (),
29+ " descale_q / descale_k must be on CPU, but got " ,
30+ t.device ().type (), " device" );
31+ TORCH_CHECK (t.scalar_type () == torch::kFloat32 ,
32+ " descale_q / descale_k must be float32, but got " ,
33+ t.scalar_type ());
34+ TORCH_CHECK (t.numel () == 1 ,
35+ " descale_q / descale_k must be a scalar, but got " ,
36+ t.numel (), " elements" );
37+ return t.item <float >();
38+ }
39+
2340std::vector<at::Tensor>
2441get_mla_metadata (
2542 at::Tensor &seqlens_k,
@@ -68,16 +85,19 @@ mha_fwd_kvcache_mla(
6885 const float softmax_scale,
6986 bool is_causal,
7087 const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
71- const at::Tensor &num_splits // batch_size + 1
88+ const at::Tensor &num_splits, // batch_size + 1
89+ c10::optional<const at::Tensor> &descale_q, // batch_size
90+ c10::optional<const at::Tensor> &descale_k // batch_size
7291) {
7392 // Check the architecture
7493 auto dprops = at::cuda::getCurrentDeviceProperties ();
7594 bool is_sm90 = dprops->major == 9 && dprops->minor == 0 ;
7695 TORCH_CHECK (is_sm90);
7796
7897 // Check data types
79- auto q_dtype = q.dtype ();
80- TORCH_CHECK (q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf );
98+ auto q_dtype = q.scalar_type ();
99+ TORCH_CHECK (q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf ||
100+ q_dtype == torch::kFloat8_e4m3fn , " Unsupported dtype for query tensor" );
81101 TORCH_CHECK (kcache.dtype () == q_dtype, " query and key must have the same dtype" );
82102 TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
83103 TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
@@ -106,7 +126,7 @@ mha_fwd_kvcache_mla(
106126 const int num_heads_q = sizes[2 ];
107127 const int head_size_k = sizes[3 ];
108128 TORCH_CHECK (head_size_k == 576 , " Only head_size_k == 576 is supported" );
109- TORCH_CHECK (head_size_v == 512 , " Only head_size_v == 576 is supported" );
129+ TORCH_CHECK (head_size_v == 512 , " Only head_size_v == 512 is supported" );
110130
111131 const int max_num_blocks_per_seq = block_table.size (1 );
112132 const int num_blocks = kcache.size (0 );
@@ -133,7 +153,9 @@ mha_fwd_kvcache_mla(
133153 at::cuda::CUDAGuard device_guard{(char )q.get_device ()};
134154
135155 auto opts = q.options ();
136- at::Tensor out = torch::empty ({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
156+ auto out_type = (q_dtype == torch::kFloat8_e4m3fn ) ? torch::kBFloat16 : q_dtype;
157+ at::Tensor out = torch::empty ({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype (out_type));
158+
137159 at::Tensor softmax_lse = torch::empty ({batch_size, num_heads, q_seq_per_hk}, opts.dtype (at::kFloat ));
138160 CHECK_CONTIGUOUS (softmax_lse);
139161
@@ -152,6 +174,11 @@ mha_fwd_kvcache_mla(
152174 params.d_v = head_size_v;
153175 params.scale_softmax = softmax_scale;
154176 params.scale_softmax_log2 = float (softmax_scale * M_LOG2E);
177+ if (q_dtype == torch::kFloat8_e4m3fn ) {
178+ params.descale_q = get_scalar_f32_cpu_only (descale_q);
179+ params.descale_k = get_scalar_f32_cpu_only (descale_q);
180+ }
181+
155182 // Set the pointers and strides.
156183 params.q_ptr = q.data_ptr ();
157184 params.k_ptr = kcache.data_ptr ();
@@ -188,15 +215,19 @@ mha_fwd_kvcache_mla(
188215 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
189216 TORCH_CHECK (head_size_k == 576 );
190217 if (q_dtype == torch::kBFloat16 ) {
191- run_flash_splitkv_mla_kernel<cutlass::bfloat16_t >(params, stream);
192- run_flash_mla_combine_kernel<cutlass::bfloat16_t >(params, stream);
218+ TORCH_CHECK (false , " Unsupported tensor dtype for query" );
219+ // run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
220+ // run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
193221 } else if (q_dtype == torch::kHalf ) {
222+ TORCH_CHECK (false , " Unsupported tensor dtype for query" );
194223#ifdef FLASH_MLA_DISABLE_FP16
195224 TORCH_CHECK (false , " FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA." );
196225#else
197- run_flash_splitkv_mla_kernel<cutlass::half_t >(params, stream);
198- run_flash_mla_combine_kernel<cutlass::half_t >(params, stream);
226+ // run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
227+ // run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
199228#endif
229+ } else if (q_dtype == torch::kFloat8_e4m3fn ) {
230+ run_flash_splitkv_mla_kernel<cutlass::float_e4m3_t , cutlass::bfloat16_t >(params, stream);
200231 } else {
201232 TORCH_CHECK (false , " Unsupported tensor dtype for query" );
202233 }
0 commit comments