1
1
import argparse
2
+
3
+ from enum import IntEnum
4
+
2
5
import logging
3
6
4
7
from typing import Any , Callable , List , Optional
43
46
logger .warning (f"Failed to import TMA: { e } " )
44
47
45
48
49
+ class ScalingMode (IntEnum ):
50
+ TENSOR = 0
51
+ ROW = 1
52
+
53
+
46
54
def parse_args (args ):
47
55
parser = argparse .ArgumentParser (description = "TritonBench fp8_gemm" )
48
56
parser .add_argument ("--llama" , action = "store_true" )
49
- parser .add_argument ("--scaling_rowwise " , action = "store_true " )
57
+ parser .add_argument ("--scaling-mode " , type = str , default = "tensor " )
50
58
parser .add_argument ("--m" , type = int )
51
59
parser .add_argument ("--k" , type = int )
52
60
parser .add_argument ("--n" , type = int )
@@ -55,6 +63,15 @@ def parse_args(args):
55
63
return parser .parse_args (args )
56
64
57
65
66
+ def get_scaling_mode_int (scaling_mode : str ) -> int :
67
+ if scaling_mode == "tensor" :
68
+ return ScalingMode .TENSOR
69
+ elif scaling_mode == "row" :
70
+ return ScalingMode .ROW
71
+ else :
72
+ raise ValueError (f"Invalid scaling mode: { scaling_mode } " )
73
+
74
+
58
75
class Operator (BenchmarkOperator ):
59
76
DEFAULT_METRICS = ["tflops" , "gbps" , "latency" ]
60
77
DEFAULT_PRECISION = "fp8"
@@ -65,11 +82,12 @@ def __init__(
65
82
super ().__init__ (tb_args , extra_args )
66
83
self .extra_args = parse_args (extra_args )
67
84
85
+ self .scaling_mode_int = get_scaling_mode_int (self .extra_args .scaling_mode ).value
86
+
68
87
def _get_dtype (self ):
69
- if self .extra_args .scaling_rowwise :
70
- return torch .bfloat16
71
- else :
88
+ if self .scaling_mode_int == ScalingMode .TENSOR :
72
89
return torch .float16
90
+ return torch .bfloat16
73
91
74
92
def get_input_iter (self ):
75
93
def _get_scale_per_tensor (
@@ -102,10 +120,10 @@ def args(m, n, k):
102
120
a = torch .randn (m , k , device = self .device ).to (self ._get_dtype ())
103
121
b = torch .randn (n , k , device = self .device ).to (self ._get_dtype ())
104
122
105
- if self .extra_args . scaling_rowwise :
123
+ if self .scaling_mode_int == ScalingMode . ROW :
106
124
scale_a = _get_scale_per_row (a )
107
125
scale_b = _get_scale_per_row (b )
108
- else :
126
+ else : # self.scaling_mode_int == ScalingMode.TENSOR
109
127
scale_a = _get_scale_per_tensor (
110
128
a , custom_scale = self .extra_args .per_tensor_scale_a
111
129
)
@@ -191,7 +209,7 @@ def blackwell_persistent_tma_fp8_gemm(self, a, b, scale_a, scale_b):
191
209
scale_a ,
192
210
scale_b ,
193
211
self ._get_dtype (),
194
- self .extra_args . scaling_rowwise ,
212
+ self .scaling_mode_int ,
195
213
)
196
214
197
215
@register_benchmark (enabled = True )
0 commit comments