@@ -50,8 +50,15 @@ def test_preserves_dtype(self) -> None:
5050
5151
5252class TestFloat8Linear :
53- def _test_linear_impl (self , x , m_ref , linear_type : LinearType , emulate : bool ):
54- m_fp8 = get_float8_linear (linear_type , m_ref , emulate )
53+ def _test_linear_impl (
54+ self ,
55+ x ,
56+ m_ref ,
57+ linear_type : LinearType ,
58+ emulate : bool ,
59+ use_activation_hooks : bool = False ,
60+ ):
61+ m_fp8 = get_float8_linear (linear_type , m_ref , emulate , use_activation_hooks )
5562 for _ in range (2 ):
5663 if linear_requires_sync (linear_type ):
5764 sync_float8_amax_and_scale_history (m_fp8 )
@@ -112,7 +119,15 @@ def _test_linear_impl(self, x, m_ref, linear_type: LinearType, emulate: bool):
112119 @pytest .mark .parametrize ("emulate" , [True , False ])
113120 @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
114121 @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
115- def test_linear_nobias (self , x_shape , linear_type : LinearType , emulate : bool ):
122+ @pytest .mark .parametrize ("use_activation_hooks" , [True , False ])
123+ @pytest .mark .usefixtures ("x_fail_activation_hooks_with_delayed" )
124+ def test_linear_nobias (
125+ self ,
126+ x_shape ,
127+ linear_type : LinearType ,
128+ emulate : bool ,
129+ use_activation_hooks : bool ,
130+ ):
116131 if not emulate :
117132 if not torch .cuda .is_available ():
118133 warnings .warn ("CUDA not available" )
@@ -125,16 +140,23 @@ def test_linear_nobias(self, x_shape, linear_type: LinearType, emulate: bool):
125140
126141 x = torch .randn (* x_shape , device = "cuda" )
127142 m_ref = nn .Linear (16 , 32 , bias = False , device = "cuda" )
128- self ._test_linear_impl (x , m_ref , linear_type , emulate )
143+ self ._test_linear_impl (x , m_ref , linear_type , emulate , use_activation_hooks )
129144
130145 @pytest .mark .parametrize ("emulate" , [True , False ])
131146 @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
132147 @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
133148 @pytest .mark .parametrize (
134149 "linear_dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ]
135150 )
151+ @pytest .mark .parametrize ("use_activation_hooks" , [True , False ])
152+ @pytest .mark .usefixtures ("x_fail_activation_hooks_with_delayed" )
136153 def test_linear_bias (
137- self , x_shape , linear_type : LinearType , emulate : bool , linear_dtype : torch .dtype
154+ self ,
155+ x_shape ,
156+ linear_type : LinearType ,
157+ emulate : bool ,
158+ linear_dtype : torch .dtype ,
159+ use_activation_hooks : bool ,
138160 ):
139161 if not emulate :
140162 if not torch .cuda .is_available ():
@@ -148,25 +170,52 @@ def test_linear_bias(
148170
149171 x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
150172 m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
151- self ._test_linear_impl (x , m_ref , linear_type , emulate )
173+ self ._test_linear_impl (x , m_ref , linear_type , emulate , use_activation_hooks )
152174
153- m = nn .Linear (32 , 16 , device = "cuda" , dtype = linear_dtype )
154- m = Float8Linear .from_float (m , emulate )
175+ @pytest .mark .parametrize ("emulate" , [True , False ])
176+ @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
177+ @pytest .mark .parametrize (
178+ "linear_dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ]
179+ )
180+ @pytest .mark .parametrize ("use_activation_hooks" , [True , False ])
181+ @pytest .mark .usefixtures ("x_fail_activation_hooks_with_delayed" )
182+ def test_autocast_outputs (
183+ self ,
184+ linear_type : LinearType ,
185+ emulate : bool ,
186+ linear_dtype : torch .dtype ,
187+ use_activation_hooks : bool ,
188+ ):
189+ if not emulate :
190+ if not torch .cuda .is_available ():
191+ warnings .warn ("CUDA not available" )
192+ pytest .skip ()
193+ elif torch .cuda .get_device_capability () < (9 , 0 ):
194+ warnings .warn (
195+ f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
196+ )
197+ pytest .skip ()
198+
199+ m_ref = nn .Linear (32 , 16 , device = "cuda" , dtype = linear_dtype )
200+ m = get_float8_linear (linear_type , m_ref , emulate , use_activation_hooks )
155201
156202 # autocast off
157203 x = torch .randn (16 , 32 , device = "cuda" , dtype = linear_dtype )
158- sync_float8_amax_and_scale_history (m )
204+ if linear_requires_sync (linear_type ):
205+ sync_float8_amax_and_scale_history (m )
159206 y = m (x )
160207 assert y .dtype == linear_dtype , f"y.dtype is { y .dtype } , expected { linear_dtype } "
161208
162209 # autocast on
163210 with torch .autocast ("cuda" ):
164- sync_float8_amax_and_scale_history (m )
211+ if linear_requires_sync (linear_type ):
212+ sync_float8_amax_and_scale_history (m )
165213 y = m (x )
166214 assert y .dtype == torch .half , f"y.dtype is { y .dtype } , expected { torch .half } "
167215
168216 with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
169- sync_float8_amax_and_scale_history (m )
217+ if linear_requires_sync (linear_type ):
218+ sync_float8_amax_and_scale_history (m )
170219 y = m (x )
171220 assert (
172221 y .dtype == torch .bfloat16
@@ -180,11 +229,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
180229 emulate = (
181230 not torch .cuda .is_available () or torch .cuda .get_device_capability () < (9 , 0 )
182231 )
183- x_shape = (16 , 16 )
184-
185- x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
186- m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
187- self ._test_linear_impl (x , m_ref , linear_type , emulate )
188232
189233 m = nn .Linear (32 , 16 , device = "cuda" , dtype = linear_dtype )
190234 m = Float8Linear .from_float (m , emulate )
0 commit comments