4545 HAS_TMA = False
4646 logger .warning (f"Failed to import TMA: { e } " )
4747
48+ HAS_CUDA_129 = (
49+ torch .cuda .is_available () and torch .version .cuda and torch .version .cuda >= "12.9"
50+ )
51+
4852
4953def parse_args (args ):
5054 parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
@@ -63,6 +67,8 @@ def get_scaling_mode_int(scaling_mode: str) -> int:
6367 return ScalingMode .TENSOR
6468 elif scaling_mode == "row" :
6569 return ScalingMode .ROW
70+ elif scaling_mode == "deepseek" :
71+ return ScalingMode .DEEPSEEK
6672 else :
6773 raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
6874
@@ -111,11 +117,40 @@ def _get_scale_per_row(
111117 torch .float32
112118 ) # For row-wise scaling, kernel requires a float32 scale tensor
113119
120+ def _get_scale_deepseek (
121+ x : torch .Tensor ,
122+ block_outer : int ,
123+ ) -> tuple [torch .Tensor , torch .Tensor ]:
124+ """
125+ DeepSeek-style scaling on matmul A @ B uses a combination of block- and tile-wise scaling:
126+ - activation tensor A: 1x128 tile-wise scaling
127+ - weight tensor B: 128x128 block-wise scaling
128+ """
129+ block_inner = 128
130+ x = x .unflatten (1 , (- 1 , block_inner )).unflatten (0 , (- 1 , block_outer ))
131+ amax = x .abs ().amax (dim = [1 , 3 ], keepdim = True ).float ()
132+ scale = torch .finfo (torch .float8_e4m3fn ).max / amax
133+ x = (
134+ x .mul (scale ).flatten (2 , 3 ).flatten (0 , 1 )
135+ ) # scale input up to dynamic range of float8_e4m3fn
136+ scale = scale .flatten (2 , 3 ).flatten (0 , 1 )
137+ return x , scale .to (torch .float32 )
138+
114139 def args (m , n , k ):
115140 a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
116141 b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
117142
118- if self .scaling_mode_int == ScalingMode .ROW :
143+ if self .scaling_mode_int == ScalingMode .DEEPSEEK :
144+ activations_block_outer = 1
145+ weights_block_outer = 128
146+
147+ a , scale_a = _get_scale_deepseek (a , activations_block_outer )
148+ b , scale_b = _get_scale_deepseek (b , weights_block_outer )
149+
150+ scale_a = (
151+ scale_a .t ().contiguous ().t ()
152+ ) # 1x128 blocks need scales to be outer-dim-major
153+ elif self .scaling_mode_int == ScalingMode .ROW :
119154 scale_a = _get_scale_per_row (a )
120155 scale_b = _get_scale_per_row (b )
121156 else : # self.scaling_mode_int == ScalingMode.TENSOR
@@ -164,12 +199,22 @@ def get_x_val(self, example_inputs) -> float:
164199
165200 @register_benchmark (baseline = True )
166201 def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
202+ is_scaling_deepseek = self .scaling_mode_int == ScalingMode .DEEPSEEK
203+
204+ assert (
205+ not is_scaling_deepseek or HAS_CUDA_129
206+ ), "Deepseek-style scaling (BlockWise128x128) for scaled_gemm requires CUDA 12.9+"
207+
208+ use_fast_accum = (
209+ False if is_scaling_deepseek else True
210+ ) # blockwise scaled_gemm does not support use_fast_accum=True
211+
167212 return lambda : torch ._scaled_mm (
168213 a ,
169214 b .t (),
170215 scale_a ,
171216 scale_b .t (),
172- use_fast_accum = True ,
217+ use_fast_accum = use_fast_accum ,
173218 out_dtype = self ._get_dtype (),
174219 )
175220
0 commit comments