diff --git a/tritonbench/operators/fp8_gemm/fp8_gemm.py b/tritonbench/operators/fp8_gemm/fp8_gemm.py index 44953f06..f899226c 100644 --- a/tritonbench/operators/fp8_gemm/fp8_gemm.py +++ b/tritonbench/operators/fp8_gemm/fp8_gemm.py @@ -1,4 +1,5 @@ import argparse + import logging from typing import Any, Callable, List, Optional @@ -7,6 +8,8 @@ import torch._inductor.config as inductor_config import triton +from torch._inductor.kernel.mm import scaling_pairs, ScalingType + from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda @@ -46,7 +49,7 @@ def parse_args(args): parser = argparse.ArgumentParser(description="TritonBench fp8_gemm") parser.add_argument("--llama", action="store_true") - parser.add_argument("--scaling_rowwise", action="store_true") + parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise") parser.add_argument("--m", type=int) parser.add_argument("--k", type=int) parser.add_argument("--n", type=int) @@ -55,6 +58,54 @@ def parse_args(args): return parser.parse_args(args) +def get_scaling_recipe_int(scaling_recipe: str) -> int: + if scaling_recipe == "TensorWise": + return ScalingType.TensorWise + elif scaling_recipe == "RowWise": + return ScalingType.RowWise + else: + raise ValueError(f"Invalid scaling recipe: {scaling_recipe}") + + +def get_scale( + x: torch.Tensor, + scaling_recipe_int: int, + transpose: bool = False, + custom_scale: float = None, +): + def _get_scale_per_tensor( + x: torch.Tensor, custom_scale: float = None + ) -> torch.Tensor: + # For tensor-wise scaling, kernel requires a float32 scale tensor + if custom_scale: + return torch.tensor(custom_scale, dtype=torch.float32, device=x.device) + scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max() + return scale.to(torch.float32) + + def _get_scale_per_row(x: torch.Tensor, transpose: bool = False) -> torch.Tensor: + if transpose: # scale_b.shape should be [1, N] + scale = ( + torch.finfo(torch.float8_e4m3fn).max + / x.abs().max(dim=0, keepdim=True).values + ) + else: # scale_a.shape should be [M, 1] + scale = ( + torch.finfo(torch.float8_e4m3fn).max + / x.abs().max(dim=1, keepdim=True).values + ) + return scale.to( + torch.float32 + ) # For row-wise scaling, kernel requires a float32 scale tensor + + match scaling_recipe_int: + case ScalingType.TensorWise: + return _get_scale_per_tensor(x, custom_scale=custom_scale) + case ScalingType.RowWise: + return _get_scale_per_row(x, transpose=transpose) + case _: + raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}") + + class Operator(BenchmarkOperator): DEFAULT_METRICS = ["tflops", "gbps", "latency"] DEFAULT_PRECISION = "fp8" @@ -66,53 +117,39 @@ def __init__( super().__init__(tb_args, extra_args) self.extra_args = parse_args(extra_args) + scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",") + if (scaling_recipe_a, scaling_recipe_b) not in [ + (a.name, b.name) for a, b in scaling_pairs + ]: + raise ValueError( + f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs." + ) + self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value + self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value + def _get_dtype(self): - if self.extra_args.scaling_rowwise: - return torch.bfloat16 - else: + if ( + self.scaling_recipe_a_int == ScalingType.TensorWise + and self.scaling_recipe_b_int == ScalingType.TensorWise + ): return torch.float16 + return torch.bfloat16 def get_input_iter(self): - def _get_scale_per_tensor( - x: torch.Tensor, custom_scale: float = None - ) -> torch.Tensor: - # For tensor-wise scaling, kernel requires a float32 scale tensor - if custom_scale: - return torch.tensor(custom_scale, dtype=torch.float32, device=x.device) - scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max() - return scale.to(torch.float32) - - def _get_scale_per_row( - x: torch.Tensor, transpose: bool = False - ) -> torch.Tensor: - if transpose: # scale_b.shape should be [1, N] - scale = ( - torch.finfo(torch.float8_e4m3fn).max - / x.abs().max(dim=0, keepdim=True).values - ) - else: # scale_a.shape should be [M, 1] - scale = ( - torch.finfo(torch.float8_e4m3fn).max - / x.abs().max(dim=1, keepdim=True).values - ) - return scale.to( - torch.float32 - ) # For row-wise scaling, kernel requires a float32 scale tensor - def args(m, n, k): a = torch.randn(m, k, device=self.device).to(self._get_dtype()) b = torch.randn(n, k, device=self.device).to(self._get_dtype()) - if self.extra_args.scaling_rowwise: - scale_a = _get_scale_per_row(a) - scale_b = _get_scale_per_row(b) - else: - scale_a = _get_scale_per_tensor( - a, custom_scale=self.extra_args.per_tensor_scale_a - ) - scale_b = _get_scale_per_tensor( - b, custom_scale=self.extra_args.per_tensor_scale_b - ) + scale_a = get_scale( + a, + self.scaling_recipe_a_int, + custom_scale=self.extra_args.per_tensor_scale_a, + ) + scale_b = get_scale( + b, + self.scaling_recipe_b_int, + custom_scale=self.extra_args.per_tensor_scale_b, + ) # Kernels expect dtype=float8_e4m3fn a = a.to(torch.float8_e4m3fn) @@ -192,7 +229,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b): scale_a, scale_b, self._get_dtype(), - self.extra_args.scaling_rowwise, + 0 if self.scaling_recipe_a_int == self.scaling_recipe_b_int == 0 else 1, ) @register_benchmark(enabled=True) diff --git a/tritonbench/operators/fp8_gemm/persistent.py b/tritonbench/operators/fp8_gemm/persistent.py index 5ce97117..9c44e5a6 100644 --- a/tritonbench/operators/fp8_gemm/persistent.py +++ b/tritonbench/operators/fp8_gemm/persistent.py @@ -1,10 +1,13 @@ from functools import lru_cache + from typing import Optional import torch import triton import triton.language as tl +from torch._inductor.kernel.mm import ScalingType + from tritonbench.utils.env_utils import is_cuda from tritonbench.utils.triton_utils import has_experimental_descriptor @@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c): # - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps -def blackwell_persistent_tma( - a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise -): +def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode): configs = matmul_configs_blackwell() # Check constraints. @@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): NUM_SMS=NUM_SMS, # num_stages=configs[shape_dtype]["num_stages"], # num_warps=configs[shape_dtype]["num_warps"], # - SCALING_ROWWISE=scaling_rowwise, + SCALING_MODE=scaling_mode, # WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], # EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], # ) @@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel( GROUP_SIZE_M: tl.constexpr, # ACC_TYPE: tl.constexpr, NUM_SMS: tl.constexpr, - SCALING_ROWWISE: tl.constexpr, # + SCALING_MODE: tl.constexpr, # WARP_SPECIALIZE: tl.constexpr, EPILOGUE_SUBTILE: tl.constexpr, ): # @@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel( tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n - if SCALING_ROWWISE: + if SCALING_MODE == ScalingType.RowWise: # For row-wise scaling, we'll use the pointers as-is scale_a = scale_a_ptr scale_b = scale_b_ptr @@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel( b_block = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32) - if SCALING_ROWWISE: + if SCALING_MODE == ScalingType.RowWise: offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M) offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)