1111
1212from float8_experimental .float8_tensor import Float8Tensor
1313from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14+ import float8_experimental .config as config
1415
1516
1617@torch ._dynamo .allow_in_graph
@@ -39,25 +40,54 @@ def backward(ctx, gradY):
3940 None ,
4041 )
4142
43+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
44+ """
45+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
46+ """
47+ return module .cast_to_float8 (args [0 ])
48+
49+ def cast_dldy_to_float8_e5m2_backward_pre_hook (module , grad_output ):
50+ """
51+ Hook to cast the incoming gradient to `torch.float8_e5m2`
52+ """
53+ gradY = grad_output [0 ]
54+ gradY_scale = tensor_to_scale (gradY , torch .float8_e5m2 )
55+ gradY_scaled = gradY * gradY_scale
56+ bits_fp8 = to_fp8_saturated (gradY_scaled , torch .float8_e5m2 )
57+ tensor_fp8 = Float8Tensor (bits_fp8 , gradY_scale , gradY .dtype , emulate = module .emulate )
58+ return (tensor_fp8 ,)
4259
4360class Float8DynamicLinear (torch .nn .Linear ):
4461 """
4562 A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
4663 conversion to fp8 of the input and weight tensors.
4764 """
65+ def __init__ (self , * args , ** kwargs ):
66+ super ().__init__ (* args , ** kwargs )
67+ self .use_activation_hooks = config .dynamic_use_activation_hooks
4868
4969 def forward (self , x ):
50- x_fp8 = self .cast_to_float8 (x )
70+ # cast x to float8_e4m3fn
71+ if self .use_activation_hooks :
72+ x_fp8 = x
73+ else :
74+ x_fp8 = self .cast_to_float8 (x )
75+
76+ # cast w to float8_e4m3fn
5177 w_fp8 = self .cast_to_float8 (self .weight )
5278
5379 y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
5480
5581 # Cast gradY to float8_e5m2 during backward
56- y = self .cast_to_float8e5m2_bw (y )
82+ if self .use_activation_hooks :
83+ pass
84+ else :
85+ y = self .cast_to_float8e5m2_bw (y )
5786
5887 return y
5988
6089 def cast_to_float8 (self , inpt_tensor ):
90+ # TODO rename this function to clarify e4m3
6191 scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
6292 return Float8Tensor .to_float8 (
6393 inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -80,4 +110,9 @@ def from_float(cls, mod, emulate: bool = False):
80110 new_mod .weight = mod .weight
81111 new_mod .bias = mod .bias
82112 new_mod .emulate = emulate
113+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
114+ if new_mod .use_activation_hooks :
115+ # install the hooks
116+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
117+ new_mod .register_full_backward_pre_hook (cast_dldy_to_float8_e5m2_backward_pre_hook )
83118 return new_mod
0 commit comments