1111
1212from float8_experimental .float8_tensor import Float8Tensor
1313from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14+ import float8_experimental .config as config
1415
1516
1617class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
@@ -38,6 +39,22 @@ def backward(ctx, gradY):
3839 None ,
3940 )
4041
42+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
43+ """
44+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
45+ """
46+ return module .cast_to_float8 (args [0 ])
47+
48+ def cast_dldy_to_float8_e5m2_backward_pre_hook (module , grad_output ):
49+ """
50+ Hook to cast the incoming gradient to `torch.float8_e5m2`
51+ """
52+ gradY = grad_output [0 ]
53+ gradY_scale = tensor_to_scale (gradY , torch .float8_e5m2 )
54+ gradY_scaled = gradY * gradY_scale
55+ bits_fp8 = to_fp8_saturated (gradY_scaled , torch .float8_e5m2 )
56+ tensor_fp8 = Float8Tensor (bits_fp8 , gradY_scale , gradY .dtype , emulate = module .emulate )
57+ return (tensor_fp8 ,)
4158
4259class Float8DynamicLinear (torch .nn .Linear ):
4360 """
@@ -48,9 +65,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4865 def __init__ (self , * args , ** kwargs ):
4966 super ().__init__ (* args , ** kwargs )
5067 self .add_weight_tag ()
68+ self .use_activation_hooks = config .dynamic_use_activation_hooks
5169
5270 def forward (self , x ):
53- x_fp8 = self .cast_to_float8 (x )
71+ # cast x to float8_e4m3fn
72+ if self .use_activation_hooks :
73+ x_fp8 = x
74+ else :
75+ x_fp8 = self .cast_to_float8 (x )
76+
77+ # cast w to float8_e4m3fn
5478 if getattr (self , "_w_fp8" , None ) is not None : # FSDP handled the cast
5579 w_fp8 = self ._w_fp8
5680 else :
@@ -59,7 +83,10 @@ def forward(self, x):
5983 y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
6084
6185 # Cast gradY to float8_e5m2 during backward
62- y = self .cast_to_float8e5m2_bw (y )
86+ if self .use_activation_hooks :
87+ pass
88+ else :
89+ y = self .cast_to_float8e5m2_bw (y )
6390
6491 return y
6592
@@ -69,6 +96,7 @@ def add_weight_tag(self):
6996 self .weight ._is_fp8_weight = True
7097
7198 def cast_to_float8 (self , inpt_tensor ):
99+ # TODO rename this function to clarify e4m3
72100 scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
73101 return Float8Tensor .to_float8 (
74102 inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -92,4 +120,10 @@ def from_float(cls, mod, emulate: bool = False):
92120 new_mod .bias = mod .bias
93121 new_mod .emulate = emulate
94122 new_mod .add_weight_tag ()
123+
124+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
125+ if new_mod .use_activation_hooks :
126+ # install the hooks
127+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
128+ new_mod .register_full_backward_pre_hook (cast_dldy_to_float8_e5m2_backward_pre_hook )
95129 return new_mod
0 commit comments