Skip to content

Commit f7af150

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Refactor fp8_gemm benchmark to simplify addition of new scaling modes (#500)
Summary: Refactor the `fp8_gemm` benchmark in TritonBench to accept scaling modes as an argument. This diff enables us to extend the `fp8_gemm` benchmark to new scaling modes without adding new benchmarking arguments. Differential Revision: D83617233
1 parent d8b41f2 commit f7af150

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

tritonbench/operators/fp8_gemm/fp8_gemm.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import argparse
2+
3+
from enum import IntEnum
4+
25
import logging
36

47
from typing import Any, Callable, List, Optional
@@ -43,10 +46,15 @@
4346
logger.warning(f"Failed to import TMA: {e}")
4447

4548

49+
class ScalingMode(IntEnum):
50+
TENSOR = 0
51+
ROW = 1
52+
53+
4654
def parse_args(args):
4755
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
4856
parser.add_argument("--llama", action="store_true")
49-
parser.add_argument("--scaling_rowwise", action="store_true")
57+
parser.add_argument("--scaling-mode", type=str, default="tensor")
5058
parser.add_argument("--m", type=int)
5159
parser.add_argument("--k", type=int)
5260
parser.add_argument("--n", type=int)
@@ -55,6 +63,15 @@ def parse_args(args):
5563
return parser.parse_args(args)
5664

5765

66+
def get_scaling_mode_int(scaling_mode: str) -> int:
67+
if scaling_mode == "tensor":
68+
return ScalingMode.TENSOR
69+
elif scaling_mode == "row":
70+
return ScalingMode.ROW
71+
else:
72+
raise ValueError(f"Invalid scaling mode: {scaling_mode}")
73+
74+
5875
class Operator(BenchmarkOperator):
5976
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
6077
DEFAULT_PRECISION = "fp8"
@@ -65,11 +82,12 @@ def __init__(
6582
super().__init__(tb_args, extra_args)
6683
self.extra_args = parse_args(extra_args)
6784

85+
self.scaling_mode_int = get_scaling_mode_int(self.extra_args.scaling_mode).value
86+
6887
def _get_dtype(self):
69-
if self.extra_args.scaling_rowwise:
70-
return torch.bfloat16
71-
else:
88+
if self.scaling_mode_int == ScalingMode.TENSOR:
7289
return torch.float16
90+
return torch.bfloat16
7391

7492
def get_input_iter(self):
7593
def _get_scale_per_tensor(
@@ -102,10 +120,10 @@ def args(m, n, k):
102120
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
103121
b = torch.randn(n, k, device=self.device).to(self._get_dtype())
104122

105-
if self.extra_args.scaling_rowwise:
123+
if self.scaling_mode_int == ScalingMode.ROW:
106124
scale_a = _get_scale_per_row(a)
107125
scale_b = _get_scale_per_row(b)
108-
else:
126+
else: # self.scaling_mode_int == ScalingMode.TENSOR
109127
scale_a = _get_scale_per_tensor(
110128
a, custom_scale=self.extra_args.per_tensor_scale_a
111129
)
@@ -191,7 +209,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191209
scale_a,
192210
scale_b,
193211
self._get_dtype(),
194-
self.extra_args.scaling_rowwise,
212+
self.scaling_mode_int,
195213
)
196214

197215
@register_benchmark(enabled=True)

tritonbench/operators/fp8_gemm/persistent.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
410410
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps
411411

412412

413-
def blackwell_persistent_tma(
414-
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
415-
):
413+
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
416414
configs = matmul_configs_blackwell()
417415

418416
# Check constraints.
@@ -471,7 +469,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
471469
NUM_SMS=NUM_SMS, #
472470
num_stages=configs[shape_dtype]["num_stages"], #
473471
num_warps=configs[shape_dtype]["num_warps"], #
474-
SCALING_ROWWISE=scaling_rowwise,
472+
SCALING_MODE=scaling_mode, #
475473
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
476474
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
477475
)
@@ -504,7 +502,7 @@ def blackwell_persistent_tma_kernel(
504502
GROUP_SIZE_M: tl.constexpr, #
505503
ACC_TYPE: tl.constexpr,
506504
NUM_SMS: tl.constexpr,
507-
SCALING_ROWWISE: tl.constexpr, #
505+
SCALING_MODE: tl.constexpr, #
508506
WARP_SPECIALIZE: tl.constexpr,
509507
EPILOGUE_SUBTILE: tl.constexpr,
510508
): #
@@ -538,7 +536,7 @@ def blackwell_persistent_tma_kernel(
538536
tile_id_c = start_pid - NUM_SMS
539537
num_pid_in_group = GROUP_SIZE_M * num_pid_n
540538

541-
if SCALING_ROWWISE:
539+
if SCALING_MODE == 1:
542540
# For row-wise scaling, we'll use the pointers as-is
543541
scale_a = scale_a_ptr
544542
scale_b = scale_b_ptr
@@ -563,7 +561,7 @@ def blackwell_persistent_tma_kernel(
563561
b_block = b_desc.load([offs_bn, offs_k])
564562
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)
565563

566-
if SCALING_ROWWISE:
564+
if SCALING_MODE == 1:
567565
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
568566
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)
569567

0 commit comments

Comments
 (0)