7070 ScalingType ,
7171 CastConfig ,
7272)
73+ from torchao .float8 .config import recipe_name_to_linear_config , Float8LinearRecipeName
7374
7475
7576class LNLinearSigmoid (torch .nn .Module ):
@@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
129130 else :
130131 # cache does not exist yet, create it
131132 cache = dict ()
133+ else :
134+ cache = dict ()
132135 key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
133136 if key in cache :
134137 return cache [key ]
@@ -153,13 +156,18 @@ def do_matmul(A, B):
153156 )
154157 f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
155158
159+ scale_a = torch .ones (M , 1 , device = device )
160+ scale_b = torch .ones (1 , N , device = device )
161+ fast_accum = True # for axiswise
162+ f8_axs_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
163+
156164 # save to cache if needed
157165 if cache_filename is not None :
158- cache [key ] = [bf16_time_s , f8_time_s ]
166+ cache [key ] = [bf16_time_s , f8_time_s , f8_axs_time_s ]
159167 with open (cache_filename , 'w' ) as f :
160168 json .dump (cache , f )
161169
162- return bf16_time_s , f8_time_s
170+ return bf16_time_s , f8_time_s , f8_axs_time_s
163171
164172def run (
165173 outfile : str ,
@@ -231,13 +239,15 @@ def run(
231239 headers = [
232240 'fwd_M' , 'fwd_K' , 'fwd_N' ,
233241 # gemm microbenchmarks
234- 'bf16_gemm_s' , 'fp8_gemm_s' ,
242+ 'bf16_gemm_s' , 'fp8_gemm_s' , 'fp8_axs_gemm_time_s' ,
235243 # roofline memory overhead estimates
236244 'fp8_oh_dyn_limit' , 'fp8_oh_dyn_nolimit' ,
237245 'fp8_oh_del_limit' , 'fp8_oh_del_nolimit' ,
238246 # actual e2e measurements
239- 'bf16_e2e_s' , 'fp8_dyn_e2e_s' , 'fp8_del_e2e_s' ,
240- 'fp8_dyn_speedup' , 'fp8_del_speedup' ,
247+ 'bf16_s' , 'fp8_dyn_s' , 'fp8_del_s' , 'fp8_dyn_axs_s' ,
248+ # 'fp8_lw_s',
249+ 'fp8_dyn_sp' , 'fp8_del_sp' , 'fp8_dyn_axs_sp' ,
250+ # 'fp8_lw_sp',
241251 ]
242252 results = []
243253
@@ -248,15 +258,18 @@ def run(
248258 break
249259
250260 if gemm_time_strategy == "benchmarks" :
251- bf16_g1 , f8_g1 = get_gemm_times (M_val , K_val , N_val , True , gemm_cache_filename )
252- bf16_g2 , f8_g2 = get_gemm_times (M_val , N_val , K_val , False , gemm_cache_filename )
253- bf16_g3 , f8_g3 = get_gemm_times (K_val , M_val , N_val , False , gemm_cache_filename )
261+ bf16_g1 , f8_g1 , f8_g1_axs = get_gemm_times (M_val , K_val , N_val , True , gemm_cache_filename )
262+ bf16_g2 , f8_g2 , f8_g2_axs = get_gemm_times (M_val , N_val , K_val , False , gemm_cache_filename )
263+ bf16_g3 , f8_g3 , f8_g3_axs = get_gemm_times (K_val , M_val , N_val , False , gemm_cache_filename )
254264 bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
255265 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
266+ fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
256267 else :
257268 assert gemm_time_strategy == "roofline" , "unsupported"
258269 bf16_time_val = bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
259270 fp8_gemm_time_s = fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
271+ # for now, assume axiswise gemm is similar to tensorwise
272+ fp8_axs_gemm_time_s = fp8_gemm_time_s
260273
261274 fp8_mem_time_dyn_limit_s = \
262275 fp8_mem_time_sympy_dyn_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
@@ -291,23 +304,43 @@ def run(
291304 cast_config_weight = CastConfig (scaling_type = ScalingType .DELAYED ),
292305 cast_config_grad_output = CastConfig (scaling_type = ScalingType .DELAYED ),
293306 )
294- m_fp8_del = convert_to_float8_training (m_orig )
307+ m_fp8_del = convert_to_float8_training (copy . deepcopy ( m_orig ), config = config )
295308 m_fp8_del = torch .compile (m_fp8_del )
296309 fp8_del_time_actual_s = get_gpu_kernel_time (m_fp8_del , x )
297310
311+ # get the float8 dynamic axiswise scaling gpu kernel time
312+ torch ._dynamo .reset ()
313+ config = recipe_name_to_linear_config (Float8LinearRecipeName .ALL_AXISWISE )
314+ m_fp8_dyn_axs = convert_to_float8_training (copy .deepcopy (m_orig ), config = config )
315+ m_fp8_dyn_axs = torch .compile (m_fp8_dyn_axs )
316+ fp8_dyn_axs_time_actual_s = get_gpu_kernel_time (m_fp8_dyn_axs , x )
317+
318+ # get the lw recipe scaling gpu kernel time
319+ # TODO(future PR): enable below once basic performance issues
320+ # are fixed
321+ # torch._dynamo.reset()
322+ # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
323+ # m_fp8_lw = convert_to_float8_training(m_orig, config=config)
324+ # m_fp8_lw = torch.compile(m_fp8_lw)
325+ # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
326+
298327 results .append ([
299328 M_val , K_val , N_val ,
300329 # gemm microbenchmarks
301- bf16_time_val , fp8_gemm_time_s ,
330+ bf16_time_val , fp8_gemm_time_s , fp8_axs_gemm_time_s ,
302331 # roofline overhead estimates
303332 fp8_mem_time_dyn_limit_s ,
304333 fp8_mem_time_dyn_nolimit_s ,
305334 fp8_mem_time_del_limit_s ,
306335 fp8_mem_time_del_nolimit_s ,
307336 # e2e numbers
308337 bf16_time_actual_s , fp8_dyn_time_actual_s , fp8_del_time_actual_s ,
338+ fp8_dyn_axs_time_actual_s ,
339+ # fp8_lw_time_actual_s,
309340 bf16_time_actual_s / fp8_dyn_time_actual_s ,
310341 bf16_time_actual_s / fp8_del_time_actual_s ,
342+ bf16_time_actual_s / fp8_dyn_axs_time_actual_s ,
343+ # bf16_time_actual_s / fp8_lw_time_actual_s,
311344 ])
312345
313346 df = pd .DataFrame (results , columns = headers )
0 commit comments