1111
1212from float8_experimental .float8_tensor import Float8Tensor
1313from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14+ import float8_experimental .config as config
1415
1516
1617class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
@@ -38,6 +39,20 @@ def backward(ctx, gradY):
3839 None ,
3940 )
4041
42+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
43+ """
44+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
45+ """
46+ return module .cast_to_float8 (args [0 ])
47+
48+
49+ def cast_dldy_to_float8_e5m2_forward_hook (module , args , output ):
50+ """
51+ Hook to cast the incoming gradient to `torch.float8_e5m2`
52+ """
53+ new_output = NoopFwToFloat8E5M2Bw .apply (output , module .emulate )
54+ return new_output
55+
4156
4257class Float8DynamicLinear (torch .nn .Linear ):
4358 """
@@ -48,9 +63,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4863 def __init__ (self , * args , ** kwargs ):
4964 super ().__init__ (* args , ** kwargs )
5065 self .add_weight_tag ()
66+ self .use_activation_hooks = config .dynamic_use_activation_hooks
5167
5268 def forward (self , x ):
53- x_fp8 = self .cast_to_float8 (x )
69+ # cast x to float8_e4m3fn
70+ if self .use_activation_hooks :
71+ x_fp8 = x
72+ else :
73+ x_fp8 = self .cast_to_float8 (x )
74+
75+ # cast w to float8_e4m3fn
5476 if getattr (self , "_w_fp8" , None ) is not None : # FSDP handled the cast
5577 w_fp8 = self ._w_fp8
5678 else :
@@ -59,7 +81,10 @@ def forward(self, x):
5981 y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
6082
6183 # Cast gradY to float8_e5m2 during backward
62- y = self .cast_to_float8e5m2_bw (y )
84+ if self .use_activation_hooks :
85+ pass
86+ else :
87+ y = self .cast_to_float8e5m2_bw (y )
6388
6489 return y
6590
@@ -69,6 +94,7 @@ def add_weight_tag(self):
6994 self .weight ._is_fp8_weight = True
7095
7196 def cast_to_float8 (self , inpt_tensor ):
97+ # TODO rename this function to clarify e4m3
7298 scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
7399 return Float8Tensor .to_float8 (
74100 inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -92,4 +118,14 @@ def from_float(cls, mod, emulate: bool = False):
92118 new_mod .bias = mod .bias
93119 new_mod .emulate = emulate
94120 new_mod .add_weight_tag ()
121+
122+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
123+ if new_mod .use_activation_hooks :
124+ # install the hooks
125+ # TODO(future): figure out why using backward pre-hooks does not
126+ # work here:
127+ # 1. repro code: https://gist.github.com/vkuzo/27a3f6ca48e50ba1134b077f0dba254c
128+ # 2. repro output: https://gist.github.com/vkuzo/728eae9dcc627e130829d122daa982e7
129+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
130+ new_mod .register_forward_hook (cast_dldy_to_float8_e5m2_forward_hook )
95131 return new_mod
0 commit comments