Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 855795c

Browse files
committed
[wip] hooks
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 31fba04 commit 855795c

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

float8_experimental/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,11 @@
1616
# according to their microbatching/pipeline parallel setup.
1717
# Note: this is currently a global flag for simplicity and dynamo performance.
1818
weight_cache_enabled = False
19+
20+
#
21+
# Other
22+
#
23+
24+
# If True, dynamic linear uses hooks for activation casting
25+
dynamic_use_activation_hooks = True
26+
# dynamic_use_activation_hooks = False

float8_experimental/dynamic_linear/dynamic_float8_linear.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from float8_experimental.float8_tensor import Float8Tensor
1313
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
14+
import float8_experimental.config as config
1415

1516

1617
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@@ -38,6 +39,20 @@ def backward(ctx, gradY):
3839
None,
3940
)
4041

42+
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
43+
"""
44+
Hook to cast the incoming activation to `torch.float8_e4m3fn`
45+
"""
46+
return module.cast_to_float8(args[0])
47+
48+
49+
def cast_dldy_to_float8_e5m2_forward_hook(module, args, output):
50+
"""
51+
Hook to cast the incoming gradient to `torch.float8_e5m2`
52+
"""
53+
new_output = NoopFwToFloat8E5M2Bw.apply(output, module.emulate)
54+
return new_output
55+
4156

4257
class Float8DynamicLinear(torch.nn.Linear):
4358
"""
@@ -48,9 +63,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4863
def __init__(self, *args, **kwargs):
4964
super().__init__(*args, **kwargs)
5065
self.add_weight_tag()
66+
self.use_activation_hooks = config.dynamic_use_activation_hooks
5167

5268
def forward(self, x):
53-
x_fp8 = self.cast_to_float8(x)
69+
# cast x to float8_e4m3fn
70+
if self.use_activation_hooks:
71+
x_fp8 = x
72+
else:
73+
x_fp8 = self.cast_to_float8(x)
74+
75+
# cast w to float8_e4m3fn
5476
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
5577
w_fp8 = self._w_fp8
5678
else:
@@ -59,7 +81,10 @@ def forward(self, x):
5981
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
6082

6183
# Cast gradY to float8_e5m2 during backward
62-
y = self.cast_to_float8e5m2_bw(y)
84+
if self.use_activation_hooks:
85+
pass
86+
else:
87+
y = self.cast_to_float8e5m2_bw(y)
6388

6489
return y
6590

@@ -69,6 +94,7 @@ def add_weight_tag(self):
6994
self.weight._is_fp8_weight = True
7095

7196
def cast_to_float8(self, inpt_tensor):
97+
# TODO rename this function to clarify e4m3
7298
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
7399
return Float8Tensor.to_float8(
74100
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
@@ -92,4 +118,14 @@ def from_float(cls, mod, emulate: bool = False):
92118
new_mod.bias = mod.bias
93119
new_mod.emulate = emulate
94120
new_mod.add_weight_tag()
121+
122+
new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
123+
if new_mod.use_activation_hooks:
124+
# install the hooks
125+
# TODO(future): figure out why using backward pre-hooks does not
126+
# work here:
127+
# 1. repro code: https://gist.github.com/vkuzo/27a3f6ca48e50ba1134b077f0dba254c
128+
# 2. repro output: https://gist.github.com/vkuzo/728eae9dcc627e130829d122daa982e7
129+
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
130+
new_mod.register_forward_hook(cast_dldy_to_float8_e5m2_forward_hook)
95131
return new_mod

0 commit comments

Comments
 (0)