88import itertools
99import random
1010import re
11+ from typing import List , Tuple
1112import unittest
1213import warnings
1314
2728 Float8LinearConfig ,
2829 ScalingGranularity ,
2930 ScalingType ,
31+ Float8LinearRecipeName ,
32+ recipe_name_to_linear_config ,
3033)
3134from torchao .float8 .float8_linear import Float8Linear
3235from torchao .float8 .float8_linear_utils import (
3538 sync_float8_amax_and_scale_history ,
3639)
3740from torchao .float8 .float8_python_api import addmm_float8_unwrapped
38- from torchao .float8 .float8_scaling_utils import hp_tensor_to_float8_dynamic
41+ from torchao .float8 .float8_scaling_utils import (
42+ hp_tensor_to_float8_dynamic ,
43+ get_maybe_axiswise_dim ,
44+ )
3945from torchao .float8 .float8_tensor import (
4046 Float8Tensor ,
4147 GemmInputRole ,
5157 FP8_TYPES ,
5258 tensor_to_scale ,
5359)
60+ from torchao .testing .float8 .test_utils import get_test_float8_linear_config
5461
5562random .seed (0 )
5663torch .manual_seed (0 )
5966is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
6067is_cuda_9_0 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
6168
69+
70+
6271def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
6372 assert torch .all (a ._scale == b ._scale ).item (), "scales are not identical"
6473 assert torch .all (a ._data == b ._data ).item (), "data is not identical"
@@ -205,9 +214,17 @@ def test_axiswise_reshape(self):
205214 a_fp8_d2_r2 = a_fp8_d2 .reshape (3 , - 1 )
206215
207216 @pytest .mark .parametrize ("a_shape" , [(16 , 32 ), (2 , 16 , 32 ), (1 , 2 , 16 , 32 )])
217+ @pytest .mark .parametrize (
218+ "a_granularity,b_granularity" ,
219+ [
220+ (ScalingGranularity .AXISWISE , ScalingGranularity .AXISWISE ),
221+ (ScalingGranularity .AXISWISE , ScalingGranularity .TENSORWISE ),
222+ (ScalingGranularity .TENSORWISE , ScalingGranularity .AXISWISE ),
223+ ]
224+ )
208225 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
209226 @unittest .skipIf (not is_cuda_9_0 , "Requires CUDA capability >= 9.0" )
210- def test_axiswise_gemm (self , a_shape ):
227+ def test_axiswise_gemm (self , a_shape , a_granularity , b_granularity ):
211228 a = torch .randn (* a_shape , dtype = torch .bfloat16 , device = "cuda" )
212229 b = torch .randn (64 , 32 , dtype = torch .bfloat16 , device = "cuda" )
213230
@@ -218,18 +235,20 @@ def test_axiswise_gemm(self, a_shape):
218235 e4m3_dtype ,
219236 linear_mm_config ,
220237 gemm_input_role = GemmInputRole .INPUT ,
221- scaling_granularity = ScalingGranularity . AXISWISE ,
222- axiswise_dim = - 1 ,
238+ scaling_granularity = a_granularity ,
239+ axiswise_dim = get_maybe_axiswise_dim ( - 1 , a_granularity ) ,
223240 )
224241 a_fp8 = a_fp8 .reshape (- 1 , a_shape [- 1 ])
242+
225243 b_fp8 = hp_tensor_to_float8_dynamic (
226244 b ,
227245 e4m3_dtype ,
228246 linear_mm_config ,
229247 gemm_input_role = GemmInputRole .WEIGHT ,
230- scaling_granularity = ScalingGranularity . AXISWISE ,
231- axiswise_dim = - 1 , # will be transposed
248+ scaling_granularity = b_granularity ,
249+ axiswise_dim = get_maybe_axiswise_dim ( - 1 , b_granularity ),
232250 )
251+
233252 c_fp8_compute = torch .mm (a_fp8 , b_fp8 .t ())
234253 a = a .reshape (- 1 , a_shape [- 1 ])
235254 c_ref = torch .mm (a , b .t ())
@@ -322,79 +341,64 @@ def _test_linear_impl(
322341 )
323342 @pytest .mark .parametrize (
324343 "scaling_type_grad_output" ,
325- [ScalingType .DELAYED , ScalingType .DYNAMIC , ScalingType .STATIC ],
326- )
327- @pytest .mark .parametrize (
328- "scaling_granularity" ,
329- [ScalingGranularity .TENSORWISE , ScalingGranularity .AXISWISE ],
344+ [ScalingType .DELAYED , ScalingType .DYNAMIC ],
330345 )
331346 @pytest .mark .parametrize ("linear_dtype" , [torch .bfloat16 , torch .float32 ])
332347 @pytest .mark .parametrize ("linear_bias" , [False , True ])
333348 @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
334- def test_linear (
349+ def test_linear_from_config_params (
335350 self ,
336351 x_shape ,
337352 emulate : bool ,
338353 scaling_type_input : ScalingType ,
339354 scaling_type_weight : ScalingType ,
340355 scaling_type_grad_output : ScalingType ,
341- scaling_granularity : ScalingGranularity ,
342356 linear_dtype : torch .dtype ,
343357 linear_bias : bool ,
344358 ):
345- if scaling_granularity is ScalingGranularity .AXISWISE :
346- if (
347- scaling_type_input != ScalingType .DYNAMIC or
348- scaling_type_weight != ScalingType .DYNAMIC or
349- scaling_type_grad_output != ScalingType .DYNAMIC or
350- linear_dtype != torch .bfloat16 or
351- (not is_cuda_9_0 )
352- ):
353- pytest .skip ()
354-
355359 x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
356360 m_ref = nn .Linear (16 , 32 , bias = linear_bias , device = "cuda" , dtype = linear_dtype )
357361
358- if scaling_type_input is ScalingType .STATIC :
359- cast_config_input = CastConfig (
360- scaling_type = scaling_type_input ,
361- scaling_granularity = scaling_granularity ,
362- static_scale = torch .tensor ([1.0 ], device = "cuda" ),
363- )
364- else :
365- cast_config_input = CastConfig (
366- scaling_type = scaling_type_input ,
367- scaling_granularity = scaling_granularity ,
368- )
369- if scaling_type_weight is ScalingType .STATIC :
370- cast_config_weight = CastConfig (
371- scaling_type = scaling_type_weight ,
372- scaling_granularity = scaling_granularity ,
373- static_scale = torch .tensor ([1.0 ], device = "cuda" ),
374- )
375- else :
376- cast_config_weight = CastConfig (
377- scaling_type = scaling_type_weight ,
378- scaling_granularity = scaling_granularity ,
379- )
380- if scaling_type_grad_output is ScalingType .STATIC :
381- cast_config_grad_output = CastConfig (
382- scaling_type = scaling_type_grad_output ,
383- scaling_granularity = scaling_granularity ,
384- static_scale = torch .tensor ([1.0 ], device = "cuda" ),
385- )
386- else :
387- cast_config_grad_output = CastConfig (
388- scaling_type = scaling_type_grad_output ,
389- scaling_granularity = scaling_granularity ,
390- )
362+ config = get_test_float8_linear_config (
363+ scaling_type_input ,
364+ scaling_type_weight ,
365+ scaling_type_grad_output ,
366+ emulate ,
367+ )
391368
392- config = Float8LinearConfig (
393- cast_config_input = cast_config_input ,
394- cast_config_weight = cast_config_weight ,
395- cast_config_grad_output = cast_config_grad_output ,
396- emulate = emulate ,
369+ self ._test_linear_impl (
370+ x ,
371+ m_ref ,
372+ config ,
397373 )
374+
375+ # Note: there are now too many config combinations to test all of
376+ # them, so this function factors out some of the recipes which are annoying
377+ # to combine with the main testing function.
378+ # TODO(future PR): make this cleaner.
379+ @pytest .mark .parametrize (
380+ "recipe_name" ,
381+ [Float8LinearRecipeName .ALL_AXISWISE , Float8LinearRecipeName .LW_AXISWISE_WITH_GW_HP ],
382+ )
383+ @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
384+ @pytest .mark .parametrize ("linear_bias" , [True , False ])
385+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
386+ def test_linear_from_recipe (
387+ self ,
388+ recipe_name ,
389+ x_shape ,
390+ linear_bias : bool ,
391+ ):
392+ if torch .cuda .get_device_capability () < (9 , 0 ):
393+ warnings .warn (
394+ f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
395+ )
396+ pytest .skip ()
397+
398+ linear_dtype = torch .bfloat16
399+ x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
400+ m_ref = nn .Linear (16 , 32 , bias = linear_bias , device = "cuda" , dtype = linear_dtype )
401+ config = recipe_name_to_linear_config (recipe_name )
398402 self ._test_linear_impl (
399403 x ,
400404 m_ref ,
0 commit comments