Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 32fbe21

Browse files
committed
existing eager tests pass
1 parent 713d2db commit 32fbe21

File tree

2 files changed

+143
-4
lines changed

2 files changed

+143
-4
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
import torch
11+
from float8_experimental.float8_ops import float8_linear
1112

1213
from float8_experimental.float8_tensor import Float8Tensor
1314
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
@@ -48,12 +49,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4849

4950
def forward(self, x):
5051
x_fp8 = self.cast_to_float8(x)
51-
w_fp8 = self.cast_to_float8(self.weight)
52-
53-
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
52+
# w_fp8 = self.cast_to_float8(self.weight)
5453

54+
# y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
55+
weight_scale = tensor_to_scale(self.weight, torch.float8_e4m3fn)
56+
y = float8_linear.apply(
57+
x_fp8, self.weight, weight_scale, None, self.emulate, False
58+
)
5559
# Cast gradY to float8_e5m2 during backward
5660
y = self.cast_to_float8e5m2_bw(y)
61+
y = y + self.bias if self.bias is not None else y
5762

5863
return y
5964

float8_experimental/float8_ops.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
from typing import Any, Dict
6+
from typing import Any, Dict, Optional
77

88
import torch
99

@@ -75,11 +75,33 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
7575
return a_data, a_scale, b_data, b_scale
7676

7777

78+
def float8_mm_helper(a: Float8Tensor, b: Float8Tensor) -> torch.Tensor:
79+
"""This is a helper function for float8_mm
80+
Args:
81+
a: The first matrix multiplication term.
82+
b: The second matrix multiplication term.
83+
Returns:
84+
torch.Tensor: The result of the matrix multiplication.
85+
"""
86+
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
87+
output_dtype = a._orig_dtype
88+
if a._emulate:
89+
assert a._emulate == b._emulate
90+
return torch.ops.aten.mm_float8_emulated(
91+
a._data, a._scale, b._data, b._scale, output_dtype
92+
)[0]
93+
tensor_out, amax = addmm_float8_unwrapped(
94+
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None, bias=None
95+
)
96+
return tensor_out
97+
98+
7899
@implements([aten.mm.default])
79100
def float8_mm(aten_op, args, kwargs=None):
80101
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
81102
a = args[0]
82103
b = args[1]
104+
return float8_mm_helper(a, b)
83105
a_data, a_scale, b_data, b_scale = preprocess_addmm(a, b)
84106
output_dtype = a._orig_dtype
85107
if a._emulate:
@@ -140,3 +162,115 @@ def autocast_to_copy(aten_op, args, kwargs=None):
140162
return Float8Tensor(
141163
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate
142164
)
165+
166+
167+
class float8_linear(torch.autograd.Function):
168+
"""Custom autograd function for computing torch.nn.Linear on Float8Tensor.
169+
170+
This is needed for a couple reasons, we want to have fine grained control over the
171+
recomputation of casted values for backward.
172+
"""
173+
174+
@staticmethod
175+
def forward(
176+
ctx,
177+
x_fp8: torch.Tensor,
178+
original_weight: torch.Tensor,
179+
weight_scale: torch.Tensor,
180+
weight_amax_buffer: Optional[torch.Tensor],
181+
emulate: bool,
182+
recompute_float8_weight: bool,
183+
):
184+
ctx.save_for_backward(x_fp8)
185+
w_fp8 = Float8Tensor.to_float8(
186+
original_weight,
187+
weight_scale,
188+
torch.float8_e4m3fn,
189+
weight_amax_buffer,
190+
emulate=emulate,
191+
)
192+
if recompute_float8_weight:
193+
# This should be set to True when using traditional fsdp to avoid saving
194+
# saving the unsharded weight for
195+
ctx.save_for_backward(
196+
x_fp8, original_weight, weight_scale, weight_amax_buffer
197+
)
198+
else:
199+
# Does this interact properly with activation checkpointing?
200+
ctx.save_for_backward(x_fp8, w_fp8)
201+
202+
ctx.recompute_float8_weight = recompute_float8_weight
203+
ctx.emulate = emulate
204+
orig_shape = x_fp8._data.shape
205+
x_fp8_reshaped = Float8Tensor(
206+
x_fp8._data.reshape(-1, orig_shape[-1]),
207+
x_fp8._scale,
208+
x_fp8._orig_dtype,
209+
emulate=emulate,
210+
)
211+
212+
w_fp8_t = Float8Tensor(
213+
w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, emulate=emulate
214+
)
215+
216+
res_bits = float8_mm_helper(x_fp8_reshaped, w_fp8_t)
217+
218+
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
219+
return res_bits
220+
221+
@staticmethod
222+
def backward(ctx, go_fp8: torch.Tensor):
223+
if ctx.recompute_float8_weight:
224+
x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors
225+
w_fp8 = Float8Tensor.to_float8(
226+
original_weight,
227+
weight_scale,
228+
torch.float8_e4m3fn,
229+
weight_amax_buffer,
230+
emulate=emulate,
231+
)
232+
else:
233+
x_fp8, w_fp8 = ctx.saved_tensors
234+
235+
emulate = ctx.emulate
236+
237+
go_fp8_orig_shape = go_fp8._data.shape
238+
go_fp8_reshaped = Float8Tensor(
239+
go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
240+
go_fp8._scale,
241+
go_fp8._orig_dtype,
242+
emulate=emulate,
243+
)
244+
245+
w_fp8_t_c_t = Float8Tensor(
246+
w_fp8._data.t().contiguous().t(),
247+
w_fp8._scale,
248+
w_fp8._orig_dtype,
249+
emulate=emulate,
250+
)
251+
252+
# calculate dL/dX
253+
dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t)
254+
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
255+
256+
x_fp8_orig_shape = x_fp8._data.shape
257+
x_fp8_reshaped_t_c = Float8Tensor(
258+
x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
259+
x_fp8._scale,
260+
x_fp8._orig_dtype,
261+
emulate=emulate,
262+
)
263+
264+
go_fp8_reshaped_t_c_t = Float8Tensor(
265+
go_fp8_reshaped._data.t().contiguous().t(),
266+
go_fp8_reshaped._scale,
267+
go_fp8_reshaped._orig_dtype,
268+
emulate=emulate,
269+
)
270+
271+
# calculate dL/dW
272+
dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t)
273+
dL_dW = dL_dW.t()
274+
275+
empty_grads = None, None, None, None, None, None, None, None, None
276+
return dL_dX, dL_dW, *empty_grads

0 commit comments

Comments
 (0)