Skip to content

Commit bf817b1

Browse files
committed
[wip] make scaling configurable by gemm-argument
Summary: My brain hurts from so many long identifiers... Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0cfb3bb ghstack-comment-id: 2372563439 Pull Request resolved: #940
1 parent 008ed5b commit bf817b1

File tree

7 files changed

+506
-226
lines changed

7 files changed

+506
-226
lines changed

benchmarks/float8/profile_linear_float8.py

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
linear_requires_sync,
3434
sync_float8_amax_and_scale_history,
3535
)
36+
from torchao.testing.float8.test_utils import (
37+
scaling_granularities_by_gemm_lcw_recipe,
38+
get_test_float8_linear_config,
39+
)
3640
from torch.profiler import profile, ProfilerActivity, record_function
3741
from utils import (
3842
kernel_name_to_category,
@@ -258,6 +262,8 @@ def main(
258262
scaling_type_weight: str = "dynamic",
259263
scaling_type_grad_output: str = "dynamic",
260264
scaling_granularity: str = "tensorwise",
265+
# TODO(future PR): clean up the override, it's confusing
266+
recipe_override: Optional[str] = None,
261267
model_type: str = "linear",
262268
dtype_filter: str = "both",
263269
add_inductor_metadata_to_trace: bool = True,
@@ -271,45 +277,57 @@ def main(
271277
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272278
scaling_granularity = ScalingGranularity(scaling_granularity)
273279

274-
if scaling_type_input is ScalingType.STATIC:
275-
cast_config_input=CastConfig(
276-
scaling_type=scaling_type_input,
277-
static_scale=torch.tensor([1.0], device="cuda"),
278-
scaling_granularity=scaling_granularity,
279-
)
280-
else:
281-
cast_config_input=CastConfig(
282-
scaling_type=scaling_type_input,
283-
scaling_granularity=scaling_granularity,
284-
)
285-
if scaling_type_weight is ScalingType.STATIC:
286-
cast_config_weight=CastConfig(
287-
scaling_type=scaling_type_weight,
288-
static_scale=torch.tensor([1.0], device="cuda"),
289-
scaling_granularity=scaling_granularity,
290-
)
291-
else:
292-
cast_config_weight=CastConfig(
293-
scaling_type=scaling_type_weight,
294-
scaling_granularity=scaling_granularity,
295-
)
296-
if scaling_type_grad_output is ScalingType.STATIC:
297-
cast_config_grad_output=CastConfig(
298-
scaling_type=scaling_type_grad_output,
299-
static_scale=torch.tensor([1.0], device="cuda"),
300-
scaling_granularity=scaling_granularity,
301-
)
302-
else:
303-
cast_config_grad_output=CastConfig(
304-
scaling_type=scaling_type_grad_output,
305-
scaling_granularity=scaling_granularity,
280+
if recipe_override is None:
281+
282+
if scaling_type_input is ScalingType.STATIC:
283+
cast_config_input=CastConfig(
284+
scaling_type=scaling_type_input,
285+
static_scale=torch.tensor([1.0], device="cuda"),
286+
scaling_granularity=scaling_granularity,
287+
)
288+
else:
289+
cast_config_input=CastConfig(
290+
scaling_type=scaling_type_input,
291+
scaling_granularity=scaling_granularity,
292+
)
293+
if scaling_type_weight is ScalingType.STATIC:
294+
cast_config_weight=CastConfig(
295+
scaling_type=scaling_type_weight,
296+
static_scale=torch.tensor([1.0], device="cuda"),
297+
scaling_granularity=scaling_granularity,
298+
)
299+
else:
300+
cast_config_weight=CastConfig(
301+
scaling_type=scaling_type_weight,
302+
scaling_granularity=scaling_granularity,
303+
)
304+
if scaling_type_grad_output is ScalingType.STATIC:
305+
cast_config_grad_output=CastConfig(
306+
scaling_type=scaling_type_grad_output,
307+
static_scale=torch.tensor([1.0], device="cuda"),
308+
scaling_granularity=scaling_granularity,
309+
)
310+
else:
311+
cast_config_grad_output=CastConfig(
312+
scaling_type=scaling_type_grad_output,
313+
scaling_granularity=scaling_granularity,
314+
)
315+
316+
config = Float8LinearConfig(
317+
cast_config_input=cast_config_input,
318+
cast_config_weight=cast_config_weight,
319+
cast_config_grad_output=cast_config_grad_output,
306320
)
307321

308-
config = Float8LinearConfig(
309-
cast_config_input=cast_config_input,
310-
cast_config_weight=cast_config_weight,
311-
cast_config_grad_output=cast_config_grad_output,
312-
)
322+
elif recipe_override == "lcw":
323+
scaling_granularities_by_gemm = scaling_granularities_by_gemm_lcw_recipe
324+
config = get_test_float8_linear_config(
325+
scaling_type_input,
326+
scaling_type_weight,
327+
scaling_type_grad_output,
328+
scaling_granularities_by_gemm,
329+
False, # emulate
330+
)
313331

314332
scaling_repr = "_".join(
315333
[

test/float8/test_base.py

Lines changed: 74 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import itertools
99
import random
1010
import re
11+
from typing import List, Tuple
1112
import unittest
1213
import warnings
1314

@@ -51,6 +52,10 @@
5152
FP8_TYPES,
5253
tensor_to_scale,
5354
)
55+
from torchao.testing.float8.test_utils import (
56+
scaling_granularities_by_gemm,
57+
get_test_float8_linear_config,
58+
)
5459

5560
random.seed(0)
5661
torch.manual_seed(0)
@@ -59,6 +64,8 @@
5964
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
6065
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
6166

67+
68+
6269
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
6370
assert torch.all(a._data == b._data).item(), "scales are not identical"
6471
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -211,31 +218,52 @@ def test_axiswise_reshape(self):
211218
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)
212219

213220
@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
221+
@pytest.mark.parametrize(
222+
"a_granularity,b_granularity",
223+
[
224+
(ScalingGranularity.AXISWISE, ScalingGranularity.AXISWISE),
225+
(ScalingGranularity.AXISWISE, ScalingGranularity.TENSORWISE),
226+
(ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE),
227+
]
228+
)
214229
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
215230
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
216-
def test_axiswise_gemm(self, a_shape):
231+
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
217232
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
218233
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
219234

220235
linear_mm_config = LinearMMConfig()
221236

237+
if a_granularity is ScalingGranularity.AXISWISE:
238+
a_axiswise_dim = -1
239+
else:
240+
assert a_granularity is ScalingGranularity.TENSORWISE
241+
a_axiswise_dim = None
222242
a_fp8 = hp_tensor_to_float8_dynamic(
223243
a,
224244
e4m3_dtype,
225245
linear_mm_config,
226246
gemm_input_role=GemmInputRole.INPUT,
227-
scaling_granularity=ScalingGranularity.AXISWISE,
228-
axiswise_dim=-1,
247+
scaling_granularity=a_granularity,
248+
axiswise_dim=a_axiswise_dim,
229249
)
230250
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
251+
252+
b_axiswise_dim = 1 if b_granularity is ScalingGranularity.AXISWISE else None
253+
if b_granularity is ScalingGranularity.AXISWISE:
254+
b_axiswise_dim = 1 # will be transposed
255+
else:
256+
assert b_granularity is ScalingGranularity.TENSORWISE
257+
b_axiswise_dim = None
231258
b_fp8 = hp_tensor_to_float8_dynamic(
232259
b,
233260
e4m3_dtype,
234261
linear_mm_config,
235262
gemm_input_role=GemmInputRole.WEIGHT,
236-
scaling_granularity=ScalingGranularity.AXISWISE,
237-
axiswise_dim=1, # will be transposed
263+
scaling_granularity=b_granularity,
264+
axiswise_dim=b_axiswise_dim,
238265
)
266+
239267
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
240268
a = a.reshape(-1, a_shape[-1])
241269
c_ref = torch.mm(a, b.t())
@@ -316,26 +344,33 @@ def _test_linear_impl(
316344
# verify initialization flags got updated
317345
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
318346

319-
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
320-
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
347+
# @pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
348+
@pytest.mark.parametrize("emulate", [False] if is_cuda_8_9 else [True])
349+
# @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
350+
@pytest.mark.parametrize("x_shape", [(16, 16),])
321351
@pytest.mark.parametrize(
322352
"scaling_type_input",
323-
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
353+
# [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
354+
[ScalingType.DYNAMIC]
324355
)
325356
@pytest.mark.parametrize(
326357
"scaling_type_weight",
327-
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
358+
# [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
359+
[ScalingType.DYNAMIC]
328360
)
329361
@pytest.mark.parametrize(
330362
"scaling_type_grad_output",
331-
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
363+
# [ScalingType.DELAYED, ScalingType.DYNAMIC],
364+
[ScalingType.DYNAMIC]
332365
)
333366
@pytest.mark.parametrize(
334-
"scaling_granularity",
335-
[ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE],
367+
"scaling_granularities_by_gemm",
368+
scaling_granularities_by_gemm
336369
)
337-
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
338-
@pytest.mark.parametrize("linear_bias", [False, True])
370+
# @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
371+
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, ])
372+
# @pytest.mark.parametrize("linear_bias", [False, True])
373+
@pytest.mark.parametrize("linear_bias", [False, ])
339374
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
340375
def test_linear(
341376
self,
@@ -344,7 +379,7 @@ def test_linear(
344379
scaling_type_input: ScalingType,
345380
scaling_type_weight: ScalingType,
346381
scaling_type_grad_output: ScalingType,
347-
scaling_granularity: ScalingGranularity,
382+
scaling_granularities_by_gemm: List[List[Tuple[ScalingGranularity, ScalingGranularity]]],
348383
linear_dtype: torch.dtype,
349384
linear_bias: bool,
350385
):
@@ -357,7 +392,23 @@ def test_linear(
357392
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
358393
)
359394
pytest.skip()
360-
if scaling_granularity is ScalingGranularity.AXISWISE:
395+
396+
(
397+
(scaling_granularity_input, scaling_granularity_weight, original_prec_input, original_prec_weight),
398+
(scaling_granularity_grad_output, scaling_granularity_weight_for_grad_input, original_prec_grad_output, original_prec_weight_for_grad_input),
399+
(scaling_granularity_input_for_grad_weight, scaling_granularity_grad_output_for_grad_weight, original_prec_input_for_grad_weight, original_prec_grad_output_for_grad_weight),
400+
) = scaling_granularities_by_gemm
401+
402+
has_any_axiswise_scaling = (
403+
scaling_granularity_input is ScalingGranularity.AXISWISE or
404+
scaling_granularity_weight is ScalingGranularity.AXISWISE or
405+
scaling_granularity_grad_output is ScalingGranularity.AXISWISE or
406+
scaling_granularity_input_for_grad_weight is ScalingGranularity.AXISWISE or
407+
scaling_granularity_weight_for_grad_input is ScalingGranularity.AXISWISE or
408+
scaling_granularity_grad_output_for_grad_weight is ScalingGranularity.AXISWISE
409+
)
410+
411+
if has_any_axiswise_scaling:
361412
if (
362413
scaling_type_input != ScalingType.DYNAMIC or
363414
scaling_type_weight != ScalingType.DYNAMIC or
@@ -370,46 +421,14 @@ def test_linear(
370421
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
371422
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
372423

373-
if scaling_type_input is ScalingType.STATIC:
374-
cast_config_input = CastConfig(
375-
scaling_type=scaling_type_input,
376-
scaling_granularity=scaling_granularity,
377-
static_scale=torch.tensor([1.0], device="cuda"),
378-
)
379-
else:
380-
cast_config_input = CastConfig(
381-
scaling_type=scaling_type_input,
382-
scaling_granularity=scaling_granularity,
383-
)
384-
if scaling_type_weight is ScalingType.STATIC:
385-
cast_config_weight = CastConfig(
386-
scaling_type=scaling_type_weight,
387-
scaling_granularity=scaling_granularity,
388-
static_scale=torch.tensor([1.0], device="cuda"),
389-
)
390-
else:
391-
cast_config_weight = CastConfig(
392-
scaling_type=scaling_type_weight,
393-
scaling_granularity=scaling_granularity,
394-
)
395-
if scaling_type_grad_output is ScalingType.STATIC:
396-
cast_config_grad_output = CastConfig(
397-
scaling_type=scaling_type_grad_output,
398-
scaling_granularity=scaling_granularity,
399-
static_scale=torch.tensor([1.0], device="cuda"),
400-
)
401-
else:
402-
cast_config_grad_output = CastConfig(
403-
scaling_type=scaling_type_grad_output,
404-
scaling_granularity=scaling_granularity,
405-
)
406-
407-
config = Float8LinearConfig(
408-
cast_config_input=cast_config_input,
409-
cast_config_weight=cast_config_weight,
410-
cast_config_grad_output=cast_config_grad_output,
411-
emulate=emulate,
424+
config = get_test_float8_linear_config(
425+
scaling_type_input,
426+
scaling_type_weight,
427+
scaling_type_grad_output,
428+
scaling_granularities_by_gemm,
429+
emulate,
412430
)
431+
413432
self._test_linear_impl(
414433
x,
415434
m_ref,

0 commit comments

Comments
 (0)