Skip to content

Commit 3fe43d4

Browse files
committed
[wip] make scaling configurable by gemm-argument
Summary: My brain hurts from so many long identifiers... Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 816eaa1 ghstack-comment-id: 2372563439 Pull Request resolved: #940
1 parent 9429eba commit 3fe43d4

File tree

10 files changed

+556
-401
lines changed

10 files changed

+556
-401
lines changed

benchmarks/float8/profile_linear_float8.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@
2727
Float8LinearConfig,
2828
ScalingType,
2929
ScalingGranularity,
30+
Float8LinearRecipeName,
31+
recipe_name_to_linear_config,
3032
)
3133
from torchao.float8.float8_linear_utils import (
3234
convert_to_float8_training,
3335
linear_requires_sync,
3436
sync_float8_amax_and_scale_history,
3537
)
38+
from torchao.testing.float8.test_utils import get_test_float8_linear_config
3639
from torch.profiler import profile, ProfilerActivity, record_function
3740
from utils import (
3841
kernel_name_to_category,
@@ -257,7 +260,7 @@ def main(
257260
scaling_type_input: str = "dynamic",
258261
scaling_type_weight: str = "dynamic",
259262
scaling_type_grad_output: str = "dynamic",
260-
scaling_granularity: str = "tensorwise",
263+
recipe_name: Optional[str] = None,
261264
model_type: str = "linear",
262265
dtype_filter: str = "both",
263266
add_inductor_metadata_to_trace: bool = True,
@@ -269,47 +272,17 @@ def main(
269272
scaling_type_input = ScalingType(scaling_type_input)
270273
scaling_type_weight = ScalingType(scaling_type_weight)
271274
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272-
scaling_granularity = ScalingGranularity(scaling_granularity)
273275

274-
if scaling_type_input is ScalingType.STATIC:
275-
cast_config_input=CastConfig(
276-
scaling_type=scaling_type_input,
277-
static_scale=torch.tensor([1.0], device="cuda"),
278-
scaling_granularity=scaling_granularity,
276+
if recipe_name is None:
277+
config = get_test_float8_linear_config(
278+
scaling_type_input,
279+
scaling_type_weight,
280+
scaling_type_grad_output,
281+
emulate=False,
279282
)
280-
else:
281-
cast_config_input=CastConfig(
282-
scaling_type=scaling_type_input,
283-
scaling_granularity=scaling_granularity,
284-
)
285-
if scaling_type_weight is ScalingType.STATIC:
286-
cast_config_weight=CastConfig(
287-
scaling_type=scaling_type_weight,
288-
static_scale=torch.tensor([1.0], device="cuda"),
289-
scaling_granularity=scaling_granularity,
290-
)
291-
else:
292-
cast_config_weight=CastConfig(
293-
scaling_type=scaling_type_weight,
294-
scaling_granularity=scaling_granularity,
295-
)
296-
if scaling_type_grad_output is ScalingType.STATIC:
297-
cast_config_grad_output=CastConfig(
298-
scaling_type=scaling_type_grad_output,
299-
static_scale=torch.tensor([1.0], device="cuda"),
300-
scaling_granularity=scaling_granularity,
301-
)
302-
else:
303-
cast_config_grad_output=CastConfig(
304-
scaling_type=scaling_type_grad_output,
305-
scaling_granularity=scaling_granularity,
306-
)
307-
308-
config = Float8LinearConfig(
309-
cast_config_input=cast_config_input,
310-
cast_config_weight=cast_config_weight,
311-
cast_config_grad_output=cast_config_grad_output,
312-
)
283+
elif recipe_name is not None:
284+
recipe_name = Float8LinearRecipeName(recipe_name)
285+
config = recipe_name_to_linear_config(recipe_name)
313286

314287
scaling_repr = "_".join(
315288
[

test/float8/test_base.py

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import itertools
99
import random
1010
import re
11+
from typing import List, Tuple
1112
import unittest
1213
import warnings
1314

@@ -27,6 +28,8 @@
2728
Float8LinearConfig,
2829
ScalingGranularity,
2930
ScalingType,
31+
Float8LinearRecipeName,
32+
recipe_name_to_linear_config,
3033
)
3134
from torchao.float8.float8_linear import Float8Linear
3235
from torchao.float8.float8_linear_utils import (
@@ -35,7 +38,10 @@
3538
sync_float8_amax_and_scale_history,
3639
)
3740
from torchao.float8.float8_python_api import addmm_float8_unwrapped
38-
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
41+
from torchao.float8.float8_scaling_utils import (
42+
hp_tensor_to_float8_dynamic,
43+
get_maybe_axiswise_dim,
44+
)
3945
from torchao.float8.float8_tensor import (
4046
Float8Tensor,
4147
GemmInputRole,
@@ -51,6 +57,7 @@
5157
FP8_TYPES,
5258
tensor_to_scale,
5359
)
60+
from torchao.testing.float8.test_utils import get_test_float8_linear_config
5461

5562
random.seed(0)
5663
torch.manual_seed(0)
@@ -59,6 +66,8 @@
5966
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
6067
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
6168

69+
70+
6271
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
6372
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
6473
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -205,9 +214,17 @@ def test_axiswise_reshape(self):
205214
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)
206215

207216
@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
217+
@pytest.mark.parametrize(
218+
"a_granularity,b_granularity",
219+
[
220+
(ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE),
221+
(ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE),
222+
(ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE),
223+
]
224+
)
208225
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
209226
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
210-
def test_axiswise_gemm(self, a_shape):
227+
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
211228
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
212229
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
213230

@@ -218,18 +235,20 @@ def test_axiswise_gemm(self, a_shape):
218235
e4m3_dtype,
219236
linear_mm_config,
220237
gemm_input_role=GemmInputRole.INPUT,
221-
scaling_granularity=ScalingGranularity.AXISWISE,
222-
axiswise_dim=-1,
238+
scaling_granularity=a_granularity,
239+
axiswise_dim=get_maybe_axiswise_dim(-1, a_granularity),
223240
)
224241
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
242+
225243
b_fp8 = hp_tensor_to_float8_dynamic(
226244
b,
227245
e4m3_dtype,
228246
linear_mm_config,
229247
gemm_input_role=GemmInputRole.WEIGHT,
230-
scaling_granularity=ScalingGranularity.AXISWISE,
231-
axiswise_dim=-1, # will be transposed
248+
scaling_granularity=b_granularity,
249+
axiswise_dim=get_maybe_axiswise_dim(-1, b_granularity),
232250
)
251+
233252
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
234253
a = a.reshape(-1, a_shape[-1])
235254
c_ref = torch.mm(a, b.t())
@@ -322,79 +341,64 @@ def _test_linear_impl(
322341
)
323342
@pytest.mark.parametrize(
324343
"scaling_type_grad_output",
325-
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
326-
)
327-
@pytest.mark.parametrize(
328-
"scaling_granularity",
329-
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
344+
[ScalingType.DELAYED, ScalingType.DYNAMIC],
330345
)
331346
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
332347
@pytest.mark.parametrize("linear_bias", [False, True])
333348
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
334-
def test_linear(
349+
def test_linear_from_config_params(
335350
self,
336351
x_shape,
337352
emulate: bool,
338353
scaling_type_input: ScalingType,
339354
scaling_type_weight: ScalingType,
340355
scaling_type_grad_output: ScalingType,
341-
scaling_granularity: ScalingGranularity,
342356
linear_dtype: torch.dtype,
343357
linear_bias: bool,
344358
):
345-
if scaling_granularity is ScalingGranularity.AXISWISE:
346-
if (
347-
scaling_type_input != ScalingType.DYNAMIC or
348-
scaling_type_weight != ScalingType.DYNAMIC or
349-
scaling_type_grad_output != ScalingType.DYNAMIC or
350-
linear_dtype != torch.bfloat16 or
351-
(not is_cuda_9_0)
352-
):
353-
pytest.skip()
354-
355359
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
356360
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
357361

358-
if scaling_type_input is ScalingType.STATIC:
359-
cast_config_input = CastConfig(
360-
scaling_type=scaling_type_input,
361-
scaling_granularity=scaling_granularity,
362-
static_scale=torch.tensor([1.0], device="cuda"),
363-
)
364-
else:
365-
cast_config_input = CastConfig(
366-
scaling_type=scaling_type_input,
367-
scaling_granularity=scaling_granularity,
368-
)
369-
if scaling_type_weight is ScalingType.STATIC:
370-
cast_config_weight = CastConfig(
371-
scaling_type=scaling_type_weight,
372-
scaling_granularity=scaling_granularity,
373-
static_scale=torch.tensor([1.0], device="cuda"),
374-
)
375-
else:
376-
cast_config_weight = CastConfig(
377-
scaling_type=scaling_type_weight,
378-
scaling_granularity=scaling_granularity,
379-
)
380-
if scaling_type_grad_output is ScalingType.STATIC:
381-
cast_config_grad_output = CastConfig(
382-
scaling_type=scaling_type_grad_output,
383-
scaling_granularity=scaling_granularity,
384-
static_scale=torch.tensor([1.0], device="cuda"),
385-
)
386-
else:
387-
cast_config_grad_output = CastConfig(
388-
scaling_type=scaling_type_grad_output,
389-
scaling_granularity=scaling_granularity,
390-
)
362+
config = get_test_float8_linear_config(
363+
scaling_type_input,
364+
scaling_type_weight,
365+
scaling_type_grad_output,
366+
emulate,
367+
)
391368

392-
config = Float8LinearConfig(
393-
cast_config_input=cast_config_input,
394-
cast_config_weight=cast_config_weight,
395-
cast_config_grad_output=cast_config_grad_output,
396-
emulate=emulate,
369+
self._test_linear_impl(
370+
x,
371+
m_ref,
372+
config,
397373
)
374+
375+
# Note: there are now too many config combinations to test all of
376+
# them, so this function factors out some of the recipes which are annoying
377+
# to combine with the main testing function.
378+
# TODO(future PR): make this cleaner.
379+
@pytest.mark.parametrize(
380+
"recipe_name",
381+
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
382+
)
383+
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
384+
@pytest.mark.parametrize("linear_bias", [True, False])
385+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
386+
def test_linear_from_recipe(
387+
self,
388+
recipe_name,
389+
x_shape,
390+
linear_bias: bool,
391+
):
392+
if torch.cuda.get_device_capability() < (9, 0):
393+
warnings.warn(
394+
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
395+
)
396+
pytest.skip()
397+
398+
linear_dtype = torch.bfloat16
399+
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
400+
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
401+
config = recipe_name_to_linear_config(recipe_name)
398402
self._test_linear_impl(
399403
x,
400404
m_ref,

0 commit comments

Comments
 (0)