2222 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
2323
2424
25- from torchao .float8 .config import CastConfig , Float8LinearConfig , ScalingType
25+ from torchao .float8 .config import (
26+ CastConfig ,
27+ Float8LinearConfig ,
28+ ScalingGranularity ,
29+ ScalingType ,
30+ )
2631from torchao .float8 .float8_linear import Float8Linear
2732from torchao .float8 .float8_linear_utils import (
2833 convert_to_float8_training ,
2934 linear_requires_sync ,
3035 sync_float8_amax_and_scale_history ,
3136)
3237from torchao .float8 .float8_python_api import addmm_float8_unwrapped
38+ from torchao .float8 .float8_scaling_utils import hp_tensor_to_float8_dynamic
3339from torchao .float8 .float8_tensor import (
3440 Float8Tensor ,
3541 GemmInputRole ,
5157
5258
5359is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
60+ is_cuda_9_0 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
5461
5562def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
5663 assert torch .all (a ._scale == b ._scale ).item (), "scales are not identical"
5764 assert torch .all (a ._data == b ._data ).item (), "data is not identical"
5865 return True
5966
6067
61- class TestFloat8Tensor ( unittest . TestCase ) :
68+ class TestFloat8Tensor :
6269 def test_preserves_dtype (self ) -> None :
6370 # hp means high precision, lp means low precision
6471 hp_dtypes = (torch .float32 , torch .float16 , torch .bfloat16 )
@@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
6875 x1_s = tensor_to_scale (x1_hp , lp_dtype )
6976 x2_lp = hp_tensor_and_scale_to_float8 (x1_hp , x1_s , lp_dtype )
7077 x3_hp = x2_lp .to_original_precision ()
71- self . assertTrue ( x3_hp .dtype == hp_dtype )
78+ assert x3_hp .dtype == hp_dtype
7279
7380 def test_differentiable_casts (self ) -> None :
7481 lp_dtypes = (e4m3_dtype , e5m2_dtype )
@@ -103,7 +110,7 @@ def test_index_put(self):
103110 fp8_b = hp_tensor_and_scale_to_float8 (b , scale_a , torch .float8_e4m3fn )
104111 fp8_b_bad = hp_tensor_and_scale_to_float8 (b , scale_b , torch .float8_e4m3fn )
105112
106- with self . assertRaises (AssertionError ):
113+ with pytest . raises (AssertionError ):
107114 b [index ] = fp8_a
108115 fp8_b [index ] = a
109116 fp8_b_bad [index ] = fp8_a
@@ -117,7 +124,7 @@ def test_copy_(self):
117124 b = torch .empty (16 , dtype = torch .bfloat16 )
118125 b .copy_ (fp8_a ) # Should work
119126 torch .testing .assert_close (b , fp8_a .to_original_precision ())
120- with self . assertRaises (RuntimeError ):
127+ with pytest . raises (RuntimeError ):
121128 fp8_a .copy_ (b ) # Should fail
122129
123130 fp8_b = Float8Tensor (
@@ -129,6 +136,105 @@ def test_copy_(self):
129136 fp8_b .copy_ (fp8_a )
130137 torch .testing .assert_close (fp8_a ._data , fp8_b ._data )
131138
139+ @pytest .mark .parametrize ("shape" , [(8 , 16 ), (4 , 8 , 16 ), (2 , 4 , 8 , 16 )])
140+ @pytest .mark .parametrize ("axiswise_dim" , [0 , - 1 ])
141+ def test_axiswise_dynamic_cast (self , shape , axiswise_dim ):
142+ a = torch .randn (* shape , dtype = torch .bfloat16 )
143+ linear_mm_config = LinearMMConfig ()
144+ a_fp8 = hp_tensor_to_float8_dynamic (
145+ a ,
146+ e4m3_dtype ,
147+ linear_mm_config ,
148+ scaling_granularity = ScalingGranularity .AXISWISE ,
149+ axiswise_dim = axiswise_dim ,
150+ )
151+ a_dq = a_fp8 .to_original_precision ()
152+ sqnr = compute_error (a , a_dq )
153+ assert sqnr >= 25.0
154+
155+ def test_axiswise_reshape (self ):
156+ a = torch .randn (3 , 5 , 7 , dtype = torch .bfloat16 )
157+ linear_mm_config = LinearMMConfig ()
158+
159+ # if we scale across dim0, we can only reshape to [3, -1]
160+ a_fp8_d0 = hp_tensor_to_float8_dynamic (
161+ a ,
162+ e4m3_dtype ,
163+ linear_mm_config ,
164+ scaling_granularity = ScalingGranularity .AXISWISE ,
165+ axiswise_dim = 0 ,
166+ )
167+ assert list (a_fp8_d0 ._data .shape ) == [3 , 5 , 7 ]
168+ assert list (a_fp8_d0 ._scale .shape ) == [1 , 5 , 7 ]
169+
170+ a_fp8_d0_r = a_fp8_d0 .reshape (3 , - 1 )
171+ assert list (a_fp8_d0_r .shape ) == [3 , 5 * 7 ]
172+ assert list (a_fp8_d0_r ._scale .shape ) == [1 , 5 * 7 ]
173+ # verify numerics did not change
174+ assert torch .allclose (
175+ a_fp8_d0 .to_original_precision (),
176+ a_fp8_d0_r .to_original_precision ().reshape (3 , 5 , 7 ),
177+ atol = 0 ,
178+ rtol = 0 ,
179+ )
180+ with pytest .raises (RuntimeError ):
181+ a_fp8_d0_r2 = a_fp8_d0 .reshape (- 1 , 7 )
182+
183+ # if we scale across dim2, we can only reshape to [-1, 7]
184+ a_fp8_d2 = hp_tensor_to_float8_dynamic (
185+ a ,
186+ e4m3_dtype ,
187+ linear_mm_config ,
188+ scaling_granularity = ScalingGranularity .AXISWISE ,
189+ axiswise_dim = - 1 ,
190+ )
191+ assert list (a_fp8_d2 ._data .shape ) == [3 , 5 , 7 ]
192+ assert list (a_fp8_d2 ._scale .shape ) == [3 , 5 , 1 ]
193+
194+ a_fp8_d2_r = a_fp8_d2 .reshape (- 1 , 7 )
195+ assert list (a_fp8_d2_r .shape ) == [3 * 5 , 7 ]
196+ assert list (a_fp8_d2_r ._scale .shape ) == [3 * 5 , 1 ]
197+ # verify numerics did not change
198+ assert torch .allclose (
199+ a_fp8_d2 .to_original_precision (),
200+ a_fp8_d2_r .to_original_precision ().reshape (3 , 5 , 7 ),
201+ atol = 0 ,
202+ rtol = 0 ,
203+ )
204+ with pytest .raises (RuntimeError ):
205+ a_fp8_d2_r2 = a_fp8_d2 .reshape (3 , - 1 )
206+
207+ @pytest .mark .parametrize ("a_shape" , [(16 , 32 ), (2 , 16 , 32 ), (1 , 2 , 16 , 32 )])
208+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
209+ @unittest .skipIf (not is_cuda_9_0 , "Requires CUDA capability >= 9.0" )
210+ def test_axiswise_gemm (self , a_shape ):
211+ a = torch .randn (* a_shape , dtype = torch .bfloat16 , device = "cuda" )
212+ b = torch .randn (64 , 32 , dtype = torch .bfloat16 , device = "cuda" )
213+
214+ linear_mm_config = LinearMMConfig ()
215+
216+ a_fp8 = hp_tensor_to_float8_dynamic (
217+ a ,
218+ e4m3_dtype ,
219+ linear_mm_config ,
220+ gemm_input_role = GemmInputRole .INPUT ,
221+ scaling_granularity = ScalingGranularity .AXISWISE ,
222+ axiswise_dim = - 1 ,
223+ )
224+ a_fp8 = a_fp8 .reshape (- 1 , a_shape [- 1 ])
225+ b_fp8 = hp_tensor_to_float8_dynamic (
226+ b ,
227+ e4m3_dtype ,
228+ linear_mm_config ,
229+ gemm_input_role = GemmInputRole .WEIGHT ,
230+ scaling_granularity = ScalingGranularity .AXISWISE ,
231+ axiswise_dim = - 1 , # will be transposed
232+ )
233+ c_fp8_compute = torch .mm (a_fp8 , b_fp8 .t ())
234+ a = a .reshape (- 1 , a_shape [- 1 ])
235+ c_ref = torch .mm (a , b .t ())
236+ sqnr = compute_error (c_ref , c_fp8_compute )
237+ assert sqnr >= 25.0
132238
133239
134240class TestFloat8Linear :
0 commit comments