88import torch
99
1010from float8_experimental .float8_python_api import addmm_float8_unwrapped
11- from float8_experimental .float8_tensor import Float8Tensor
11+ from float8_experimental .float8_tensor import Float8Tensor , re_construct_float8_weight
1212from float8_experimental .float8_utils import is_row_major
1313from torch .utils ._pytree import tree_map
1414
@@ -191,9 +191,7 @@ def forward(
191191 if recompute_float8_weight :
192192 # This should be set to True when using traditional fsdp to avoid
193193 # saving the unsharded weight for backwards
194- ctx .save_for_backward (
195- x_fp8 , original_weight , weight_scale , weight_amax_buffer
196- )
194+ ctx .save_for_backward (x_fp8 , original_weight , weight_scale )
197195 else :
198196 # Does this interact properly with activation checkpointing?
199197 ctx .save_for_backward (x_fp8 , w_fp8 )
@@ -211,19 +209,15 @@ def forward(
211209 @staticmethod
212210 def backward (ctx , go_fp8 : torch .Tensor ):
213211 if ctx .recompute_float8_weight :
214- x_fp8 , original_weight , weight_scale , weight_amax_buffer = ctx .saved_tensors
215- w_fp8 = Float8Tensor .to_float8 (
216- original_weight ,
217- weight_scale ,
218- torch .float8_e4m3fn ,
219- weight_amax_buffer ,
220- emulate = ctx .emulate ,
212+ x_fp8 , original_weight , weight_scale = ctx .saved_tensors
213+ w_fp8 = re_construct_float8_weight (
214+ original_weight , weight_scale , torch .float8_e4m3fn , emulate = ctx .emulate
221215 )
222216 else :
223217 x_fp8 , w_fp8 = ctx .saved_tensors
224218
225219 # calculate dL/dX
226- go_fp8_reshaped = go_fp8 .view (- 1 , go_fp8 .size (- 1 ))
220+ go_fp8_reshaped = go_fp8 .reshape (- 1 , go_fp8 .size (- 1 ))
227221 w_fp8_t_c_t = w_fp8 .t ().contiguous ().t ()
228222 dL_dX = float8_mm_helper (go_fp8_reshaped , w_fp8_t_c_t )
229223 dL_dX = dL_dX .view (* go_fp8 .shape [:- 1 ], dL_dX .size (- 1 ))
0 commit comments