88import torch
99
1010from float8_experimental .float8_python_api import addmm_float8_unwrapped
11- from float8_experimental .float8_tensor import Float8Tensor
11+ from float8_experimental .float8_tensor import Float8Tensor , re_construct_float8_weight
1212from float8_experimental .float8_utils import is_row_major
1313from torch .utils ._pytree import tree_map
1414
@@ -192,7 +192,7 @@ def forward(
192192 # This should be set to True when using traditional fsdp to avoid
193193 # saving the unsharded weight for backwards
194194 ctx .save_for_backward (
195- x_fp8 , original_weight , weight_scale , weight_amax_buffer
195+ x_fp8 , original_weight , weight_scale
196196 )
197197 else :
198198 # Does this interact properly with activation checkpointing?
@@ -211,19 +211,13 @@ def forward(
211211 @staticmethod
212212 def backward (ctx , go_fp8 : torch .Tensor ):
213213 if ctx .recompute_float8_weight :
214- x_fp8 , original_weight , weight_scale , weight_amax_buffer = ctx .saved_tensors
215- w_fp8 = Float8Tensor .to_float8 (
216- original_weight ,
217- weight_scale ,
218- torch .float8_e4m3fn ,
219- weight_amax_buffer ,
220- emulate = ctx .emulate ,
221- )
214+ x_fp8 , original_weight , weight_scale = ctx .saved_tensors
215+ w_fp8 = re_construct_float8_weight (original_weight , weight_scale , torch .float8_e4m3fn , emulate = ctx .emulate )
222216 else :
223217 x_fp8 , w_fp8 = ctx .saved_tensors
224218
225219 # calculate dL/dX
226- go_fp8_reshaped = go_fp8 .view (- 1 , go_fp8 .size (- 1 ))
220+ go_fp8_reshaped = go_fp8 .reshape (- 1 , go_fp8 .size (- 1 ))
227221 w_fp8_t_c_t = w_fp8 .t ().contiguous ().t ()
228222 dL_dX = float8_mm_helper (go_fp8_reshaped , w_fp8_t_c_t )
229223 dL_dX = dL_dX .view (* go_fp8 .shape [:- 1 ], dL_dX .size (- 1 ))
0 commit comments