@@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None:
5050
5151
5252class TestFloat8Linear :
53- def _test_linear_impl (self , x , m_ref , linear_type : LinearType , emulate : bool ):
54- m_fp8 = get_float8_linear (linear_type , m_ref , emulate )
53+ def _test_linear_impl (
54+ self ,
55+ x ,
56+ m_ref ,
57+ linear_type : LinearType ,
58+ emulate : bool ,
59+ recompute_weight_cast : bool ,
60+ ):
61+ m_fp8 = get_float8_linear (linear_type , m_ref , emulate , recompute_weight_cast )
5562 for _ in range (2 ):
5663 if linear_requires_sync (linear_type ):
5764 sync_float8_amax_and_scale_history (m_fp8 )
@@ -112,7 +119,14 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
112119 @pytest .mark .parametrize ("emulate" , [True , False ])
113120 @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
114121 @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
115- def test_linear_nobias (self , x_shape , linear_type : LinearType , emulate : bool ):
122+ @pytest .mark .parametrize ("recompute_weight_cast" , [True , False ])
123+ def test_linear_nobias (
124+ self ,
125+ x_shape ,
126+ linear_type : LinearType ,
127+ emulate : bool ,
128+ recompute_weight_cast : bool ,
129+ ):
116130 if not emulate :
117131 if not torch .cuda .is_available ():
118132 warnings .warn ("CUDA not available" )
@@ -125,16 +139,22 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
125139
126140 x = torch .randn (* x_shape , device = "cuda" )
127141 m_ref = nn .Linear (16 , 32 , bias = False , device = "cuda" )
128- self ._test_linear_impl (x , m_ref , linear_type , emulate )
142+ self ._test_linear_impl (x , m_ref , linear_type , emulate , recompute_weight_cast )
129143
130144 @pytest .mark .parametrize ("emulate" , [True , False ])
131145 @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
132146 @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
133147 @pytest .mark .parametrize (
134148 "linear_dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ]
135149 )
150+ @pytest .mark .parametrize ("recompute_weight_cast" , [True , False ])
136151 def test_linear_bias (
137- self , x_shape , linear_type : LinearType , emulate : bool , linear_dtype : torch .dtype
152+ self ,
153+ x_shape ,
154+ linear_type : LinearType ,
155+ emulate : bool ,
156+ linear_dtype : torch .dtype ,
157+ recompute_weight_cast : bool ,
138158 ):
139159 if not emulate :
140160 if not torch .cuda .is_available ():
@@ -148,10 +168,10 @@ def test_linear_bias(
148168
149169 x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
150170 m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
151- self ._test_linear_impl (x , m_ref , linear_type , emulate )
171+ self ._test_linear_impl (x , m_ref , linear_type , emulate , recompute_weight_cast )
152172
153173 m = nn .Linear (32 , 16 , device = "cuda" , dtype = linear_dtype )
154- m = Float8Linear .from_float (m , emulate )
174+ m = Float8Linear .from_float (m , emulate , recompute_weight_cast )
155175
156176 # autocast off
157177 x = torch .randn (16 , 32 , device = "cuda" , dtype = linear_dtype )
@@ -184,7 +204,7 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
184204
185205 x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
186206 m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
187- self ._test_linear_impl (x , m_ref , linear_type , emulate )
207+ self ._test_linear_impl (x , m_ref , linear_type , emulate , False )
188208
189209 m = nn .Linear (32 , 16 , device = "cuda" , dtype = linear_dtype )
190210 m = Float8Linear .from_float (m , emulate )
0 commit comments