1414
1515import torch
1616import torch .utils .benchmark as benchmark
17+ from float8_experimental .float8_linear import TensorScalingType
1718from float8_experimental .float8_linear_utils import (
1819 get_float8_linear ,
20+ linear_requires_sync ,
1921 LinearType ,
2022 sync_float8_amax_and_scale_history ,
2123)
@@ -68,6 +70,7 @@ class Experiment:
6870 compiled : bool
6971 use_fast_accum : bool
7072 linear_type : str
73+ scaling_repr : str
7174
7275 # 3 Times since we are calculating forward backward
7376 @property
@@ -96,10 +99,17 @@ def main(
9699 fast_accum_filter : Optional [bool ] = None ,
97100 shape_name_filter : Optional [str ] = None ,
98101 linear_type_filter : Optional [str ] = None ,
102+ scaling_type_x : str = "delayed" ,
103+ scaling_type_w : str = "delayed" ,
104+ scaling_type_dL_dY : str = "delayed" ,
99105):
100106 device = "cuda"
101107 print (f"Compile is set to | { compile } " )
102108
109+ scaling_type_x = TensorScalingType (scaling_type_x )
110+ scaling_type_w = TensorScalingType (scaling_type_w )
111+ scaling_type_dL_dY = TensorScalingType (scaling_type_dL_dY )
112+
103113 # LLaMa 2 70B single-node weight shapes
104114 # assumes fused attn.wqkv and ffn.w13
105115 name_to_shapes_70b = {
@@ -134,9 +144,24 @@ def main(
134144 LinearType .DELAYED if linear_type == "delayed" else LinearType .DYNAMIC
135145 )
136146
137- linear_float8 = get_float8_linear (
138- linear_type_enum , copy .deepcopy (linear_ref ), emulate = False
139- )
147+ if linear_type == "delayed" :
148+ linear_float8 = get_float8_linear (
149+ linear_type_enum ,
150+ copy .deepcopy (linear_ref ),
151+ emulate = False ,
152+ scaling_type_x = scaling_type_x ,
153+ scaling_type_w = scaling_type_w ,
154+ scaling_type_dL_dY = scaling_type_dL_dY ,
155+ )
156+ scaling_repr = linear_float8 .scaling_repr ()
157+ else :
158+ linear_float8 = get_float8_linear (
159+ linear_type_enum ,
160+ copy .deepcopy (linear_ref ),
161+ emulate = False ,
162+ )
163+ scaling_repr = None
164+
140165 if fast_accum :
141166 linear_float8 .forward_config = ScaledMMConfig (False , True , False )
142167 else :
@@ -150,7 +175,10 @@ def main(
150175 if linear_type_enum == LinearType .DELAYED :
151176
152177 def float8_forw_backward ():
153- sync_float8_amax_and_scale_history (linear_float8 )
178+ if linear_requires_sync (
179+ linear_type_enum , scaling_type_x , scaling_type_w , scaling_type_dL_dY
180+ ):
181+ sync_float8_amax_and_scale_history (linear_float8 )
154182 linear_float8 (input_tensor ).sum ().backward ()
155183
156184 else :
@@ -197,6 +225,7 @@ def wrapper(*args, **kwargs):
197225 compile ,
198226 use_fast_accum = fast_accum ,
199227 linear_type = linear_type ,
228+ scaling_repr = scaling_repr ,
200229 )
201230 print (experiment )
202231 print ("float8 speedup" , experiment .ref_time_sec / experiment .float8_time_sec )
@@ -209,6 +238,7 @@ def wrapper(*args, **kwargs):
209238 "K" ,
210239 "N" ,
211240 "linear_type" ,
241+ "scaling_repr" ,
212242 "ref_dtype" ,
213243 "compiled" ,
214244 "use_fast_accum" ,
@@ -228,6 +258,7 @@ def wrapper(*args, **kwargs):
228258 experiment .shape [1 ],
229259 experiment .shape [2 ],
230260 experiment .linear_type ,
261+ experiment .scaling_repr ,
231262 experiment .dtype ,
232263 experiment .compiled ,
233264 experiment .use_fast_accum ,
@@ -257,6 +288,7 @@ def wrapper(*args, **kwargs):
257288 "name" ,
258289 "shape" ,
259290 "linear_type" ,
291+ "scaling_repr" ,
260292 "compiled" ,
261293 "use_fast_accum" ,
262294 "ref_time_sec" ,
@@ -280,15 +312,26 @@ def invoke_main() -> None:
280312 parser .add_argument ("--fast_accum_filter" , type = bool , required = False )
281313 parser .add_argument ("--shape_name_filter" , type = str , required = False )
282314 parser .add_argument ("--linear_type_filter" , type = str , required = False )
315+ parser .add_argument ("--scaling_type_x" , type = str , required = False )
316+ parser .add_argument ("--scaling_type_w" , type = str , required = False )
317+ parser .add_argument ("--scaling_type_dL_dY" , type = str , required = False )
283318 args = parser .parse_args ()
284319 output_path = Path (args .output_path ) if args .output_path is not None else None
320+ kwargs = {}
321+ if args .scaling_type_x is not None :
322+ kwargs ["scaling_type_x" ] = args .scaling_type_x
323+ if args .scaling_type_w is not None :
324+ kwargs ["scaling_type_w" ] = args .scaling_type_w
325+ if args .scaling_type_dL_dY is not None :
326+ kwargs ["scaling_type_dL_dY" ] = args .scaling_type_dL_dY
285327 main (
286328 output_path ,
287329 args .compile ,
288330 args .n_limit ,
289331 args .fast_accum_filter ,
290332 args .shape_name_filter ,
291333 args .linear_type_filter ,
334+ ** kwargs ,
292335 )
293336
294337
0 commit comments