@@ -201,21 +201,13 @@ def forward(
201201
202202 ctx .recompute_float8_weight = recompute_float8_weight
203203 ctx .emulate = emulate
204- orig_shape = x_fp8 ._data .shape
205- x_fp8_reshaped = Float8Tensor (
206- x_fp8 ._data .reshape (- 1 , orig_shape [- 1 ]),
207- x_fp8 ._scale ,
208- x_fp8 ._orig_dtype ,
209- emulate = emulate ,
210- )
204+ x_fp8_reshaped = x_fp8 .reshape (- 1 , x_fp8 .size (- 1 ))
211205
212- w_fp8_t = Float8Tensor (
213- w_fp8 ._data .t (), w_fp8 ._scale , w_fp8 ._orig_dtype , emulate = emulate
214- )
206+ w_fp8_t = w_fp8 .t ()
215207
216208 res_bits = float8_mm_helper (x_fp8_reshaped , w_fp8_t )
217209
218- res_bits = res_bits .reshape (* orig_shape [:- 1 ], res_bits .shape [ - 1 ] )
210+ res_bits = res_bits .reshape (* x_fp8 . shape [:- 1 ], res_bits .size ( - 1 ) )
219211 return res_bits
220212
221213 @staticmethod
@@ -227,48 +219,21 @@ def backward(ctx, go_fp8: torch.Tensor):
227219 weight_scale ,
228220 torch .float8_e4m3fn ,
229221 weight_amax_buffer ,
230- emulate = emulate ,
222+ emulate = ctx . emulate ,
231223 )
232224 else :
233225 x_fp8 , w_fp8 = ctx .saved_tensors
234226
235- emulate = ctx .emulate
236-
237- go_fp8_orig_shape = go_fp8 ._data .shape
238- go_fp8_reshaped = Float8Tensor (
239- go_fp8 ._data .reshape (- 1 , go_fp8_orig_shape [- 1 ]),
240- go_fp8 ._scale ,
241- go_fp8 ._orig_dtype ,
242- emulate = emulate ,
243- )
244-
245- w_fp8_t_c_t = Float8Tensor (
246- w_fp8 ._data .t ().contiguous ().t (),
247- w_fp8 ._scale ,
248- w_fp8 ._orig_dtype ,
249- emulate = emulate ,
250- )
251-
252227 # calculate dL/dX
228+ go_fp8_reshaped = go_fp8 .reshape (- 1 , go_fp8 .size (- 1 ))
229+ w_fp8_t_c_t = w_fp8 .t ().contiguous ().t ()
253230 dL_dX = float8_mm_helper (go_fp8_reshaped , w_fp8_t_c_t )
254- dL_dX = dL_dX .reshape (* go_fp8_orig_shape [:- 1 ], dL_dX .shape [- 1 ])
255-
256- x_fp8_orig_shape = x_fp8 ._data .shape
257- x_fp8_reshaped_t_c = Float8Tensor (
258- x_fp8 ._data .reshape (- 1 , x_fp8_orig_shape [- 1 ]).t ().contiguous (),
259- x_fp8 ._scale ,
260- x_fp8 ._orig_dtype ,
261- emulate = emulate ,
262- )
263-
264- go_fp8_reshaped_t_c_t = Float8Tensor (
265- go_fp8_reshaped ._data .t ().contiguous ().t (),
266- go_fp8_reshaped ._scale ,
267- go_fp8_reshaped ._orig_dtype ,
268- emulate = emulate ,
269- )
231+ dL_dX = dL_dX .reshape (* go_fp8 .shape [:- 1 ], dL_dX .size (- 1 ))
270232
271233 # calculate dL/dW
234+ x_fp8_reshaped_t_c = x_fp8 .reshape (- 1 , x_fp8 .size (- 1 )).t ().contiguous ()
235+ go_fp8_reshaped_t_c_t = go_fp8_reshaped .t ().contiguous ().t ()
236+
272237 dL_dW = float8_mm_helper (x_fp8_reshaped_t_c , go_fp8_reshaped_t_c_t )
273238 dL_dW = dL_dW .t ()
274239
0 commit comments