@@ -125,7 +125,7 @@ def backward(ctx, go):
125125 fp8_scale_grad_output ,
126126 e5m2_dtype ,
127127 linear_mm_config = ctx .linear_mm_config ,
128- gemm_input_role = GemmInputRole .DL_DY ,
128+ gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
129129 )
130130 empty_grads = None , None , None , None , None , None
131131 return res , * empty_grads
@@ -273,21 +273,21 @@ def convert_amax_buffer_to_float32(self):
273273 if self ._buffers [key ] is not None :
274274 self ._buffers [key ] = self ._buffers [key ].to (torch .float32 )
275275
276- def cast_x_to_float8 (
277- self , x : torch .Tensor , is_amax_initialized : bool
276+ def cast_input_to_float8 (
277+ self , input : torch .Tensor , is_amax_initialized : bool
278278 ) -> torch .Tensor :
279279 # Duplicate the autocast logic for F.linear, so that the output
280280 # of our module has the right original precision
281281 if torch .is_autocast_enabled ():
282282 # For now, hardcode to GPU's autocast dtype
283283 # if we need CPU support in the future, we can add it
284284 autocast_dtype = torch .get_autocast_gpu_dtype ()
285- x = x .to (autocast_dtype )
285+ input = input .to (autocast_dtype )
286286
287287 if self .scaling_type_input is TensorScalingType .DELAYED :
288288 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
289289 _maybe_initialize_amaxes_scales_for_float8_cast (
290- x ,
290+ input ,
291291 self .fp8_amax_input ,
292292 self .fp8_amax_history_input ,
293293 self .fp8_scale_input ,
@@ -296,29 +296,29 @@ def cast_x_to_float8(
296296 is_amax_initialized ,
297297 reduce_amax = True ,
298298 )
299- x_fp8 = Float8Tensor .to_float8 (
300- x ,
299+ input_fp8 = Float8Tensor .to_float8 (
300+ input ,
301301 self .fp8_scale_input ,
302302 e4m3_dtype ,
303303 self .fp8_amax_input ,
304304 linear_mm_config = self .linear_mm_config ,
305- gemm_input_role = GemmInputRole .X ,
305+ gemm_input_role = GemmInputRole .INPUT ,
306306 )
307307 else :
308308 assert self .scaling_type_input is TensorScalingType .DYNAMIC
309- x_fp8 = cast_to_float8_e4m3_dynamic (x , self .linear_mm_config )
310- return x_fp8
309+ input_fp8 = cast_to_float8_e4m3_dynamic (input , self .linear_mm_config )
310+ return input_fp8
311311
312- def cast_w_to_float8 (
313- self , w : torch .Tensor , is_amax_initialized : bool
312+ def cast_weight_to_float8 (
313+ self , weight : torch .Tensor , is_amax_initialized : bool
314314 ) -> torch .Tensor :
315315 if self .scaling_type_weight is TensorScalingType .DELAYED :
316316 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
317- w_fp8 = self .weight
317+ weight_fp8 = self .weight
318318 else :
319319 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
320320 _maybe_initialize_amaxes_scales_for_float8_cast (
321- w ,
321+ weight ,
322322 self .fp8_amax_weight ,
323323 self .fp8_amax_history_weight ,
324324 self .fp8_scale_weight ,
@@ -328,29 +328,31 @@ def cast_w_to_float8(
328328 reduce_amax = False ,
329329 )
330330
331- w_fp8 = Float8Tensor .to_float8 (
332- w ,
331+ weight_fp8 = Float8Tensor .to_float8 (
332+ weight ,
333333 self .fp8_scale_weight ,
334334 e4m3_dtype ,
335335 self .fp8_amax_weight ,
336336 linear_mm_config = self .linear_mm_config ,
337- gemm_input_role = GemmInputRole .W ,
337+ gemm_input_role = GemmInputRole .WEIGHT ,
338338 )
339339 else :
340340 assert self .scaling_type_weight is TensorScalingType .DYNAMIC
341341 if isinstance (self .weight , Float8Tensor ): # cast by FSDP
342- w_fp8 = self .weight
342+ weight_fp8 = self .weight
343343 else :
344- w_fp8 = cast_to_float8_e4m3_dynamic (
345- self .weight , self .linear_mm_config , gemm_input_role = GemmInputRole .W
344+ weight_fp8 = cast_to_float8_e4m3_dynamic (
345+ self .weight ,
346+ self .linear_mm_config ,
347+ gemm_input_role = GemmInputRole .WEIGHT ,
346348 )
347- return w_fp8
349+ return weight_fp8
348350
349- def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
351+ def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
350352 if self .scaling_type_grad_output is TensorScalingType .DELAYED :
351353 scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
352- y = NoopFwToFloat8E5M2Bw .apply (
353- y ,
354+ output = NoopFwToFloat8E5M2Bw .apply (
355+ output ,
354356 self .fp8_amax_grad_output ,
355357 self .fp8_amax_history_grad_output ,
356358 self .fp8_scale_grad_output ,
@@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
360362 )
361363 else :
362364 assert self .scaling_type_grad_output is TensorScalingType .DYNAMIC
363- y = cast_to_float8_e5m2_dynamic_bw (y , self .linear_mm_config )
364- return y
365+ output = cast_to_float8_e5m2_dynamic_bw (output , self .linear_mm_config )
366+ return output
365367
366- def float8_pre_forward (self , x ):
368+ def float8_pre_forward (self , input ):
367369 if not self .enable_pre_and_post_forward :
368370 return
369371 if (
@@ -374,7 +376,7 @@ def float8_pre_forward(self, x):
374376 raise AssertionError (
375377 "amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
376378 )
377- self .last_seen_input_dtype = x .dtype
379+ self .last_seen_input_dtype = input .dtype
378380
379381 def float8_post_forward (self ):
380382 if not self .enable_pre_and_post_forward :
@@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
388390 if self .has_any_delayed_scaling :
389391 self .float8_pre_forward (input )
390392
391- x_fp8 = self .cast_x_to_float8 (input , self .is_amax_initialized )
392- w_fp8 = self .cast_w_to_float8 (self .weight , self .is_amax_initialized )
393+ input_fp8 = self .cast_input_to_float8 (input , self .is_amax_initialized )
394+ weight_fp8 = self .cast_weight_to_float8 (self .weight , self .is_amax_initialized )
393395
394- y = torch .matmul (x_fp8 , w_fp8 .t ())
396+ output = torch .matmul (input_fp8 , weight_fp8 .t ())
395397
396- # Cast gradY to float8_e5m2 during backward
397- y = self .cast_y_to_float8_in_bw ( y )
398+ # Cast grad_output to float8_e5m2 during backward
399+ output = self .cast_output_to_float8_in_bw ( output )
398400
399401 if self .bias is not None :
400- y = y + self .bias .to (y .dtype )
402+ output = output + self .bias .to (output .dtype )
401403
402404 if self .has_any_delayed_scaling :
403405 self .float8_post_forward ()
404- return y
406+ return output
405407
406408 def scaling_repr (self ):
407409 # add scaling settings without using too many characters
408- # example: "x :del,w:del,dldy :dyn"
409- return f"x :{ self .scaling_type_input .short_str ()} ,w:{ self .scaling_type_weight .short_str ()} ,dldy :{ self .scaling_type_grad_output .short_str ()} "
410+ # example: "i :del,w:del,go :dyn"
411+ return f"i :{ self .scaling_type_input .short_str ()} ,w:{ self .scaling_type_weight .short_str ()} ,go :{ self .scaling_type_grad_output .short_str ()} "
410412
411413 def extra_repr (self ):
412414 s = f'{ super ().extra_repr ()} , scaling="{ self .scaling_repr ()} "'
0 commit comments