Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6df3e55

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add for option to use tensor hooks for Dynamic Linear (#198)
Summary: This is a duplicate of: #170 With more testing, ideally I think we wouldn't have the choice between hooks and modified forwards and just use hooks. However compile does not appear to support this yet Pull Request resolved: #198 Reviewed By: wanchaol Differential Revision: D53287660 Pulled By: drisspg fbshipit-source-id: 727e43e8850f3a480ba87df80c0710516ef45f28
1 parent aa920be commit 6df3e55

File tree

8 files changed

+225
-44
lines changed

8 files changed

+225
-44
lines changed

float8_experimental/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@
1414
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
1515
# option is useful for safety, but not strictly necessary.
1616
enable_pre_and_post_forward = True
17+
18+
# If True, dynamic linear uses hooks for activation casting
19+
# TODO(before land): add test coverage for both cases
20+
# dynamic_use_activation_hooks = True
21+
# dynamic_use_activation_hooks = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
88
"""
99

10+
import float8_experimental.config as config
1011
import torch
1112

12-
from float8_experimental.float8_tensor import Float8Tensor
13+
from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
1314
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
1415

1516

@@ -31,13 +32,27 @@ def forward(
3132

3233
@staticmethod
3334
def backward(ctx, gradY):
34-
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
35-
gradY_scaled = gradY * gradY_scale
36-
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
37-
return (
38-
Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=ctx.emulate),
39-
None,
40-
)
35+
fp8_tensor = to_fp8_no_autograd(gradY, torch.float8_e5m2, ctx.emulate)
36+
return fp8_tensor, None
37+
38+
39+
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
40+
"""
41+
Hook to cast the incoming activation to `torch.float8_e4m3fn`
42+
"""
43+
return module.cast_to_float8_e4m3fn(args[0])
44+
45+
46+
def cast_grad_to_float8_e5m2_backward_forward_hook(module, input, output):
47+
"""This is a forward hook that sends the output of the model through
48+
a no-op in the forward but a cast to float8_e5m2 in the backward.
49+
50+
Args:
51+
module (nn.Module): the module to cast the output of
52+
input (Tensor): the input to the module forward call
53+
output (Tensor): the output of the module forward
54+
"""
55+
return module.cast_to_float8_e5m2_bw(output)
4156

4257

4358
class Float8DynamicLinear(torch.nn.Linear):
@@ -46,38 +61,65 @@ class Float8DynamicLinear(torch.nn.Linear):
4661
conversion to fp8 of the input and weight tensors.
4762
"""
4863

64+
def __init__(self, use_activation_hooks: bool, **super_kwargs):
65+
"""
66+
Args:
67+
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
68+
"""
69+
super().__init__(**super_kwargs)
70+
71+
self.use_activation_hooks = use_activation_hooks
72+
4973
def forward(self, x):
50-
x_fp8 = self.cast_to_float8(x)
51-
w_fp8 = self.cast_to_float8(self.weight)
74+
# cast x to float8_e4m3fn if not using activation hooks
75+
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)
76+
77+
# cast w to float8_e4m3fn
78+
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)
5279

5380
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
5481

55-
# Cast gradY to float8_e5m2 during backward
56-
y = self.cast_to_float8e5m2_bw(y)
82+
# Cast gradY to float8_e5m2 during backward if not using activation hooks
83+
if not self.use_activation_hooks:
84+
y = self.cast_to_float8_e5m2_bw(y)
5785

5886
return y
5987

60-
def cast_to_float8(self, inpt_tensor):
88+
def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
6189
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
6290
return Float8Tensor.to_float8(
6391
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
6492
)
6593

66-
def cast_to_float8e5m2_bw(self, gradY):
94+
def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:
6795
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
6896

6997
@classmethod
70-
def from_float(cls, mod, emulate: bool = False):
98+
def from_float(
99+
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
100+
) -> "Float8DynamicLinear":
71101
"""
72102
Create an nn.Linear with fp8 compute from a regular nn.Linear
73103
74104
Args:
75105
mod (torch.nn.Linear): nn.Linear to convert
76106
emulate (bool): whether to emulate fp8 matmul logic in float32
107+
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
77108
"""
78109
with torch.device("meta"):
79-
new_mod = cls(mod.in_features, mod.out_features, bias=False)
110+
super_kwargs = {
111+
"in_features": mod.in_features,
112+
"out_features": mod.out_features,
113+
"bias": False,
114+
}
115+
new_mod = cls(use_activation_hooks, **super_kwargs)
80116
new_mod.weight = mod.weight
81117
new_mod.bias = mod.bias
82118
new_mod.emulate = emulate
119+
if new_mod.use_activation_hooks:
120+
# install the hooks
121+
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
122+
new_mod.register_forward_hook(
123+
cast_grad_to_float8_e5m2_backward_forward_hook
124+
)
83125
return new_mod

float8_experimental/float8_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,16 @@ def forward(self, x):
305305
return y
306306

307307
@classmethod
308-
def from_float(cls, mod, emulate: bool = False):
308+
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
309309
"""
310310
Create an nn.Linear with fp8 compute from a regular nn.Linear
311311
312312
Args:
313313
mod (torch.nn.Linear): nn.Linear to convert
314314
emulate (bool): whether to emulate fp8 matmul logic in float32
315+
use_activation_hooks (bool): whether to use activation hooks instead of inlining the casting logic
315316
"""
317+
assert not use_activation_hooks, "use_activation_hooks is not supported yet!"
316318
# TODO Follow up! This is a great idea but we need the mixin base to create real
317319
# Tensors and the Linear base to create empty params
318320
# with torch.device("meta"):

float8_experimental/float8_linear_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,30 @@ class LinearType(Enum):
2323

2424

2525
def get_float8_linear(
26-
linear_type: LinearType, linear_ref: torch.nn.Linear, emulate: bool = False
26+
linear_type: LinearType,
27+
linear_ref: torch.nn.Linear,
28+
emulate: bool = False,
29+
use_activation_hooks: bool = False,
2730
):
2831
"""Returns a Float8Linear module of the given type, initialized from linear_ref.
2932
Args:
3033
linear_type: The type of Float8Linear to return.
3134
linear_ref: The linear module to initialize from.
3235
emulate: Whether to emulate the fp8 matmul logic in float32.
36+
use_activation_hooks: Whether to use activation hooks for dynamic linear.
3337
"""
3438
LINEAR_TYPE_MAP = {
3539
LinearType.DELAYED: Float8Linear,
3640
LinearType.DYNAMIC: Float8DynamicLinear,
3741
}
3842
if linear_type not in LINEAR_TYPE_MAP:
3943
raise ValueError(f"linear_type must be one of {LINEAR_TYPE_MAP.keys()}")
40-
44+
if use_activation_hooks and linear_type != LinearType.DYNAMIC:
45+
raise ValueError("use_activation_hooks is only supported for dynamic linear")
4146
return LINEAR_TYPE_MAP[linear_type].from_float(
42-
copy.deepcopy(linear_ref), emulate=emulate
47+
copy.deepcopy(linear_ref),
48+
emulate=emulate,
49+
use_activation_hooks=use_activation_hooks,
4350
)
4451

4552

float8_experimental/float8_tensor.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
import torch
99

10-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
10+
from float8_experimental.float8_utils import (
11+
tensor_to_amax,
12+
tensor_to_scale,
13+
to_fp8_saturated,
14+
)
1115

1216
aten = torch.ops.aten
1317

@@ -170,3 +174,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
170174

171175
# Do not force the Float8Tensor type on the returned tensor
172176
__torch_function__ = torch._C._disabled_torch_function_impl
177+
178+
179+
def to_fp8_no_autograd(
180+
x: torch.Tensor, float8_dtype: torch.dtype, emulate: bool
181+
) -> Float8Tensor:
182+
"""Convert a tensor to float8 without autograd
183+
This is used in multiple places in the codebase to convert a tensor to float8
184+
185+
This function will calculate the scale, do the scaling, and then convert to a Float8Tensor
186+
Args:
187+
x: the tensor to convert
188+
scale: the scale to use to convert the tensor
189+
float8_dtype: the float8 dtype to use
190+
emulate: whether to emulate the matmuls in fp32
191+
"""
192+
x_scale = tensor_to_scale(x, float8_dtype)
193+
x_scaled = x * x_scale
194+
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
195+
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)

test/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
4+
@pytest.fixture
5+
def x_fail_activation_hooks(request):
6+
use_activation_hooks = request.getfixturevalue("use_activation_hooks")
7+
if use_activation_hooks:
8+
request.node.add_marker(
9+
pytest.mark.xfail(reason="use_activation_hooks is not supported for AOT")
10+
)
11+
12+
13+
@pytest.fixture
14+
def x_fail_activation_hooks_with_delayed(request):
15+
linear_type = request.getfixturevalue("linear_type")
16+
use_activation_hooks = request.getfixturevalue("use_activation_hooks")
17+
if use_activation_hooks and linear_type == linear_type.DELAYED:
18+
request.node.add_marker(
19+
pytest.mark.xfail(reason="use_activation_hooks is not supported for AOT")
20+
)

test/test_base.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None:
5050

5151

5252
class TestFloat8Linear:
53-
def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
54-
m_fp8 = get_float8_linear(linear_type, m_ref, emulate)
53+
def _test_linear_impl(
54+
self,
55+
x,
56+
m_ref,
57+
linear_type: LinearType,
58+
emulate: bool,
59+
use_activation_hooks: bool = False,
60+
):
61+
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
5562
for _ in range(2):
5663
if linear_requires_sync(linear_type):
5764
sync_float8_amax_and_scale_history(m_fp8)
@@ -112,7 +119,15 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
112119
@pytest.mark.parametrize("emulate", [True, False])
113120
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
114121
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
115-
def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
122+
@pytest.mark.parametrize("use_activation_hooks", [True, False])
123+
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
124+
def test_linear_nobias(
125+
self,
126+
x_shape,
127+
linear_type: LinearType,
128+
emulate: bool,
129+
use_activation_hooks: bool,
130+
):
116131
if not emulate:
117132
if not torch.cuda.is_available():
118133
warnings.warn("CUDA not available")
@@ -125,16 +140,23 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
125140

126141
x = torch.randn(*x_shape, device="cuda")
127142
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
128-
self._test_linear_impl(x, m_ref, linear_type, emulate)
143+
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
129144

130145
@pytest.mark.parametrize("emulate", [True, False])
131146
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
132147
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
133148
@pytest.mark.parametrize(
134149
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
135150
)
151+
@pytest.mark.parametrize("use_activation_hooks", [True, False])
152+
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
136153
def test_linear_bias(
137-
self, x_shape, linear_type: LinearType, emulate: bool, linear_dtype: torch.dtype
154+
self,
155+
x_shape,
156+
linear_type: LinearType,
157+
emulate: bool,
158+
linear_dtype: torch.dtype,
159+
use_activation_hooks: bool,
138160
):
139161
if not emulate:
140162
if not torch.cuda.is_available():
@@ -148,25 +170,52 @@ def test_linear_bias(
148170

149171
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
150172
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
151-
self._test_linear_impl(x, m_ref, linear_type, emulate)
173+
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks)
152174

153-
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
154-
m = Float8Linear.from_float(m, emulate)
175+
@pytest.mark.parametrize("emulate", [True, False])
176+
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
177+
@pytest.mark.parametrize(
178+
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
179+
)
180+
@pytest.mark.parametrize("use_activation_hooks", [True, False])
181+
@pytest.mark.usefixtures("x_fail_activation_hooks_with_delayed")
182+
def test_autocast_outputs(
183+
self,
184+
linear_type: LinearType,
185+
emulate: bool,
186+
linear_dtype: torch.dtype,
187+
use_activation_hooks: bool,
188+
):
189+
if not emulate:
190+
if not torch.cuda.is_available():
191+
warnings.warn("CUDA not available")
192+
pytest.skip()
193+
elif torch.cuda.get_device_capability() < (9, 0):
194+
warnings.warn(
195+
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
196+
)
197+
pytest.skip()
198+
199+
m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
200+
m = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks)
155201

156202
# autocast off
157203
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
158-
sync_float8_amax_and_scale_history(m)
204+
if linear_requires_sync(linear_type):
205+
sync_float8_amax_and_scale_history(m)
159206
y = m(x)
160207
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"
161208

162209
# autocast on
163210
with torch.autocast("cuda"):
164-
sync_float8_amax_and_scale_history(m)
211+
if linear_requires_sync(linear_type):
212+
sync_float8_amax_and_scale_history(m)
165213
y = m(x)
166214
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"
167215

168216
with torch.autocast("cuda", dtype=torch.bfloat16):
169-
sync_float8_amax_and_scale_history(m)
217+
if linear_requires_sync(linear_type):
218+
sync_float8_amax_and_scale_history(m)
170219
y = m(x)
171220
assert (
172221
y.dtype == torch.bfloat16
@@ -180,11 +229,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
180229
emulate = (
181230
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
182231
)
183-
x_shape = (16, 16)
184-
185-
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
186-
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
187-
self._test_linear_impl(x, m_ref, linear_type, emulate)
188232

189233
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
190234
m = Float8Linear.from_float(m, emulate)

0 commit comments

Comments
 (0)