Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 78 additions & 41 deletions tritonbench/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse

import logging

from typing import Any, Callable, List, Optional
Expand All @@ -7,6 +8,8 @@
import torch._inductor.config as inductor_config
import triton

from torch._inductor.kernel.mm import scaling_pairs, ScalingType

from tritonbench.operators.fp8_gemm.persistent import blackwell_persistent_tma
from tritonbench.utils.env_utils import get_nvidia_gpu_model, is_cuda

Expand Down Expand Up @@ -46,7 +49,7 @@
def parse_args(args):
parser = argparse.ArgumentParser(description="TritonBench fp8_gemm")
parser.add_argument("--llama", action="store_true")
parser.add_argument("--scaling_rowwise", action="store_true")
parser.add_argument("--scaling-pair", type=str, default="TensorWise,TensorWise")
parser.add_argument("--m", type=int)
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
Expand All @@ -55,6 +58,54 @@ def parse_args(args):
return parser.parse_args(args)


def get_scaling_recipe_int(scaling_recipe: str) -> int:
if scaling_recipe == "TensorWise":
return ScalingType.TensorWise
elif scaling_recipe == "RowWise":
return ScalingType.RowWise
else:
raise ValueError(f"Invalid scaling recipe: {scaling_recipe}")


def get_scale(
x: torch.Tensor,
scaling_recipe_int: int,
transpose: bool = False,
custom_scale: float = None,
):
def _get_scale_per_tensor(
x: torch.Tensor, custom_scale: float = None
) -> torch.Tensor:
# For tensor-wise scaling, kernel requires a float32 scale tensor
if custom_scale:
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
return scale.to(torch.float32)

def _get_scale_per_row(x: torch.Tensor, transpose: bool = False) -> torch.Tensor:
if transpose: # scale_b.shape should be [1, N]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=0, keepdim=True).values
)
else: # scale_a.shape should be [M, 1]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=1, keepdim=True).values
)
return scale.to(
torch.float32
) # For row-wise scaling, kernel requires a float32 scale tensor

match scaling_recipe_int:
case ScalingType.TensorWise:
return _get_scale_per_tensor(x, custom_scale=custom_scale)
case ScalingType.RowWise:
return _get_scale_per_row(x, transpose=transpose)
case _:
raise AssertionError(f"Unsupported scaling type {scaling_recipe_int}")


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "gbps", "latency"]
DEFAULT_PRECISION = "fp8"
Expand All @@ -66,53 +117,39 @@ def __init__(
super().__init__(tb_args, extra_args)
self.extra_args = parse_args(extra_args)

scaling_recipe_a, scaling_recipe_b = self.extra_args.scaling_pair.split(",")
if (scaling_recipe_a, scaling_recipe_b) not in [
(a.name, b.name) for a, b in scaling_pairs
]:
raise ValueError(
f"Invalid scaling pair: {scaling_recipe_a}, {scaling_recipe_b}. See torch/_inductor/kernel/mm.py::scaling_pairs for valid pairs."
)
self.scaling_recipe_a_int = get_scaling_recipe_int(scaling_recipe_a).value
self.scaling_recipe_b_int = get_scaling_recipe_int(scaling_recipe_b).value

def _get_dtype(self):
if self.extra_args.scaling_rowwise:
return torch.bfloat16
else:
if (
self.scaling_recipe_a_int == ScalingType.TensorWise
and self.scaling_recipe_b_int == ScalingType.TensorWise
):
return torch.float16
return torch.bfloat16

def get_input_iter(self):
def _get_scale_per_tensor(
x: torch.Tensor, custom_scale: float = None
) -> torch.Tensor:
# For tensor-wise scaling, kernel requires a float32 scale tensor
if custom_scale:
return torch.tensor(custom_scale, dtype=torch.float32, device=x.device)
scale = torch.finfo(torch.float8_e4m3fn).max / x.abs().max()
return scale.to(torch.float32)

def _get_scale_per_row(
x: torch.Tensor, transpose: bool = False
) -> torch.Tensor:
if transpose: # scale_b.shape should be [1, N]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=0, keepdim=True).values
)
else: # scale_a.shape should be [M, 1]
scale = (
torch.finfo(torch.float8_e4m3fn).max
/ x.abs().max(dim=1, keepdim=True).values
)
return scale.to(
torch.float32
) # For row-wise scaling, kernel requires a float32 scale tensor

def args(m, n, k):
a = torch.randn(m, k, device=self.device).to(self._get_dtype())
b = torch.randn(n, k, device=self.device).to(self._get_dtype())

if self.extra_args.scaling_rowwise:
scale_a = _get_scale_per_row(a)
scale_b = _get_scale_per_row(b)
else:
scale_a = _get_scale_per_tensor(
a, custom_scale=self.extra_args.per_tensor_scale_a
)
scale_b = _get_scale_per_tensor(
b, custom_scale=self.extra_args.per_tensor_scale_b
)
scale_a = get_scale(
a,
self.scaling_recipe_a_int,
custom_scale=self.extra_args.per_tensor_scale_a,
)
scale_b = get_scale(
b,
self.scaling_recipe_b_int,
custom_scale=self.extra_args.per_tensor_scale_b,
)

# Kernels expect dtype=float8_e4m3fn
a = a.to(torch.float8_e4m3fn)
Expand Down Expand Up @@ -192,7 +229,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
scale_a,
scale_b,
self._get_dtype(),
self.extra_args.scaling_rowwise,
0 if self.scaling_recipe_a_int == self.scaling_recipe_b_int == 0 else 1,
)

@register_benchmark(enabled=True)
Expand Down
15 changes: 8 additions & 7 deletions tritonbench/operators/fp8_gemm/persistent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from functools import lru_cache

from typing import Optional

import torch
import triton
import triton.language as tl

from torch._inductor.kernel.mm import ScalingType

from tritonbench.utils.env_utils import is_cuda
from tritonbench.utils.triton_utils import has_experimental_descriptor

Expand Down Expand Up @@ -410,9 +413,7 @@ def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c):
# - 1 warp = 32 threads, so each thread block requires 128 / 32 = 4 warps


def blackwell_persistent_tma(
a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_rowwise
):
def blackwell_persistent_tma(a, b, scale_a_ptr, scale_b_ptr, acc_dtype, scaling_mode):
configs = matmul_configs_blackwell()

# Check constraints.
Expand Down Expand Up @@ -471,7 +472,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
NUM_SMS=NUM_SMS, #
num_stages=configs[shape_dtype]["num_stages"], #
num_warps=configs[shape_dtype]["num_warps"], #
SCALING_ROWWISE=scaling_rowwise,
SCALING_MODE=scaling_mode, #
WARP_SPECIALIZE=configs[shape_dtype]["WARP_SPECIALIZE"], #
EPILOGUE_SUBTILE=configs[shape_dtype]["EPILOGUE_SUBTILE"], #
)
Expand Down Expand Up @@ -504,7 +505,7 @@ def blackwell_persistent_tma_kernel(
GROUP_SIZE_M: tl.constexpr, #
ACC_TYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
SCALING_ROWWISE: tl.constexpr, #
SCALING_MODE: tl.constexpr, #
WARP_SPECIALIZE: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
): #
Expand Down Expand Up @@ -538,7 +539,7 @@ def blackwell_persistent_tma_kernel(
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = GROUP_SIZE_M * num_pid_n

if SCALING_ROWWISE:
if SCALING_MODE == ScalingType.RowWise:
# For row-wise scaling, we'll use the pointers as-is
scale_a = scale_a_ptr
scale_b = scale_b_ptr
Expand All @@ -563,7 +564,7 @@ def blackwell_persistent_tma_kernel(
b_block = b_desc.load([offs_bn, offs_k])
accumulator = tl.dot(a_block, b_block.T, accumulator, out_dtype=tl.float32)

if SCALING_ROWWISE:
if SCALING_MODE == ScalingType.RowWise:
offs_scale_m = offs_am + tl.arange(0, BLOCK_SIZE_M)
offs_scale_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)

Expand Down
Loading