Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 67dca36

Browse files
committed
gotta figure out a way to call shape that is compile friendly
1 parent 32fbe21 commit 67dca36

File tree

1 file changed

+10
-45
lines changed

1 file changed

+10
-45
lines changed

float8_experimental/float8_ops.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -201,21 +201,13 @@ def forward(
201201

202202
ctx.recompute_float8_weight = recompute_float8_weight
203203
ctx.emulate = emulate
204-
orig_shape = x_fp8._data.shape
205-
x_fp8_reshaped = Float8Tensor(
206-
x_fp8._data.reshape(-1, orig_shape[-1]),
207-
x_fp8._scale,
208-
x_fp8._orig_dtype,
209-
emulate=emulate,
210-
)
204+
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8.size(-1))
211205

212-
w_fp8_t = Float8Tensor(
213-
w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, emulate=emulate
214-
)
206+
w_fp8_t = w_fp8.t()
215207

216208
res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t)
217209

218-
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
210+
res_bits = res_bits.reshape(*x_fp8.shape[:-1], res_bits.size(-1))
219211
return res_bits
220212

221213
@staticmethod
@@ -227,48 +219,21 @@ def backward(ctx, go_fp8: torch.Tensor):
227219
weight_scale,
228220
torch.float8_e4m3fn,
229221
weight_amax_buffer,
230-
emulate=emulate,
222+
emulate=ctx.emulate,
231223
)
232224
else:
233225
x_fp8, w_fp8 = ctx.saved_tensors
234226

235-
emulate = ctx.emulate
236-
237-
go_fp8_orig_shape = go_fp8._data.shape
238-
go_fp8_reshaped = Float8Tensor(
239-
go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
240-
go_fp8._scale,
241-
go_fp8._orig_dtype,
242-
emulate=emulate,
243-
)
244-
245-
w_fp8_t_c_t = Float8Tensor(
246-
w_fp8._data.t().contiguous().t(),
247-
w_fp8._scale,
248-
w_fp8._orig_dtype,
249-
emulate=emulate,
250-
)
251-
252227
# calculate dL/dX
228+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1))
229+
w_fp8_t_c_t = w_fp8.t().contiguous().t()
253230
dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t)
254-
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
255-
256-
x_fp8_orig_shape = x_fp8._data.shape
257-
x_fp8_reshaped_t_c = Float8Tensor(
258-
x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
259-
x_fp8._scale,
260-
x_fp8._orig_dtype,
261-
emulate=emulate,
262-
)
263-
264-
go_fp8_reshaped_t_c_t = Float8Tensor(
265-
go_fp8_reshaped._data.t().contiguous().t(),
266-
go_fp8_reshaped._scale,
267-
go_fp8_reshaped._orig_dtype,
268-
emulate=emulate,
269-
)
231+
dL_dX = dL_dX.reshape(*go_fp8.shape[:-1], dL_dX.size(-1))
270232

271233
# calculate dL/dW
234+
x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8.size(-1)).t().contiguous()
235+
go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t()
236+
272237
dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t)
273238
dL_dW = dL_dW.t()
274239

0 commit comments

Comments
 (0)