Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 27a6a7f

Browse files
committed
fix multiple calls to save_for_backward
1 parent 67dca36 commit 27a6a7f

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

float8_experimental/float8_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def forward(
181181
emulate: bool,
182182
recompute_float8_weight: bool,
183183
):
184-
ctx.save_for_backward(x_fp8)
185184
w_fp8 = Float8Tensor.to_float8(
186185
original_weight,
187186
weight_scale,
@@ -191,7 +190,7 @@ def forward(
191190
)
192191
if recompute_float8_weight:
193192
# This should be set to True when using traditional fsdp to avoid saving
194-
# saving the unsharded weight for
193+
# saving the unsharded weight for backwards
195194
ctx.save_for_backward(
196195
x_fp8, original_weight, weight_scale, weight_amax_buffer
197196
)

0 commit comments

Comments
 (0)