1010from pathlib import Path
1111from typing import Callable , List , Optional , Tuple
1212
13+ import bench_constants as bc
14+
1315import pandas as pd
1416
1517import torch
1618import torch .utils .benchmark as benchmark
19+ from float8_experimental .dynamic_linear .dynamic_float8_linear import Float8DynamicLinear
1720from float8_experimental .float8_linear import Float8Linear
1821from float8_experimental .float8_linear_utils import sync_float8_amax_and_scale_history
1922from tqdm import tqdm
2831except ImportError :
2932 print ("transformer_engine not installed and we won't compare against this" )
3033
31- # estimating TOPs for matmuls in fp32, fp16, fp8
32- # assuming A * B = C, with A being M * K, B being K * N, C being M * N
33-
34- # H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
35- h100_peak_flops_float32 = 67e12
36- h100_peak_flops_fp16_tc = 1979e12
37- h100_peak_tops_float8_tc = 3958e12
38-
39- dtype_to_peak_tops = {
40- torch .float32 : h100_peak_flops_float32 ,
41- torch .float16 : h100_peak_flops_fp16_tc ,
42- torch .bfloat16 : h100_peak_flops_fp16_tc ,
43- torch .float8_e4m3fn : h100_peak_tops_float8_tc ,
44- torch .float8_e5m2 : h100_peak_tops_float8_tc ,
45- }
46-
4734
4835def benchmark_torch_function_in_microseconds (
4936 func : Callable ,
@@ -63,6 +50,7 @@ class Experiment:
6350 shape : Tuple [int , int , int ]
6451 ref_time_sec : float
6552 float8_time_sec : float
53+ float8_dynamic_time_sec : float
6654 dtype : torch .dtype
6755 compiled : bool = False
6856 float_8_dtype : Optional [torch .dtype ] = torch .float8_e4m3fn
@@ -76,7 +64,7 @@ def ref_tops_sec(self):
7664
7765 @property
7866 def ref_pct_top_peak (self ):
79- return self .ref_tops_sec / dtype_to_peak_tops [self .dtype ]
67+ return self .ref_tops_sec / bc . dtype_to_peak_tops [self .dtype ]
8068
8169 @property
8270 def float8_tops_sec (self ):
@@ -85,7 +73,7 @@ def float8_tops_sec(self):
8573
8674 @property
8775 def float8_pct_top_peak (self ):
88- return self .float8_tops_sec / dtype_to_peak_tops [self .float_8_dtype ]
76+ return self .float8_tops_sec / bc . dtype_to_peak_tops [self .float_8_dtype ]
8977
9078 @property
9179 def te_tops_sec (self ):
@@ -98,7 +86,7 @@ def te_tops_sec(self):
9886 @property
9987 def te_pct_top_peak (self ):
10088 if self .te_tops_sec is not None :
101- return self .te_tops_sec / dtype_to_peak_tops [self .float_8_dtype ]
89+ return self .te_tops_sec / bc . dtype_to_peak_tops [self .float_8_dtype ]
10290 else :
10391 return None
10492
@@ -107,24 +95,27 @@ def main(
10795 sweep_path : Path ,
10896 compile : bool ,
10997 n_limit : Optional [int ] = None ,
98+ llama_model_size : str = "70B" ,
11099):
111100 device = "cuda"
112101 print (f"Compile is set to | { compile } " )
102+ print ("model size:" , llama_model_size )
103+
104+ name_to_shapes = bc .name_to_shapes [llama_model_size ]
105+ if llama_model_size == "70B" :
106+ # common distributed setup, single GPU numbers
107+ bsz , seq_len = 4 , 4096
108+ else :
109+ # debug single gpu setup
110+ bsz , seq_len = 1 , 4096
113111
114- # LLaMa 2 70B single-node weight shapes
115- # assumes fused attn.wqkv and ffn.w13
116- # source: https://fburl.com/gsheet/g8onr7rh
117- name_to_shapes_70b = {
118- "attn.wqkv" : (8192 , 1280 ),
119- "attn.w0" : (1024 , 8192 ),
120- "ffn.w13" : (8192 , 7168 ),
121- "ffn.w2" : (3584 , 8192 ),
122- }
123112 input_bias = False
124- ref_dtypes = [torch .bfloat16 , torch .float16 ]
113+ ref_dtypes = [
114+ torch .bfloat16 ,
115+ ]
125116 experiment_list : List [Experiment ] = []
126117 for idx , (dtype , (name , (K , N ))) in enumerate (
127- tqdm (list (product (ref_dtypes , name_to_shapes_70b .items ())))
118+ tqdm (list (product (ref_dtypes , name_to_shapes .items ())))
128119 ):
129120 if n_limit is not None and idx >= n_limit :
130121 break
@@ -136,7 +127,10 @@ def main(
136127 copy .deepcopy (linear_ref ), emulate = False
137128 )
138129
139- bsz , seq_len = 4 , 4096
130+ linear_dynamic_float8 = Float8DynamicLinear .from_float (
131+ copy .deepcopy (linear_ref ), emulate = False
132+ )
133+
140134 M = bsz * seq_len
141135 input_tensor = torch .randn (M , K , device = device , dtype = dtype , requires_grad = True )
142136 ref_forw_backward = lambda : linear_ref (input_tensor ).sum ().backward ()
@@ -145,6 +139,10 @@ def float8_forw_backward():
145139 sync_float8_amax_and_scale_history (linear_float8 )
146140 linear_float8 (input_tensor ).sum ().backward ()
147141
142+ float8_dynamic_forw_backward = (
143+ lambda : linear_dynamic_float8 (input_tensor ).sum ().backward ()
144+ )
145+
148146 if transformer_engine_installed :
149147 # Use the same recipe as float8_linear.DelayedScalingRecipe
150148 fp8_format = recipe .Format .HYBRID
@@ -169,19 +167,23 @@ def wrapper(*args, **kwargs):
169167
170168 ref_forw_backward = n_times (REPEAT_N , ref_forw_backward )
171169 float8_forw_backward = n_times (REPEAT_N , float8_forw_backward )
170+ float8_dynamic_forw_backward = n_times (REPEAT_N , float8_dynamic_forw_backward )
172171 if transformer_engine_installed :
173172 te_forw_backward = n_times (REPEAT_N , te_forw_backward )
174173
175174 if compile :
176175 ref_forw_backward = torch .compile (ref_forw_backward )
177176 float8_forw_backward = torch .compile (float8_forw_backward )
177+ float8_dynamic_forw_backward = torch .compile (float8_dynamic_forw_backward )
178178 # Compiling TE_linear fails but they are already compiling under the hood
179179 # if transformer_engine_installed:
180180 # te_forw_backward = torch.compile(te_forw_backward)
181181
182+ # warmup
182183 for _ in range (5 ):
183184 ref_forw_backward ()
184185 float8_forw_backward ()
186+ float8_dynamic_forw_backward ()
185187 if transformer_engine_installed :
186188 te_forw_backward ()
187189
@@ -195,6 +197,11 @@ def wrapper(*args, **kwargs):
195197 * 1e-6
196198 / REPEAT_N
197199 )
200+ float8_dynamic_time = (
201+ benchmark_torch_function_in_microseconds (float8_dynamic_forw_backward )
202+ * 1e-6
203+ / REPEAT_N
204+ )
198205 if transformer_engine_installed :
199206 te_time_sec = (
200207 benchmark_torch_function_in_microseconds (te_forw_backward )
@@ -208,12 +215,17 @@ def wrapper(*args, **kwargs):
208215 (M , K , N ),
209216 ref_time ,
210217 float8_time ,
218+ float8_dynamic_time ,
211219 dtype ,
212220 compile ,
213221 te_time_sec = te_time_sec ,
214222 )
215223 print (experiment )
216224 print ("float8 speedup" , experiment .ref_time_sec / experiment .float8_time_sec )
225+ print (
226+ "float8 dynamic speedup" ,
227+ experiment .ref_time_sec / experiment .float8_dynamic_time_sec ,
228+ )
217229 if transformer_engine_installed :
218230 print ("te speedup" , experiment .ref_time_sec / experiment .te_time_sec )
219231 experiment_list .append (experiment )
@@ -229,6 +241,7 @@ def wrapper(*args, **kwargs):
229241 "fp8_dtype" ,
230242 "ref_time_sec" ,
231243 "pt_fp8_time_sec" ,
244+ "pt_fp8_dynamic_time_sec" ,
232245 "te_fp8_time_sec" ,
233246 "ref_tops_sec" ,
234247 "ref_pct_top_peak" ,
@@ -250,6 +263,7 @@ def wrapper(*args, **kwargs):
250263 experiment .float_8_dtype ,
251264 experiment .ref_time_sec ,
252265 experiment .float8_time_sec ,
266+ experiment .float8_dynamic_time_sec ,
253267 experiment .te_time_sec ,
254268 experiment .ref_tops_sec ,
255269 experiment .ref_pct_top_peak ,
@@ -262,6 +276,9 @@ def wrapper(*args, **kwargs):
262276
263277 data_pd = pd .DataFrame (data , columns = headers )
264278 data_pd ["pt_fp8_speedup" ] = data_pd ["ref_time_sec" ] / data_pd ["pt_fp8_time_sec" ]
279+ data_pd ["pt_fp8_dynamic_speedup" ] = (
280+ data_pd ["ref_time_sec" ] / data_pd ["pt_fp8_dynamic_time_sec" ]
281+ )
265282 if transformer_engine_installed :
266283 data_pd ["te_fp8_speedup" ] = data_pd ["ref_time_sec" ] / data_pd ["te_fp8_time_sec" ]
267284 else :
@@ -280,12 +297,13 @@ def wrapper(*args, **kwargs):
280297 [
281298 "name" ,
282299 "shape" ,
283- "ref_dtype" ,
284300 "compiled" ,
285301 "ref_time_sec" ,
286302 "pt_fp8_time_sec" ,
303+ "pt_fp8_dynamic_time_sec" ,
287304 "te_fp8_time_sec" ,
288305 "pt_fp8_speedup" ,
306+ "pt_fp8_dynamic_speedup" ,
289307 "te_fp8_speedup" ,
290308 ]
291309 ]
@@ -301,6 +319,9 @@ def wrapper(*args, **kwargs):
301319 parser .add_argument ("-o" , "--output_path" , type = str , required = True )
302320 parser .add_argument ("--compile" , action = "store_true" )
303321 parser .add_argument ("-n" , "--n_limit" , type = int , required = False )
322+ parser .add_argument (
323+ "--llama_model_size" , type = str , required = True , choices = ["70B" , "7B" , "13B" ]
324+ )
304325 args = parser .parse_args ()
305326 output_path = Path (args .output_path )
306- main (output_path , args .compile , args .n_limit )
327+ main (output_path , args .compile , args .n_limit , args . llama_model_size )
0 commit comments