@@ -56,6 +56,7 @@ class Experiment:
5656 dtype : torch .dtype
5757 compiled : bool = False
5858 float_8_dtype : Optional [torch .dtype ] = torch .float8_e4m3fn
59+ recompute_weight_cast : bool = False
5960
6061 # 3 Times since we are calculating forward backward
6162 @property
@@ -95,9 +96,14 @@ def main(
9596 }
9697 input_bias = False
9798 ref_dtypes = [torch .bfloat16 , torch .float16 ]
99+ recompute_weight_casts = [True , False ]
98100 experiment_list : List [Experiment ] = []
99- for idx , (dtype , (name , (K , N ))) in enumerate (
100- tqdm (list (product (ref_dtypes , name_to_shapes_70b .items ())))
101+ for idx , (dtype , (name , (K , N )), recompute_weight_cast ) in enumerate (
102+ tqdm (
103+ list (
104+ product (ref_dtypes , name_to_shapes_70b .items (), recompute_weight_casts )
105+ )
106+ )
101107 ):
102108 if n_limit is not None and idx >= n_limit :
103109 break
@@ -106,7 +112,9 @@ def main(
106112 )
107113
108114 linear_float8 = Float8Linear .from_float (
109- copy .deepcopy (linear_ref ), emulate = False
115+ copy .deepcopy (linear_ref ),
116+ emulate = False ,
117+ recompute_weight_cast = recompute_weight_cast ,
110118 )
111119
112120 bsz , seq_len = 4 , 4096
@@ -155,6 +163,7 @@ def wrapper(*args, **kwargs):
155163 float8_time ,
156164 dtype ,
157165 compile ,
166+ recompute_weight_cast = recompute_weight_cast ,
158167 )
159168 print (experiment )
160169 print ("float8 speedup" , experiment .ref_time_sec / experiment .float8_time_sec )
@@ -169,6 +178,7 @@ def wrapper(*args, **kwargs):
169178 "ref_dtype" ,
170179 "compiled" ,
171180 "fp8_dtype" ,
181+ "recompute_weight_cast" ,
172182 "ref_time_sec" ,
173183 "pt_fp8_time_sec" ,
174184 "ref_tops_sec" ,
@@ -187,6 +197,7 @@ def wrapper(*args, **kwargs):
187197 experiment .dtype ,
188198 experiment .compiled ,
189199 experiment .float_8_dtype ,
200+ experiment .recompute_weight_cast ,
190201 experiment .ref_time_sec ,
191202 experiment .float8_time_sec ,
192203 experiment .ref_tops_sec ,
@@ -214,6 +225,7 @@ def wrapper(*args, **kwargs):
214225 "shape" ,
215226 "ref_dtype" ,
216227 "compiled" ,
228+ "recompute_weight_cast" ,
217229 "ref_time_sec" ,
218230 "pt_fp8_time_sec" ,
219231 "pt_fp8_speedup" ,
0 commit comments