@@ -40,6 +40,27 @@ def backward(ctx, gradY):
4040 )
4141
4242
43+ def cast_weight_linear (
44+ x_fp8 : Float8Tensor , weight : torch .Tensor , bias , emulate : bool
45+ ) -> torch .Tensor :
46+ """Cast weight to fp8_e4m3fn and do linear
47+ Why a new function for something that can be inlined?
48+ Because we want to call torch utils checkpoint on this function.
49+ We always want to recompute the cast of the weight to fp8 since we can, trivially
50+ fuse this into the transpose/contiguous of the weight during the backwards.
51+
52+ Args:
53+ x_fp8 (Float8Tensor): input activation in fp8
54+ weight: weight tensor in higher precision
55+ bias: bias tensor in higher precision
56+ emulate (bool): whether to emulate fp8 matmul logic in float32
57+ """
58+ scale = tensor_to_scale (weight , torch .float8_e4m3fn )
59+ w_fp8 = Float8Tensor .to_float8 (weight , scale , torch .float8_e4m3fn , emulate = emulate )
60+ y = torch .nn .functional .linear (x_fp8 , w_fp8 , bias )
61+ return y
62+
63+
4364class Float8DynamicLinear (torch .nn .Linear ):
4465 """
4566 A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
@@ -48,9 +69,14 @@ class Float8DynamicLinear(torch.nn.Linear):
4869
4970 def forward (self , x ):
5071 x_fp8 = self .cast_to_float8 (x )
51- w_fp8 = self .cast_to_float8 (self .weight )
52-
53- y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
72+ y = torch .utils .checkpoint .checkpoint (
73+ cast_weight_linear ,
74+ x_fp8 ,
75+ self .weight ,
76+ self .bias ,
77+ self .emulate ,
78+ use_reentrant = False ,
79+ )
5480
5581 # Cast gradY to float8_e5m2 during backward
5682 y = self .cast_to_float8e5m2_bw (y )
0 commit comments