Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit e7a6aa3

Browse files
committed
do less in the backwards
1 parent 20da1c0 commit e7a6aa3

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

float8_experimental/float8_ops.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from float8_experimental.float8_python_api import addmm_float8_unwrapped
11-
from float8_experimental.float8_tensor import Float8Tensor
11+
from float8_experimental.float8_tensor import Float8Tensor, re_construct_float8_weight
1212
from float8_experimental.float8_utils import is_row_major
1313
from torch.utils._pytree import tree_map
1414

@@ -192,7 +192,7 @@ def forward(
192192
# This should be set to True when using traditional fsdp to avoid
193193
# saving the unsharded weight for backwards
194194
ctx.save_for_backward(
195-
x_fp8, original_weight, weight_scale, weight_amax_buffer
195+
x_fp8, original_weight, weight_scale
196196
)
197197
else:
198198
# Does this interact properly with activation checkpointing?
@@ -211,19 +211,13 @@ def forward(
211211
@staticmethod
212212
def backward(ctx, go_fp8: torch.Tensor):
213213
if ctx.recompute_float8_weight:
214-
x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors
215-
w_fp8 = Float8Tensor.to_float8(
216-
original_weight,
217-
weight_scale,
218-
torch.float8_e4m3fn,
219-
weight_amax_buffer,
220-
emulate=ctx.emulate,
221-
)
214+
x_fp8, original_weight, weight_scale = ctx.saved_tensors
215+
w_fp8 = re_construct_float8_weight(original_weight, weight_scale, torch.float8_e4m3fn, emulate=ctx.emulate)
222216
else:
223217
x_fp8, w_fp8 = ctx.saved_tensors
224218

225219
# calculate dL/dX
226-
go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1))
220+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1))
227221
w_fp8_t_c_t = w_fp8.t().contiguous().t()
228222
dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t)
229223
dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1))

float8_experimental/float8_tensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,22 @@ def backward(ctx, g):
4242
return g, None, None, None, None
4343

4444

45+
@torch._dynamo.allow_in_graph
46+
def re_construct_float8_weight(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype, emulate: bool = False):
47+
""" In the backwards of float8_linear we don't need to fill the amax buffer
48+
for the weight tensor since that was done during the forward and we just need to
49+
recast the orignal precision tensor using the scale from the forward
50+
51+
Args:
52+
tensor: the tensor to convert
53+
scale: the scale to use to convert the tensor, from the forward
54+
float8_dtype: the float8 dtype to use
55+
emulate: if true using fp32 emulation for the matmuls, helpful
56+
if you don't have access to h100 hardware.
57+
"""
58+
tensor_scaled = tensor * scale
59+
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
60+
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
4561
@torch._dynamo.allow_in_graph
4662
class FromFloat8ConstrFunc(torch.autograd.Function):
4763
"""

0 commit comments

Comments
 (0)