99import re
1010import unittest
1111import warnings
12+ from itertools import product
1213
1314import pytest
1415
5253is_H100 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
5354
5455
56+ def filtered_parametrize (param_list , filter_func = None ):
57+ """
58+ A decorator that works like pytest.mark.parametrize but filters out
59+ unwanted parameter combinations.
60+
61+ :param param_list: A list of tuples, each containing (arg_name, [arg_values])
62+ :param filter_func: A function that takes a dictionary of parameter names and values,
63+ and returns True for valid combinations, False otherwise
64+ """
65+
66+ def decorator (func ):
67+ arg_names = [param [0 ] for param in param_list ]
68+ arg_values = [param [1 ] for param in param_list ]
69+
70+ all_combinations = product (* arg_values )
71+ if filter_func :
72+ valid_combinations = [
73+ combo
74+ for combo in all_combinations
75+ if filter_func (dict (zip (arg_names , combo )))
76+ ]
77+ else :
78+ valid_combinations = list (all_combinations )
79+
80+ return pytest .mark .parametrize (
81+ argnames = arg_names , argvalues = valid_combinations
82+ )(func )
83+
84+ return decorator
85+
86+
5587def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
5688 assert torch .all (a ._data == b ._data ).item (), "scales are not identical"
5789 assert torch .all (a ._data == b ._data ).item (), "data is not identical"
@@ -230,17 +262,35 @@ def _test_linear_impl(
230262 # verify initialization flags got updated
231263 assert m_fp8 .is_amax_initialized , "Amax was not properly initialized"
232264
233- @pytest .mark .parametrize ("emulate" , [True , False ] if is_H100 else [True ])
234- @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
235- @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
236- @pytest .mark .parametrize (
237- "scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
238- )
239- @pytest .mark .parametrize (
240- "scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
241- )
242- @pytest .mark .parametrize (
243- "scaling_type_dL_dY" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
265+ @staticmethod
266+ def is_valid_combination (params ):
267+ if not params ["emulate" ]:
268+ if not torch .cuda .is_available ():
269+ return False
270+ if torch .cuda .get_device_capability () < (9 , 0 ):
271+ return False
272+
273+ if params ["linear_type" ] == LinearType .DYNAMIC :
274+ return all (
275+ params [key ] == TensorScalingType .DYNAMIC
276+ for key in ["scaling_type_x" , "scaling_type_w" , "scaling_type_dL_dY" ]
277+ )
278+
279+ return True
280+
281+ @filtered_parametrize (
282+ [
283+ ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )]),
284+ ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ]),
285+ ("emulate" , [True , False ] if is_H100 else [True ]),
286+ ("scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]),
287+ ("scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]),
288+ (
289+ "scaling_type_dL_dY" ,
290+ [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ],
291+ ),
292+ ],
293+ filter_func = is_valid_combination ,
244294 )
245295 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
246296 def test_linear_nobias (
@@ -252,28 +302,6 @@ def test_linear_nobias(
252302 scaling_type_w : TensorScalingType ,
253303 scaling_type_dL_dY : TensorScalingType ,
254304 ):
255- if not emulate :
256- if not torch .cuda .is_available ():
257- warnings .warn ("CUDA not available" )
258- pytest .skip ()
259- elif torch .cuda .get_device_capability () < (9 , 0 ):
260- warnings .warn (
261- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
262- )
263- pytest .skip ()
264- if linear_type is LinearType .DYNAMIC :
265- # Only test one combination of scaling types, as they are a no-op
266- # for Float8DynamicLinear. It would be cleaner to split into two
267- # tests, but IMO not worth it since Float8DynamicLinear will be
268- # deleted soon
269- is_all_dynamic = (
270- scaling_type_x is TensorScalingType .DYNAMIC
271- and scaling_type_w is TensorScalingType .DYNAMIC
272- and scaling_type_dL_dY is TensorScalingType .DYNAMIC
273- )
274- if not is_all_dynamic :
275- pytest .skip ()
276-
277305 x = torch .randn (* x_shape , device = "cuda" )
278306 m_ref = nn .Linear (16 , 32 , bias = False , device = "cuda" )
279307 self ._test_linear_impl (
@@ -286,20 +314,20 @@ def test_linear_nobias(
286314 scaling_type_dL_dY ,
287315 )
288316
289- @pytest . mark . parametrize ( "emulate" , [ True , False ] if is_H100 else [ True ])
290- @ pytest . mark . parametrize ( "x_shape" , [( 16 , 16 ), ( 2 , 16 , 16 ), ( 3 , 2 , 16 , 16 )])
291- @ pytest . mark . parametrize ( "linear_type " , [LinearType . DELAYED , LinearType . DYNAMIC ])
292- @ pytest . mark . parametrize (
293- "scaling_type_x " , [TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
294- )
295- @ pytest . mark . parametrize (
296- "scaling_type_w" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
297- )
298- @ pytest . mark . parametrize (
299- "scaling_type_dL_dY" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
300- )
301- @ pytest . mark . parametrize (
302- "linear_dtype" , [ torch . float16 , torch . bfloat16 , torch . float32 ]
317+ @filtered_parametrize (
318+ [
319+ ( "x_shape " , [( 16 , 16 ), ( 2 , 16 , 16 ), ( 3 , 2 , 16 , 16 )]),
320+ ( "linear_type" , [ LinearType . DELAYED , LinearType . DYNAMIC ]),
321+ ( "emulate " , [True , False ] if is_H100 else [ True ]),
322+ ( "scaling_type_x" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]),
323+ ( "scaling_type_w" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]),
324+ (
325+ "scaling_type_dL_dY" ,
326+ [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ],
327+ ),
328+ ( "linear_dtype" , [ torch . float16 , torch . bfloat16 , torch . float32 ]),
329+ ],
330+ filter_func = is_valid_combination ,
303331 )
304332 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
305333 def test_linear_bias (
@@ -312,28 +340,6 @@ def test_linear_bias(
312340 emulate : bool ,
313341 linear_dtype : torch .dtype ,
314342 ):
315- if not emulate :
316- if not torch .cuda .is_available ():
317- warnings .warn ("CUDA not available" )
318- pytest .skip ()
319- elif torch .cuda .get_device_capability () < (9 , 0 ):
320- warnings .warn (
321- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
322- )
323- pytest .skip ()
324- if linear_type is LinearType .DYNAMIC :
325- # Only test one combination of scaling types, as they are a no-op
326- # for Float8DynamicLinear. It would be cleaner to split into two
327- # tests, but IMO not worth it since Float8DynamicLinear will be
328- # deleted soon
329- is_all_dynamic = (
330- scaling_type_x is TensorScalingType .DYNAMIC
331- and scaling_type_w is TensorScalingType .DYNAMIC
332- and scaling_type_dL_dY is TensorScalingType .DYNAMIC
333- )
334- if not is_all_dynamic :
335- pytest .skip ()
336-
337343 x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
338344 m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
339345 self ._test_linear_impl (
0 commit comments