Skip to content

Commit be516f4

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Differential Revision: D83617233
1 parent d8b41f2 commit be516f4

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import argparse
2+
3+
from enum import IntEnum
4+
25
import logging
36

7+
from torch._inductor.kernel.mm import ScalingMode
8+
49
from typing import Any, Callable, List, Optional
510

611
import torch
@@ -46,7 +51,7 @@
4651
def parse_args(args):
4752
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4853
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
54+
parser.add_argument("--scaling-mode", type=str, default="tensor")
5055
parser.add_argument("--m", type=int)
5156
parser.add_argument("--k", type=int)
5257
parser.add_argument("--n", type=int)
@@ -55,6 +60,15 @@ def parse_args(args):
5560
return parser.parse_args(args)
5661

5762

63+
def get_scaling_mode_int(scaling_mode: str) -> int:
64+
if scaling_mode == "tensor":
65+
return ScalingMode.TENSOR
66+
elif scaling_mode == "row":
67+
return ScalingMode.ROW
68+
else:
69+
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
70+
71+
5872
class Operator(BenchmarkOperator):
5973
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
6074
DEFAULT_PRECISION = "fp8"
@@ -65,11 +79,12 @@ def __init__(
6579
super().__init__(tb_args, extra_args)
6680
self.extra_args = parse_args(extra_args)
6781

82+
self.scaling_mode_int = get_scaling_mode_int(self.extra_args.scaling_mode).value
83+
6884
def _get_dtype(self):
69-
if self.extra_args.scaling_rowwise:
70-
return torch.bfloat16
71-
else:
85+
if self.scaling_mode_int == ScalingMode.TENSOR:
7286
return torch.float16
87+
return torch.bfloat16
7388

7489
def get_input_iter(self):
7590
def _get_scale_per_tensor(
@@ -102,10 +117,10 @@ def args(m, n, k):
102117
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
103118
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
104119

105-
if self.extra_args.scaling_rowwise:
120+
if self.scaling_mode_int == ScalingMode.ROW:
106121
scale_a = _get_scale_per_row(a)
107122
scale_b = _get_scale_per_row(b)
108-
else:
123+
else: # self.scaling_mode_int == ScalingMode.TENSOR
109124
scale_a = _get_scale_per_tensor(
110125
a, custom_scale=self.extra_args.per_tensor_scale_a
111126
)
@@ -191,7 +206,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191206
scale_a,
192207
scale_b,
193208
self._get_dtype(),
194-
self.extra_args.scaling_rowwise,
209+
self.scaling_mode_int,
195210
)
196211

197212
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
from enum import IntEnum
2+
13
from functools import lru_cache
4+
5+
from torch._inductor.kernel.mm import ScalingMode
6+
27
from typing import Optional
38

49
import torch
@@ -23,7 +28,6 @@
2328
except (ImportError, IOError, AttributeError):
2429
pass
2530

26-
2731
def _matmul_launch_metadata(grid, kernel, args):
2832
ret = {}
2933
M, N, K = args["M"], args["N"], args["K"]
@@ -410,9 +414,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410414
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411415

412416

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
417+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416418
configs = matmul_configs_blackwell()
417419

418420
# Check constraints.
@@ -471,7 +473,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471473
NUM_SMS=NUM_SMS, #
472474
num_stages=configs[shape_dtype]["num_stages"], #
473475
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
476+
SCALING_MODE=scaling_mode, #
475477
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476478
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477479
)
@@ -504,7 +506,7 @@ def blackwell_persistent_tma_kernel(
504506
GROUP_SIZE_M: tl.constexpr, #
505507
ACC_TYPE: tl.constexpr,
506508
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
509+
SCALING_MODE: tl.constexpr, #
508510
WARP_SPECIALIZE: tl.constexpr,
509511
EPILOGUE_SUBTILE: tl.constexpr,
510512
): #
@@ -538,7 +540,7 @@ def blackwell_persistent_tma_kernel(
538540
tile_id_c = start_pid - NUM_SMS
539541
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540542

541-
if SCALING_ROWWISE:
543+
if SCALING_MODE == ScalingMode.ROW:
542544
# For row-wise scaling, we'll use the pointers as-is
543545
scale_a = scale_a_ptr
544546
scale_b = scale_b_ptr
@@ -563,7 +565,7 @@ def blackwell_persistent_tma_kernel(
563565
b_block = b_desc.load([offs_bn, offs_k])
564566
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565567

566-
if SCALING_ROWWISE:
568+
if SCALING_MODE == ScalingMode.ROW:
567569
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568570
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569571

0 commit comments

Comments
 (0)