diff --git a/python/__pycache__/common.cpython-312.pyc b/python/__pycache__/common.cpython-312.pyc new file mode 100644 index 0000000..9c19320 Binary files /dev/null and b/python/__pycache__/common.cpython-312.pyc differ diff --git a/python/__pycache__/deepgemm_utils.cpython-312.pyc b/python/__pycache__/deepgemm_utils.cpython-312.pyc new file mode 100644 index 0000000..f8f1370 Binary files /dev/null and b/python/__pycache__/deepgemm_utils.cpython-312.pyc differ diff --git a/python/common.py b/python/common.py index b3cdfc7..e6c549a 100644 --- a/python/common.py +++ b/python/common.py @@ -55,6 +55,21 @@ def param_num_to_GB(param_num, ele_size=1): "volume": 80, "intra_node_bw": 180, "inter_node_bw": 39, + }, + "H200-144": { + "volume": 144, + "intra_node_bw": 380, + "inter_node_bw": 39, + }, + "MI300X-192": { + "volume": 192, + "intra_node_bw": 315, + "inter_node_bw": 39, + }, + "MI308X-192": { + "volume": 192, + "intra_node_bw": 315, + "inter_node_bw": 39, } } @@ -96,7 +111,7 @@ def total_nodup_expert_params(self): @dataclass class TestConfig: device_nums: List[int] = None - s: int = 5000 + s: int = 5000 # mean seqlen gpu: str = "H800-80" model_config: ModelConfig = None tp_nums: List[int] = None diff --git a/python/deepgemm_utils.py b/python/deepgemm_utils.py new file mode 100644 index 0000000..ff211ce --- /dev/null +++ b/python/deepgemm_utils.py @@ -0,0 +1,288 @@ +# porting from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/utils.py + +import os +import sys +import time +import torch +import torch.distributed as dist + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, + trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True): + # Conflict with Nsight Systems + using_nsys = os.environ.get('DG_NSYS_PROFILING', False) + + # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + # this avoid thermal throttling while keeping DVFS at max clocks (slight gain vs sleep / more consistent on GH200) + sleep_between_tests = 0.0 + flush_l2_size = int(8e9 // 4) + if os.environ.get('DG_BENCH_DISABLE_L2_FLUSH', False): + flush_l2 = False + if os.environ.get('DG_BENCH_POWER_LIMITED', False): + # if we want to be thermally limited, we need to run many iterations non-stop for a fairly long time + # and spend as little time as possible doing memset and other setup work (80MiB should be enough to flush L2) + num_tests = 2000 + flush_l2_size = int(80e6 // 4) + sleep_val = os.environ.get('DG_BENCH_SLEEP_BETWEEN_TESTS', False) + if sleep_val: + try: + sleep_between_tests = float(sleep_val) + except ValueError: + pass # Keep default + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + lhs @ rhs + dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) + for _ in range(num_tests): + if sleep_between_tests > 0.0: + time.sleep(sleep_between_tests) + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + + if not using_nsys: + profiler.step() + + # Return 1 if using Nsight Systems + if using_nsys: + return 1 + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tupled = isinstance(kernel_names, tuple) + # print(f"bill-dbg: prof_lines: ") + # print(profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100)) + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_times.append(float(time_str.replace(unit, '')) / scale) + break + break + return tuple(kernel_times) if is_tupled else kernel_times[0] + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(tensors): + total = 0 + for t in tensors: + if isinstance(t, tuple): + total += count_bytes(t) + else: + total += t.numel() * t.element_size() + return total + + +_num_sms = None + + +def set_num_sms(num_sms: int) -> None: + """ + Set the maximum SM count for all GEMM kernels to use. + + Arguments: + num_sms: the desired maximum SM count for all GEMM kernels to use. + """ + global _num_sms + assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count + _num_sms = num_sms + + +def get_num_sms() -> int: + """ + Get the current maximum limit of SM count for all GEMM kernels to use. + If the count is never specified, the function will return the number of device SMs. + + Returns: + Current maximum limit of SM count for all GEMM kernels to use. + """ + global _num_sms + if _num_sms is None: + _num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count + return _num_sms + + +def ceil_div(x: int, y: int) -> int: + """ + Perform ceiling division of two integers. + + Args: + x: the dividend. + y: the divisor. + + Returns: + The result of the ceiling division. + """ + return (x + y - 1) // y + + +def get_m_alignment_for_contiguous_layout(): + """ + When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis. + Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well + with GEMM block shape. + + Returns: + Group-level alignment requirement for grouped contiguous layout, which is always 128. + """ + return 128 + + +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return ceil_div(x, alignment) * alignment + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along the M axis + (thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + assert x.dim() in (2, 3) + remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x \ No newline at end of file diff --git a/python/process_table.py b/python/process_table.py index 5b2bc28..416acb5 100644 --- a/python/process_table.py +++ b/python/process_table.py @@ -49,10 +49,10 @@ def process_data(dense_gemm_file: str, group_gemm_file: str, batch_gemm_file: st (group_df['b_mla'] == b_mla) & (group_df['m_per_group'] == m_per_group) & (group_df['matrix_idx'] == 7)]['time_us'].iloc[0]) - down_gemm = int(group_df[(group_df['d'] == d) & - (group_df['m_per_group'] == m_per_group) & - (group_df['b_mla'] == b_mla) & - (group_df['matrix_idx'] == 8)]['time_us'].iloc[0]) + # down_gemm = int(group_df[(group_df['d'] == d) & + # (group_df['m_per_group'] == m_per_group) & + # (group_df['b_mla'] == b_mla) & + # (group_df['matrix_idx'] == 8)]['time_us'].iloc[0]) dispatch_alltoall = int( config.calculate_alltoall_time(d, tp, b_mla, True)) @@ -65,15 +65,15 @@ def process_data(dense_gemm_file: str, group_gemm_file: str, batch_gemm_file: st # 计算两种模式下的层时间 # Two microbatch overlapping t_moe_layer_two = int(2 * (max(dispatch_alltoall, shared_time + qkv_time) + - up_gemm + down_gemm + + up_gemm + max(attn_time + o_time + allreduce, combine_alltoall))) t_dense_layer_two = int( - 2 * (shared_time + qkv_time + up_gemm + down_gemm + attn_time + o_time + allreduce)) + 2 * (shared_time + qkv_time + up_gemm + attn_time + o_time + allreduce)) # Single batch comp-compute overlapping - t_moe_layer_single = int(max(dispatch_alltoall, shared_time) + qkv_time + up_gemm + - max(down_gemm, combine_alltoall) + attn_time + o_time + allreduce) + t_moe_layer_single = int(max(dispatch_alltoall, shared_time) + qkv_time + + max(up_gemm, combine_alltoall) + attn_time + o_time + allreduce) t_dense_layer_single = int( - shared_time + qkv_time + up_gemm + down_gemm + attn_time + o_time + allreduce) + shared_time + qkv_time + up_gemm + attn_time + o_time + allreduce) # 计算TPOT和吞吐量 tpot_two = int((t_moe_layer_two * 58 + t_dense_layer_two * 3) / 1000) @@ -93,7 +93,7 @@ def process_data(dense_gemm_file: str, group_gemm_file: str, batch_gemm_file: st 'O(us)': o_time, 'Shared(us)': shared_time, 'Up_Gemm(us)': up_gemm, - 'Down_Gemm(us)': down_gemm, + 'Down_Gemm(us)': 0, 'Dispatch_AlltoAll(us)': dispatch_alltoall, 'Combine_AlltoAll(us)': combine_alltoall, 'AllReduce(us)': allreduce, diff --git a/python/test_decode_gemms.py b/python/test_decode_gemms.py index 53678ef..d6fa3a9 100644 --- a/python/test_decode_gemms.py +++ b/python/test_decode_gemms.py @@ -7,8 +7,8 @@ from typing import Tuple, Callable from enum import Enum -import deep_gemm -from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor, set_num_sms +# import deep_gemm +from deepgemm_utils import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor, set_num_sms from common import TestConfig @@ -102,6 +102,124 @@ def construct_grouped(num_groups: int, m: int, k: int, n: int, is_masked: bool) return x_fp8, y_fp8, out, ref_out +block_shape = (128, 128) + +from aiter import gemm_a8w8_blockscale + +def test_aiter_gemm_asm(dtype, m, n, k): + dim = (m, n, k) + block_shape_n, block_shape_k = block_shape + scale_n = (n + block_shape_n - 1) // block_shape_n + scale_k = (k + block_shape_k - 1) // block_shape_k + x = (torch.rand((m, k), dtype=torch.float16, device="cuda")/10).to(torch.float8_e4m3fnuz) + weight = (torch.rand( (n, k), dtype=torch.float16, device="cuda")/10).to(torch.float8_e4m3fnuz) + x_scale = torch.rand([m, scale_k], dtype=torch.float32, device="cuda") + w_scale = torch.rand([scale_n, scale_k], dtype=torch.float32, device="cuda") + output = torch.zeros( + [x.shape[0], weight.shape[0]], + dtype=dtype, + device=x.device, + ) + + gemm_a8w8_blockscale(x, weight, x_scale, w_scale, output) + + +from aiter import batched_gemm_a8w8 + +## TODO: fix the error from aiter bmm : RuntimeError: This GEMM is not supported! +def test_aiter_batch_gemm(dtype, b, m, n, k): + dim = (b, m, n, k) + x = torch.randint(-20, 20, (b, m, k), dtype=torch.int8).cuda() + weight = torch.randint(-20, 20, (b, n, k), dtype=torch.int8).cuda() + x_scale = torch.rand([b, m, 1], dtype=torch.float32).cuda() + 1e-6 + w_scale = torch.rand([b, 1, n], dtype=torch.float32).cuda() + 1e-6 + output = torch.zeros( + (b, m , n), + dtype=dtype, + device=x.device, + ) + + batched_gemm_a8w8(x, weight, x_scale, w_scale, output, None) + + +# #TODO: bill fail back to torch.bmm +# # from /sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py:L796 +def test_torch_batch_gemm(dtype, b, m, n, k): + dim = (b, m, n, k) + x = torch.randint(-20, 20, (b, m, k), dtype=dtype).cuda() + weight = torch.randint(-20, 20, (b, k, n), dtype=dtype).cuda() + q_nope_out = torch.bmm(x, weight) + + + + +######################################## fused moe ######################################## + +BLOCK_SIZE_M = 32 +quant_algo = [ + "No", # g1u0/ck(g1ux) support + "int8quant", # g1u1 support + "fp8quant", # g1u1 support + "int8smoothquant", # g1u1/g1u0 support + "fp8smoothquant", # g1u1 support + "wint4afp8smoothquant", # g1u1 support +] + +from aiter import ActivationType +from aiter.fused_moe_bf16_asm import asm_moe, torch_moe, moe_sorting_ck +from aiter.fused_moe_gelu import fused_topk, moe_align_block_size, fused_experts +from aiter import pertoken_quant, ck_moe +from aiter.ops.shuffle import shuffle_weight + + +def test_aiter_fmoe(dtype, token, model_dim, inter_dim, E, topk, quant='No', use_g1u1=True, shared_E=1, activation = ActivationType.Silu): + quant_dtype = torch.float8_e4m3fnuz + + input = torch.randn((token, model_dim), dtype=dtype, device="cuda") + + w13 = torch.randn((E+shared_E, inter_dim*2, model_dim), + dtype=dtype, device="cuda") / 10.0 + + w2 = torch.randn((E+shared_E, model_dim, inter_dim), + dtype=dtype, device="cuda") + score = torch.randn((token, E), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(input, score, topk, True) + + if shared_E > 0: + shared_E_score = 0.5 + s_topk_weights = torch.tensor([[shared_E_score, shared_E_score],] * token, + dtype=torch.float32, + device=input.device) + topk_weights = torch.cat((topk_weights, s_topk_weights), dim=1) + s_topk_ids = torch.tensor([[E, E+1],] * token, + dtype=torch.int32, + device=input.device) + topk_ids = torch.cat((topk_ids, s_topk_ids), dim=1) + + w13, fc1_scale = pertoken_quant( + w13, torch.float, quant_dtype=quant_dtype, dtypeMax=None) + w2, fc2_scale = pertoken_quant( + w2, torch.float, quant_dtype=quant_dtype, dtypeMax=None) + + sp1 = (E+shared_E, inter_dim) + sp2 = (E+shared_E, model_dim) + + + fc1_smooth_scale = None + fc2_smooth_scale = None + + # b implement + w13b = shuffle_weight(w13) + w2b = shuffle_weight(w2) + + + asm_moe(input, w13b, w2b, topk_weights, topk_ids, + fc1_scale, fc2_scale, + fc1_smooth_scale, fc2_smooth_scale, + a16=False, activation=activation) + + + class PerformanceLogger: def __init__(self, base_csv_file: str = 'performance_metrics.csv'): # 为两种不同的日志创建不同的文件 @@ -382,24 +500,25 @@ def add_tp_shapes(tp_vars, test_set): print(updated_set) for m in m_set: for matrix_idx, tp, k, n in updated_set: - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + # x_fp8, y_fp8, out, ref_out = construct(m, k, n) + # deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + # diff = calc_diff(out, ref_out) + # assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' def test_func(): - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + test_aiter_gemm_asm(torch.bfloat16, m, n, k) + # x_fp8, y_fp8, out, ref_out = construct(m, k, n) + # deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) self.run_benchmark(test_func, matrix_idx, m, n, k, - tp=tp, compute_mode=Mode.BASE, tag='fp8_gemm') + tp=tp, compute_mode=Mode.BASE, tag='_ZN2ck27kernel_gemm_xdl_cshuffle_v3INS_42GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3INS_13tensor_l') print() def test_bmm(self, config: TestConfig, *, use_flashinfer_bmm: bool = False) -> None: import torch.cuda.nvtx as nvtx print('Testing batch GEMM:') - if use_flashinfer_bmm: - from flashinfer import bmm_fp8 + # if use_flashinfer_bmm: + # from flashinfer import bmm_fp8 b_and_m_per_groups = config.generate_b_and_m_per_groups() tp_vars = config.get_tp_configs() m_set = sorted( @@ -423,90 +542,105 @@ def add_tp_shapes(tp_vars, test_set): for m in m_set: for matrix_idx, tp, b, k, n in updated_set: - x_bf16, y_bf16, x_fp8, y_fp8, out, ref_out = construct_bmm( - b, m, k, n) - if use_flashinfer_bmm: - bmm_fp8(x_fp8[0], y_fp8[0], x_fp8[1], - y_fp8[1], "torch.bfloat16", out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - def test_func(): - x_bf16, y_bf16, x_fp8, y_fp8, out, ref_out = construct_bmm( - b, m, k, n) - bmm_fp8(x_fp8[0], y_fp8[0], x_fp8[1], - y_fp8[1], "torch.bfloat16", out) - self.run_benchmark(test_func, matrix_idx, - m, n, k, Mode.BATCH, 'cutlass') - else: - def test_func(): - x_bf16, y_bf16, x_fp8, y_fp8, out, ref_out = construct_bmm( - b, m, k, n) - # with nvtx.range("matmul"): - out = torch.bmm(x_bf16, y_bf16) - # torch.cuda.synchronize() - # nvtx.range_pop() - # test_func() - try: - self.run_benchmark( - test_func, matrix_idx, m, n, k, tp=tp, compute_mode=Mode.BATCH, tag='gemm_bf16', batch=b) - except: - # H20 use a strange kernel named "nvjet_tst_176x64_64x7_1x1_v_bz_TNN" to perform - self.run_benchmark( - test_func, matrix_idx, m, n, k, tp=tp, compute_mode=Mode.BATCH, tag='nvjet_tst', batch=b) + + + def test_func(): + # test_aiter_batch_gemm(torch.bfloat16, b, m, n, k) + test_torch_batch_gemm(torch.bfloat16, b, m, n, k) + + self.run_benchmark(test_func, matrix_idx, m, n, k, tp=tp, compute_mode=Mode.BATCH, tag='Cijk_Ailk_Bljk_BBS_BH_Bias_HAS_SAV_UserArgs_MT', batch=b) ## TODO: tag cutlass is wrong!! + + # x_bf16, y_bf16, x_fp8, y_fp8, out, ref_out = construct_bmm( + # b, m, k, n) + # if use_flashinfer_bmm: + # bmm_fp8(x_fp8[0], y_fp8[0], x_fp8[1], + # y_fp8[1], "torch.bfloat16", out) + # diff = calc_diff(out, ref_out) + # assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # def test_func(): + # x_bf16, y_bf16, x_fp8, y_fp8, out, ref_out = construct_bmm( + # b, m, k, n) + # bmm_fp8(x_fp8[0], y_fp8[0], x_fp8[1], + # y_fp8[1], "torch.bfloat16", out) + # self.run_benchmark(test_func, matrix_idx, + # m, n, k, Mode.BATCH, 'cutlass') + # else: + # def test_func(): + # x_bf16, y_bf16, x_fp8, y_fp8, out, ref_out = construct_bmm( + # b, m, k, n) + # # with nvtx.range("matmul"): + # out = torch.bmm(x_bf16, y_bf16) + # # torch.cuda.synchronize() + # # nvtx.range_pop() + # # test_func() + # try: + # self.run_benchmark( + # test_func, matrix_idx, m, n, k, tp=tp, compute_mode=Mode.BATCH, tag='gemm_bf16', batch=b) + # except: + # # H20 use a strange kernel named "nvjet_tst_176x64_64x7_1x1_v_bz_TNN" to perform + # self.run_benchmark( + # test_func, matrix_idx, m, n, k, tp=tp, compute_mode=Mode.BATCH, tag='nvjet_tst', batch=b) + def test_m_grouped_gemm_masked(self, config: TestConfig) -> None: print('Testing grouped masked GEMM:') b_and_m_per_groups = config.generate_b_and_m_per_groups() for d, _, num_groups, b_mla, m_per_group in b_and_m_per_groups: - for matrix_idx, k, n in ((7, 7168, 4096), (8, 2048, 7168)): - masked_m_candidates = list(filter( - lambda candidate: candidate <= m_per_group, - (4, 8, 16, 32, 64, 128, 192, 256, 320, 384) - )) - - # Correctness testing - for i in range(10): - x_fp8, y_fp8, out, ref_out = construct_grouped( - num_groups, m_per_group, k, n, is_masked=True - ) - masked_m = torch.empty( - (num_groups,), device='cuda', dtype=torch.int) - for j in range(num_groups): - masked_m[j] = random.choice(masked_m_candidates) - expected_m = min( - int(masked_m.float().mean()) + 1, m_per_group) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - x_fp8, y_fp8, out, masked_m, expected_m - ) - - for j in range(num_groups): - diff = calc_diff( - out[j, :masked_m[j].item()], - ref_out[j, :masked_m[j].item()] - ) - assert diff < 0.001, ( - f'{m_per_group=}, {k=}, {n=}, {j=}, ' - f'masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' - ) + for matrix_idx, k, n in ((7, 2048, 7168),): + # masked_m_candidates = list(filter( + # lambda candidate: candidate <= m_per_group, + # (4, 8, 16, 32, 64, 128, 192, 256, 320, 384) + # )) + + # # Correctness testing + # for i in range(10): + # x_fp8, y_fp8, out, ref_out = construct_grouped( + # num_groups, m_per_group, k, n, is_masked=True + # ) + # masked_m = torch.empty( + # (num_groups,), device='cuda', dtype=torch.int) + # for j in range(num_groups): + # masked_m[j] = random.choice(masked_m_candidates) + # expected_m = min( + # int(masked_m.float().mean()) + 1, m_per_group) + # deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + # x_fp8, y_fp8, out, masked_m, expected_m + # ) + + # for j in range(num_groups): + # diff = calc_diff( + # out[j, :masked_m[j].item()], + # ref_out[j, :masked_m[j].item()] + # ) + # assert diff < 0.001, ( + # f'{m_per_group=}, {k=}, {n=}, {j=}, ' + # f'masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + # ) def test_func(): - x_fp8, y_fp8, out, ref_out = construct_grouped( - num_groups, m_per_group, k, n, is_masked=True - ) - masked_m = torch.ones( - (num_groups,), device='cuda', dtype=torch.int) * m_per_group - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - x_fp8, y_fp8, out, masked_m, m_per_group - ) + test_aiter_fmoe(dtype=torch.bfloat16, + token=m_per_group, + model_dim=n, + inter_dim=k, + E=num_groups, + topk=8, + quant='fp8quant', + use_g1u1=True, + shared_E=1, + activation=ActivationType.Silu) self.run_benchmark( - test_func, matrix_idx, m_per_group, n, k, tp=1, compute_mode=Mode.GROUP, tag='fp8_gemm', + test_func, matrix_idx, m_per_group, n, k, tp=1, compute_mode=Mode.GROUP, tag='fmoe_fp8_g1u1_subGU_', ## TODO: fp8_gemm needs to change d=d, b_mla=b_mla, num_groups=num_groups, m_per_group=m_per_group ) print() + + + + def parse_args(): parser = argparse.ArgumentParser(description='GEMM Performance Testing') @@ -538,8 +672,8 @@ def main(): torch.manual_seed(0) random.seed(0) - print('Library path:') - print(f' > {deep_gemm.__path__}\n') + # print('Library path:') + # print(f' > {deep_gemm.__path__}\n') output_filename = f"{args.prefix}dense_gemm.csv" if args.prefix else "dense_gemm.csv" output_path = os.path.join(args.output_dir, output_filename) diff --git a/python/test_flash_mla.py b/python/test_flash_mla.py index 6af1820..7651958 100644 --- a/python/test_flash_mla.py +++ b/python/test_flash_mla.py @@ -5,7 +5,14 @@ import torch import triton -from flash_mla import flash_mla_with_kvcache, get_mla_metadata +# from flash_mla import flash_mla_with_kvcache, get_mla_metadata + +from aiter.ops.triton import decode_mla + +from aiter.test_mha_common import attention_ref + +import aiter + from common import TestConfig @@ -52,60 +59,101 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") - - q = torch.randn(b, s_q, h_q, d) - block_size = 64 - block_table = torch.arange( - b * max_seqlen_pad // block_size, dtype=torch.int32 - ).view(b, max_seqlen_pad // block_size) - blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) - for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( - float("nan") - ) - blocked_v = blocked_k[..., :dv] + print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + num_kv_splits = 16 # don't why but sglang force 16.... for triton + + kv_max_sz = 65536 # calculated by rest of mem after weight loaded in frameworks + page_size = 1 + num_page = (kv_max_sz + page_size - 1) // page_size + - tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv + # d = qk_head_dim, h_kv = nhead_kv = 1, qk_head_dim = d, nhead = h_q, v_head_dim = dv + q = torch.randn(b, h_q, d) + + kv_buffer = torch.randn( + (num_page * page_size, h_kv, d), # decode kv head ) - def flash_mla(): - return flash_mla_with_kvcache( + sm_scale = 1.0 / (d**0.5) + + # seq_lens = torch.tensor([ctx_lens for _ in range(b)], dtype=torch.int) + kv_indptr = torch.zeros((b + 1,), dtype=torch.int) + kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0) + kv_indices = torch.randint( + 0, num_page, (kv_indptr[-1].item() + 1,), dtype=torch.int + ) + + # block_size = 64 + # block_table = torch.arange( + # b * max_seqlen_pad // block_size, dtype=torch.int32 + # ).view(b, max_seqlen_pad // block_size) + # blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + # for i in range(b): + # blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( + # float("nan") + # ) + # blocked_v = blocked_k[..., :dv] + + # tile_scheduler_metadata, num_splits = get_mla_metadata( + # cache_seqlens, s_q * h_q // h_kv, h_kv + # ) + + def aiter_flash_mla(): + attn_logits = torch.empty( + (b, h_q, num_kv_splits, dv + 1), + dtype=torch.float32, + ) + + kv_last_page_lens = torch.ones(b, dtype=torch.int) + out_asm = torch.empty((b, h_q, dv), ).fill_(-1) + attn_logits, attn_lse = aiter.mla.mla_decode_fwd( q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, + kv_buffer.view(num_page, page_size, h_kv, d), + out_asm, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale, ) - def ref_mla(): - out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) - lse = torch.empty(b, h_q, s_q, dtype=torch.float32) - for i in range(b): - begin = i * max_seqlen_pad - end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), - h_q=h_q, - h_kv=h_kv, - is_causal=causal, - ) - out[i] = O.transpose(0, 1) - lse[i] = LSE - return out, lse - - out_flash, lse_flash = flash_mla() - out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") - cal_diff(lse_flash, lse_torch, "lse") - - t = triton.testing.do_bench(flash_mla) + + # def flash_mla(): + # return flash_mla_with_kvcache( + # q, + # blocked_k, + # block_table, + # cache_seqlens, + # dv, + # tile_scheduler_metadata, + # num_splits, + # causal=causal, + # ) + + # def ref_mla(): + # out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + # lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + # for i in range(b): + # begin = i * max_seqlen_pad + # end = begin + cache_seqlens[i] + # O, LSE = scaled_dot_product_attention( + # q[i].transpose(0, 1), + # blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + # blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + # h_q=h_q, + # h_kv=h_kv, + # is_causal=causal, + # ) + # out[i] = O.transpose(0, 1) + # lse[i] = LSE + # return out, lse + + # out_flash, lse_flash = flash_mla() + # out_torch, lse_torch = ref_mla() + # cal_diff(out_flash, out_torch, "out") + # cal_diff(lse_flash, lse_torch, "lse") + + t = triton.testing.do_bench(aiter_flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( torch.finfo(q.dtype).bits // 8 @@ -136,7 +184,8 @@ def main(torch_dtype): for b in m_set: for s in [config.s]: for h_q in [math.ceil(config.model_config.q_head / tp) for tp in config.get_tp_configs()]: - for s_q in [1, 2]: # MTP = 1, 2 + # for s_q in [1, 2]: # MTP = 1, 2 + for s_q in [1, ]: # only need to calculate MTP=1 cause MTP=2 is got by calculation from MTP=1 for varlen in [False, True]: test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) diff --git a/result.jsonl b/result.jsonl new file mode 100644 index 0000000..fb019a8 --- /dev/null +++ b/result.jsonl @@ -0,0 +1,2 @@ +{"run_name": "test_run", "batch_size": 1, "input_len": 100, "output_len": 100, "prefill_latency": 0.07541227340698242, "prefill_throughput": 1326.044097097078, "median_decode_latency": 0.023357629776000977, "median_decode_throughput": 42.81256315773357, "total_latency": 2.3878631591796875, "overall_throughput": 83.7568933676697} +{"run_name": "test_run", "batch_size": 1, "input_len": 100, "output_len": 100, "prefill_latency": 0.07527947425842285, "prefill_throughput": 1328.3833473214145, "median_decode_latency": 0.023305416107177734, "median_decode_throughput": 42.908480818414326, "total_latency": 2.3827123641967773, "overall_throughput": 83.93795365535901} diff --git a/results/MI300X-single-batch-comp-comm-overlapping.csv b/results/MI300X-single-batch-comp-comm-overlapping.csv new file mode 100644 index 0000000..f5f184b --- /dev/null +++ b/results/MI300X-single-batch-comp-comm-overlapping.csv @@ -0,0 +1,171 @@ +d,tp,b_mla,QKV(us),ATTN(us),O(us),Shared(us),Up_Gemm(us),Down_Gemm(us),Dispatch_AlltoAll(us),Combine_AlltoAll(us),AllReduce(us),t_{dense_layer}(us),t_{moe_layer}(us),TPOT(ms),Single-Device Throughput(Tokens/s) +16,1,8,49,112,124,41,215,0,5,5,0,541,541,33,242 +16,1,16,49,214,115,33,226,0,5,5,0,637,637,38,421 +16,1,32,59,334,118,35,228,0,5,10,0,774,774,47,680 +16,1,64,65,655,119,36,230,0,10,21,0,1105,1105,67,955 +16,1,128,92,1255,171,38,282,0,21,43,0,1838,1838,112,1142 +24,1,8,49,112,124,41,171,0,5,5,0,497,497,30,266 +24,1,16,49,214,115,33,173,0,5,10,0,584,584,35,457 +24,1,32,59,334,118,35,187,0,10,21,0,733,733,44,727 +24,1,64,65,655,119,36,190,0,21,43,0,1065,1065,64,1000 +24,1,128,92,1255,171,38,272,0,43,87,0,1828,1833,111,1153 +32,1,8,49,112,124,41,149,0,5,8,0,475,475,28,285 +32,1,16,49,214,115,33,151,0,8,16,0,562,562,34,470 +32,1,32,59,334,118,35,170,0,16,32,0,716,716,43,744 +32,1,64,65,655,119,36,172,0,32,65,0,1047,1047,63,1015 +32,1,128,92,1255,171,38,243,0,65,131,0,1799,1826,111,1153 +32,1,256,136,2490,228,71,479,0,131,262,0,3404,3464,211,1213 +48,1,8,49,112,124,41,130,0,6,13,0,456,456,27,296 +48,1,16,49,214,115,33,137,0,13,27,0,548,548,33,484 +48,1,32,59,334,118,35,151,0,27,54,0,697,697,42,761 +48,1,64,65,655,119,36,214,0,54,109,0,1089,1107,67,955 +48,1,128,92,1255,171,38,257,0,109,219,0,1813,1884,114,1122 +48,1,256,136,2490,228,71,486,0,219,438,0,3411,3559,216,1185 +72,1,8,49,112,124,41,114,0,10,21,0,440,440,26,307 +72,1,16,49,214,115,33,117,0,21,43,0,528,528,32,500 +72,1,32,59,334,118,35,133,0,43,87,0,679,687,41,780 +72,1,64,65,655,119,36,195,0,87,175,0,1070,1121,68,941 +72,1,128,92,1255,171,38,213,0,175,350,0,1769,2043,123,1040 +72,1,256,136,2490,228,71,449,0,350,701,0,3374,3905,236,1084 +96,1,8,49,112,124,41,104,0,10,21,0,430,430,26,307 +96,1,16,49,214,115,33,116,0,21,43,0,527,527,32,500 +96,1,32,59,334,118,35,175,0,43,87,0,721,729,44,727 +96,1,64,65,655,119,36,181,0,87,175,0,1056,1107,67,955 +96,1,128,92,1255,171,38,262,0,175,350,0,1818,2043,123,1040 +96,1,256,136,2490,228,71,514,0,350,701,0,3439,3905,236,1084 +144,1,8,49,112,124,41,96,0,10,21,0,422,422,25,320 +144,1,16,49,214,115,33,108,0,21,43,0,519,519,31,516 +144,1,32,59,334,118,35,160,0,43,87,0,706,714,43,744 +144,1,64,65,655,119,36,162,0,87,175,0,1037,1101,66,969 +144,1,128,92,1255,171,38,207,0,175,350,0,1763,2043,123,1040 +144,1,256,136,2490,228,71,438,0,350,701,0,3363,3905,236,1084 +288,1,8,49,112,124,41,101,0,10,21,0,427,427,26,307 +288,1,16,49,214,115,33,146,0,21,43,0,557,557,33,484 +288,1,32,59,334,118,35,143,0,43,87,0,689,697,42,761 +288,1,64,65,655,119,36,155,0,87,175,0,1030,1101,66,969 +288,1,128,92,1255,171,38,247,0,175,350,0,1803,2043,123,1040 +288,1,256,136,2490,228,71,521,0,350,701,0,3446,3905,236,1084 +16,2,16,44,112,35,33,215,0,5,5,5,444,444,27,296 +16,2,32,48,212,37,35,225,0,5,5,5,562,562,34,470 +16,2,64,54,365,43,36,228,0,5,10,5,731,731,44,727 +16,2,128,64,703,76,38,231,0,10,21,9,1121,1121,68,941 +24,2,8,46,76,42,41,170,0,5,5,5,380,380,23,173 +24,2,16,44,112,35,33,172,0,5,5,5,401,401,24,333 +24,2,32,48,212,37,35,174,0,5,10,5,511,511,31,516 +24,2,64,54,365,43,36,188,0,10,21,5,691,691,42,761 +24,2,128,64,703,76,38,191,0,21,43,9,1081,1081,65,984 +32,2,8,46,76,42,41,150,0,5,5,5,360,360,21,190 +32,2,16,44,112,35,33,150,0,5,8,5,379,379,23,347 +32,2,32,48,212,37,35,152,0,8,16,5,489,489,29,551 +32,2,64,54,365,43,36,170,0,16,32,5,673,673,41,780 +32,2,128,64,703,76,38,173,0,32,65,9,1063,1063,64,1000 +32,2,256,91,1271,104,71,244,0,65,131,18,1799,1799,109,1174 +48,2,8,46,76,42,41,130,0,5,6,5,340,340,20,200 +48,2,16,44,112,35,33,131,0,6,13,5,360,360,21,380 +48,2,32,48,212,37,35,138,0,13,27,5,475,475,28,571 +48,2,64,54,365,43,36,153,0,27,54,5,656,656,40,800 +48,2,128,64,703,76,38,214,0,54,109,9,1104,1120,68,941 +48,2,256,91,1271,104,71,258,0,109,219,18,1813,1851,112,1142 +72,2,8,46,76,42,41,115,0,5,10,5,325,325,19,210 +72,2,16,44,112,35,33,116,0,10,21,5,345,345,21,380 +72,2,32,48,212,37,35,117,0,21,43,5,454,454,27,592 +72,2,64,54,365,43,36,136,0,43,87,5,639,646,39,820 +72,2,128,64,703,76,38,196,0,87,175,9,1086,1135,69,927 +72,2,256,91,1271,104,71,213,0,175,350,18,1768,2009,121,1057 +96,2,8,46,76,42,41,103,0,5,10,5,313,313,19,210 +96,2,16,44,112,35,33,104,0,10,21,5,333,333,20,400 +96,2,32,48,212,37,35,117,0,21,43,5,454,454,27,592 +96,2,64,54,365,43,36,177,0,43,87,5,680,687,41,780 +96,2,128,64,703,76,38,182,0,87,175,9,1072,1121,68,941 +96,2,256,91,1271,104,71,258,0,175,350,18,1813,2009,121,1057 +144,2,8,46,76,42,41,96,0,5,10,5,306,306,18,222 +144,2,16,44,112,35,33,97,0,10,21,5,326,326,19,421 +144,2,32,48,212,37,35,108,0,21,43,5,445,445,27,592 +144,2,64,54,365,43,36,159,0,43,87,5,662,669,40,800 +144,2,128,64,703,76,38,164,0,87,175,9,1054,1114,67,955 +144,2,256,91,1271,104,71,207,0,175,350,18,1762,2009,121,1057 +288,2,8,46,76,42,41,87,0,5,10,5,297,297,18,222 +288,2,16,44,112,35,33,100,0,10,21,5,329,329,20,400 +288,2,32,48,212,37,35,147,0,21,43,5,484,484,29,551 +288,2,64,54,365,43,36,145,0,43,87,5,648,655,39,820 +288,2,128,64,703,76,38,156,0,87,175,9,1046,1114,67,955 +288,2,256,91,1271,104,71,251,0,175,350,18,1806,2009,121,1057 +16,4,32,44,116,25,35,216,0,5,5,5,441,441,26,307 +16,4,64,45,206,26,36,227,0,5,5,7,547,547,33,484 +16,4,128,53,374,43,38,229,0,5,10,14,751,751,45,711 +24,4,16,43,81,24,33,172,0,5,5,5,358,358,21,190 +24,4,32,44,116,25,35,173,0,5,5,5,398,398,24,333 +24,4,64,45,206,26,36,174,0,5,10,7,494,494,30,533 +24,4,128,53,374,43,38,188,0,10,21,14,710,710,43,744 +32,4,16,43,81,24,33,150,0,5,5,5,336,336,20,200 +32,4,32,44,116,25,35,151,0,5,8,5,376,376,22,363 +32,4,64,45,206,26,36,151,0,8,16,7,471,471,28,571 +32,4,128,53,374,43,38,170,0,16,32,14,692,692,42,761 +32,4,256,66,752,67,71,174,0,32,65,28,1158,1158,70,914 +48,4,8,43,76,27,41,130,0,5,5,5,322,322,19,105 +48,4,16,43,81,24,33,130,0,5,6,5,316,316,19,210 +48,4,32,44,116,25,35,131,0,6,13,5,356,356,21,380 +48,4,64,45,206,26,36,138,0,13,27,7,458,458,27,592 +48,4,128,53,374,43,38,152,0,27,54,14,674,674,41,780 +48,4,256,66,752,67,71,215,0,54,109,28,1199,1199,73,876 +72,4,8,43,76,27,41,114,0,5,5,5,306,306,18,111 +72,4,16,43,81,24,33,118,0,5,10,5,304,304,18,222 +72,4,32,44,116,25,35,116,0,10,21,5,341,341,20,400 +72,4,64,45,206,26,36,116,0,21,43,7,436,436,26,615 +72,4,128,53,374,43,38,133,0,43,87,14,655,660,40,800 +72,4,256,66,752,67,71,196,0,87,175,28,1180,1196,72,888 +96,4,8,43,76,27,41,101,0,5,5,5,293,293,17,117 +96,4,16,43,81,24,33,103,0,5,10,5,289,289,17,235 +96,4,32,44,116,25,35,105,0,10,21,5,330,330,20,400 +96,4,64,45,206,26,36,117,0,21,43,7,437,437,26,615 +96,4,128,53,374,43,38,177,0,43,87,14,699,704,42,761 +96,4,256,66,752,67,71,182,0,87,175,28,1166,1182,72,888 +144,4,8,43,76,27,41,94,0,5,5,5,286,286,17,117 +144,4,16,43,81,24,33,95,0,5,10,5,281,281,17,235 +144,4,32,44,116,25,35,97,0,10,21,5,322,322,19,421 +144,4,64,45,206,26,36,109,0,21,43,7,429,429,26,615 +144,4,128,53,374,43,38,160,0,43,87,14,682,687,41,780 +144,4,256,66,752,67,71,163,0,87,175,28,1147,1175,71,901 +288,4,8,43,76,27,41,89,0,5,5,5,281,281,17,117 +288,4,16,43,81,24,33,89,0,5,10,5,275,275,16,250 +288,4,32,44,116,25,35,99,0,10,21,5,324,324,19,421 +288,4,64,45,206,26,36,146,0,21,43,7,466,466,28,571 +288,4,128,53,374,43,38,143,0,43,87,14,665,670,40,800 +288,4,256,66,752,67,71,156,0,87,175,28,1140,1175,71,901 +16,8,64,43,109,18,36,214,0,5,5,8,428,428,26,307 +16,8,128,47,222,19,38,227,0,5,5,16,569,569,34,470 +24,8,32,43,70,18,35,173,0,5,5,5,344,344,20,200 +24,8,64,43,109,18,36,173,0,5,5,8,387,387,23,347 +24,8,128,47,222,19,38,174,0,5,10,16,516,516,31,516 +32,8,32,43,70,18,35,151,0,5,5,5,322,322,19,210 +32,8,64,43,109,18,36,151,0,5,8,8,365,365,22,363 +32,8,128,47,222,19,38,152,0,8,16,16,494,494,30,533 +32,8,256,49,404,29,71,172,0,16,32,33,758,758,46,695 +48,8,16,40,72,17,33,129,0,5,5,5,296,296,18,111 +48,8,32,43,70,18,35,130,0,5,6,5,301,301,18,222 +48,8,64,43,109,18,36,132,0,6,13,8,346,346,21,380 +48,8,128,47,222,19,38,138,0,13,27,16,480,480,29,551 +48,8,256,49,404,29,71,152,0,27,54,33,738,738,45,711 +72,8,16,40,72,17,33,113,0,5,5,5,280,280,17,117 +72,8,32,43,70,18,35,115,0,5,10,5,286,286,17,235 +72,8,64,43,109,18,36,116,0,10,21,8,330,330,20,400 +72,8,128,47,222,19,38,117,0,21,43,16,459,459,27,592 +72,8,256,49,404,29,71,133,0,43,87,33,719,719,43,744 +96,8,8,41,75,19,41,101,0,5,5,5,282,282,17,58 +96,8,16,40,72,17,33,102,0,5,5,5,269,269,16,125 +96,8,32,43,70,18,35,103,0,5,10,5,274,274,16,250 +96,8,64,43,109,18,36,105,0,10,21,8,319,319,19,421 +96,8,128,47,222,19,38,117,0,21,43,16,459,459,27,592 +96,8,256,49,404,29,71,178,0,43,87,33,764,764,46,695 +144,8,8,41,75,19,41,93,0,5,5,5,274,274,16,62 +144,8,16,40,72,17,33,94,0,5,5,5,261,261,15,133 +144,8,32,43,70,18,35,96,0,5,10,5,267,267,16,250 +144,8,64,43,109,18,36,98,0,10,21,8,312,312,19,421 +144,8,128,47,222,19,38,108,0,21,43,16,450,450,27,592 +144,8,256,49,404,29,71,161,0,43,87,33,747,747,45,711 +288,8,8,41,75,19,41,87,0,5,5,5,268,268,16,62 +288,8,16,40,72,17,33,87,0,5,5,5,254,254,15,133 +288,8,32,43,70,18,35,89,0,5,10,5,260,260,15,266 +288,8,64,43,109,18,36,100,0,10,21,8,314,314,19,421 +288,8,128,47,222,19,38,147,0,21,43,16,489,489,29,551 +288,8,256,49,404,29,71,144,0,43,87,33,730,730,44,727 diff --git a/results/MI300X-two-microbatch-overlapping.csv b/results/MI300X-two-microbatch-overlapping.csv new file mode 100644 index 0000000..67290a6 --- /dev/null +++ b/results/MI300X-two-microbatch-overlapping.csv @@ -0,0 +1,171 @@ +d,tp,b_mla,QKV(us),ATTN(us),O(us),Shared(us),Up_Gemm(us),Down_Gemm(us),Dispatch_AlltoAll(us),Combine_AlltoAll(us),AllReduce(us),t_{dense_layer}(us),t_{moe_layer}(us),TPOT(ms),Single-Device Throughput(Tokens/s) +16,1,8,49,112,124,41,215,0,5,5,0,1082,1082,66,242 +16,1,16,49,214,115,33,226,0,5,5,0,1274,1274,77,415 +16,1,32,59,334,118,35,228,0,5,10,0,1548,1548,94,680 +16,1,64,65,655,119,36,230,0,10,21,0,2210,2210,134,955 +16,1,128,92,1255,171,38,282,0,21,43,0,3676,3676,224,1142 +24,1,8,49,112,124,41,171,0,5,5,0,994,994,60,266 +24,1,16,49,214,115,33,173,0,5,10,0,1168,1168,71,450 +24,1,32,59,334,118,35,187,0,10,21,0,1466,1466,89,719 +24,1,64,65,655,119,36,190,0,21,43,0,2130,2130,129,992 +24,1,128,92,1255,171,38,272,0,43,87,0,3656,3656,223,1147 +32,1,8,49,112,124,41,149,0,5,8,0,950,950,57,280 +32,1,16,49,214,115,33,151,0,8,16,0,1124,1124,68,470 +32,1,32,59,334,118,35,170,0,16,32,0,1432,1432,87,735 +32,1,64,65,655,119,36,172,0,32,65,0,2094,2094,127,1007 +32,1,128,92,1255,171,38,243,0,65,131,0,3598,3598,219,1168 +32,1,256,136,2490,228,71,479,0,131,262,0,6808,6808,415,1233 +48,1,8,49,112,124,41,130,0,6,13,0,912,912,55,290 +48,1,16,49,214,115,33,137,0,13,27,0,1096,1096,66,484 +48,1,32,59,334,118,35,151,0,27,54,0,1394,1394,85,752 +48,1,64,65,655,119,36,214,0,54,109,0,2178,2178,132,969 +48,1,128,92,1255,171,38,257,0,109,219,0,3626,3626,221,1158 +48,1,256,136,2490,228,71,486,0,219,438,0,6822,6846,417,1227 +72,1,8,49,112,124,41,114,0,10,21,0,880,880,53,301 +72,1,16,49,214,115,33,117,0,21,43,0,1056,1056,64,500 +72,1,32,59,334,118,35,133,0,43,87,0,1358,1358,82,780 +72,1,64,65,655,119,36,195,0,87,175,0,2140,2140,130,984 +72,1,128,92,1255,171,38,213,0,175,350,0,3538,3628,221,1158 +72,1,256,136,2490,228,71,449,0,350,701,0,6748,7034,428,1196 +96,1,8,49,112,124,41,104,0,10,21,0,860,860,52,307 +96,1,16,49,214,115,33,116,0,21,43,0,1054,1054,64,500 +96,1,32,59,334,118,35,175,0,43,87,0,1442,1442,87,735 +96,1,64,65,655,119,36,181,0,87,175,0,2112,2112,128,1000 +96,1,128,92,1255,171,38,262,0,175,350,0,3636,3726,227,1127 +96,1,256,136,2490,228,71,514,0,350,701,0,6878,7164,436,1174 +144,1,8,49,112,124,41,96,0,10,21,0,844,844,51,313 +144,1,16,49,214,115,33,108,0,21,43,0,1038,1038,63,507 +144,1,32,59,334,118,35,160,0,43,87,0,1412,1412,86,744 +144,1,64,65,655,119,36,162,0,87,175,0,2074,2074,126,1015 +144,1,128,92,1255,171,38,207,0,175,350,0,3526,3616,220,1163 +144,1,256,136,2490,228,71,438,0,350,701,0,6726,7012,426,1201 +288,1,8,49,112,124,41,101,0,10,21,0,854,854,52,307 +288,1,16,49,214,115,33,146,0,21,43,0,1114,1114,67,477 +288,1,32,59,334,118,35,143,0,43,87,0,1378,1378,84,761 +288,1,64,65,655,119,36,155,0,87,175,0,2060,2060,125,1024 +288,1,128,92,1255,171,38,247,0,175,350,0,3606,3696,225,1137 +288,1,256,136,2490,228,71,521,0,350,701,0,6892,7178,437,1171 +16,2,16,44,112,35,33,215,0,5,5,5,888,888,54,296 +16,2,32,48,212,37,35,225,0,5,5,5,1124,1124,68,470 +16,2,64,54,365,43,36,228,0,5,10,5,1462,1462,89,719 +16,2,128,64,703,76,38,231,0,10,21,9,2242,2242,136,941 +24,2,8,46,76,42,41,170,0,5,5,5,760,760,46,173 +24,2,16,44,112,35,33,172,0,5,5,5,802,802,48,333 +24,2,32,48,212,37,35,174,0,5,10,5,1022,1022,62,516 +24,2,64,54,365,43,36,188,0,10,21,5,1382,1382,84,761 +24,2,128,64,703,76,38,191,0,21,43,9,2162,2162,131,977 +32,2,8,46,76,42,41,150,0,5,5,5,720,720,43,186 +32,2,16,44,112,35,33,150,0,5,8,5,758,758,46,347 +32,2,32,48,212,37,35,152,0,8,16,5,978,978,59,542 +32,2,64,54,365,43,36,170,0,16,32,5,1346,1346,82,780 +32,2,128,64,703,76,38,173,0,32,65,9,2126,2126,129,992 +32,2,256,91,1271,104,71,244,0,65,131,18,3598,3598,219,1168 +48,2,8,46,76,42,41,130,0,5,6,5,680,680,41,195 +48,2,16,44,112,35,33,131,0,6,13,5,720,720,43,372 +48,2,32,48,212,37,35,138,0,13,27,5,950,950,57,561 +48,2,64,54,365,43,36,153,0,27,54,5,1312,1312,80,800 +48,2,128,64,703,76,38,214,0,54,109,9,2208,2208,134,955 +48,2,256,91,1271,104,71,258,0,109,219,18,3626,3626,221,1158 +72,2,8,46,76,42,41,115,0,5,10,5,650,650,39,205 +72,2,16,44,112,35,33,116,0,10,21,5,690,690,42,380 +72,2,32,48,212,37,35,117,0,21,43,5,908,908,55,581 +72,2,64,54,365,43,36,136,0,43,87,5,1278,1278,77,831 +72,2,128,64,703,76,38,196,0,87,175,9,2172,2172,132,969 +72,2,256,91,1271,104,71,213,0,175,350,18,3536,3562,217,1179 +96,2,8,46,76,42,41,103,0,5,10,5,626,626,38,210 +96,2,16,44,112,35,33,104,0,10,21,5,666,666,40,400 +96,2,32,48,212,37,35,117,0,21,43,5,908,908,55,581 +96,2,64,54,365,43,36,177,0,43,87,5,1360,1360,82,780 +96,2,128,64,703,76,38,182,0,87,175,9,2144,2144,130,984 +96,2,256,91,1271,104,71,258,0,175,350,18,3626,3652,222,1153 +144,2,8,46,76,42,41,96,0,5,10,5,612,612,37,216 +144,2,16,44,112,35,33,97,0,10,21,5,652,652,39,410 +144,2,32,48,212,37,35,108,0,21,43,5,890,890,54,592 +144,2,64,54,365,43,36,159,0,43,87,5,1324,1324,80,800 +144,2,128,64,703,76,38,164,0,87,175,9,2108,2108,128,1000 +144,2,256,91,1271,104,71,207,0,175,350,18,3524,3550,216,1185 +288,2,8,46,76,42,41,87,0,5,10,5,594,594,36,222 +288,2,16,44,112,35,33,100,0,10,21,5,658,658,40,400 +288,2,32,48,212,37,35,147,0,21,43,5,968,968,59,542 +288,2,64,54,365,43,36,145,0,43,87,5,1296,1296,79,810 +288,2,128,64,703,76,38,156,0,87,175,9,2092,2092,127,1007 +288,2,256,91,1271,104,71,251,0,175,350,18,3612,3638,221,1158 +16,4,32,44,116,25,35,216,0,5,5,5,882,882,53,301 +16,4,64,45,206,26,36,227,0,5,5,7,1094,1094,66,484 +16,4,128,53,374,43,38,229,0,5,10,14,1502,1502,91,703 +24,4,16,43,81,24,33,172,0,5,5,5,716,716,43,186 +24,4,32,44,116,25,35,173,0,5,5,5,796,796,48,333 +24,4,64,45,206,26,36,174,0,5,10,7,988,988,60,533 +24,4,128,53,374,43,38,188,0,10,21,14,1420,1420,86,744 +32,4,16,43,81,24,33,150,0,5,5,5,672,672,40,200 +32,4,32,44,116,25,35,151,0,5,8,5,752,752,45,355 +32,4,64,45,206,26,36,151,0,8,16,7,942,942,57,561 +32,4,128,53,374,43,38,170,0,16,32,14,1384,1384,84,761 +32,4,256,66,752,67,71,174,0,32,65,28,2316,2316,141,907 +48,4,8,43,76,27,41,130,0,5,5,5,644,644,39,102 +48,4,16,43,81,24,33,130,0,5,6,5,632,632,38,210 +48,4,32,44,116,25,35,131,0,6,13,5,712,712,43,372 +48,4,64,45,206,26,36,138,0,13,27,7,916,916,55,581 +48,4,128,53,374,43,38,152,0,27,54,14,1348,1348,82,780 +48,4,256,66,752,67,71,215,0,54,109,28,2398,2398,146,876 +72,4,8,43,76,27,41,114,0,5,5,5,612,612,37,108 +72,4,16,43,81,24,33,118,0,5,10,5,608,608,37,216 +72,4,32,44,116,25,35,116,0,10,21,5,682,682,41,390 +72,4,64,45,206,26,36,116,0,21,43,7,872,872,53,603 +72,4,128,53,374,43,38,133,0,43,87,14,1310,1310,79,810 +72,4,256,66,752,67,71,196,0,87,175,28,2360,2360,143,895 +96,4,8,43,76,27,41,101,0,5,5,5,586,586,35,114 +96,4,16,43,81,24,33,103,0,5,10,5,578,578,35,228 +96,4,32,44,116,25,35,105,0,10,21,5,660,660,40,400 +96,4,64,45,206,26,36,117,0,21,43,7,874,874,53,603 +96,4,128,53,374,43,38,177,0,43,87,14,1398,1398,85,752 +96,4,256,66,752,67,71,182,0,87,175,28,2332,2332,142,901 +144,4,8,43,76,27,41,94,0,5,5,5,572,572,34,117 +144,4,16,43,81,24,33,95,0,5,10,5,562,562,34,235 +144,4,32,44,116,25,35,97,0,10,21,5,644,644,39,410 +144,4,64,45,206,26,36,109,0,21,43,7,858,858,52,615 +144,4,128,53,374,43,38,160,0,43,87,14,1364,1364,83,771 +144,4,256,66,752,67,71,163,0,87,175,28,2294,2294,139,920 +288,4,8,43,76,27,41,89,0,5,5,5,562,562,34,117 +288,4,16,43,81,24,33,89,0,5,10,5,550,550,33,242 +288,4,32,44,116,25,35,99,0,10,21,5,648,648,39,410 +288,4,64,45,206,26,36,146,0,21,43,7,932,932,56,571 +288,4,128,53,374,43,38,143,0,43,87,14,1330,1330,81,790 +288,4,256,66,752,67,71,156,0,87,175,28,2280,2280,139,920 +16,8,64,43,109,18,36,214,0,5,5,8,856,856,52,307 +16,8,128,47,222,19,38,227,0,5,5,16,1138,1138,69,463 +24,8,32,43,70,18,35,173,0,5,5,5,688,688,41,195 +24,8,64,43,109,18,36,173,0,5,5,8,774,774,47,340 +24,8,128,47,222,19,38,174,0,5,10,16,1032,1032,62,516 +32,8,32,43,70,18,35,151,0,5,5,5,644,644,39,205 +32,8,64,43,109,18,36,151,0,5,8,8,730,730,44,363 +32,8,128,47,222,19,38,152,0,8,16,16,988,988,60,533 +32,8,256,49,404,29,71,172,0,16,32,33,1516,1516,92,695 +48,8,16,40,72,17,33,129,0,5,5,5,592,592,36,111 +48,8,32,43,70,18,35,130,0,5,6,5,602,602,36,222 +48,8,64,43,109,18,36,132,0,6,13,8,692,692,42,380 +48,8,128,47,222,19,38,138,0,13,27,16,960,960,58,551 +48,8,256,49,404,29,71,152,0,27,54,33,1476,1476,90,711 +72,8,16,40,72,17,33,113,0,5,5,5,560,560,34,117 +72,8,32,43,70,18,35,115,0,5,10,5,572,572,34,235 +72,8,64,43,109,18,36,116,0,10,21,8,660,660,40,400 +72,8,128,47,222,19,38,117,0,21,43,16,918,918,55,581 +72,8,256,49,404,29,71,133,0,43,87,33,1438,1438,87,735 +96,8,8,41,75,19,41,101,0,5,5,5,564,564,34,58 +96,8,16,40,72,17,33,102,0,5,5,5,538,538,32,125 +96,8,32,43,70,18,35,103,0,5,10,5,548,548,33,242 +96,8,64,43,109,18,36,105,0,10,21,8,638,638,38,421 +96,8,128,47,222,19,38,117,0,21,43,16,918,918,55,581 +96,8,256,49,404,29,71,178,0,43,87,33,1528,1528,93,688 +144,8,8,41,75,19,41,93,0,5,5,5,548,548,33,60 +144,8,16,40,72,17,33,94,0,5,5,5,522,522,31,129 +144,8,32,43,70,18,35,96,0,5,10,5,534,534,32,250 +144,8,64,43,109,18,36,98,0,10,21,8,624,624,38,421 +144,8,128,47,222,19,38,108,0,21,43,16,900,900,54,592 +144,8,256,49,404,29,71,161,0,43,87,33,1494,1494,91,703 +288,8,8,41,75,19,41,87,0,5,5,5,536,536,32,62 +288,8,16,40,72,17,33,87,0,5,5,5,508,508,30,133 +288,8,32,43,70,18,35,89,0,5,10,5,520,520,31,258 +288,8,64,43,109,18,36,100,0,10,21,8,628,628,38,421 +288,8,128,47,222,19,38,147,0,21,43,16,978,978,59,542 +288,8,256,49,404,29,71,144,0,43,87,33,1460,1460,89,719 diff --git a/results/MI300X_batch_gemm.csv b/results/MI300X_batch_gemm.csv new file mode 100644 index 0000000..d3f8be1 --- /dev/null +++ b/results/MI300X_batch_gemm.csv @@ -0,0 +1,49 @@ +matrix_idx,tp,batch,m,n,k,time_us,throughput_TFLOPS,bandwidth_GBps +3,1,128,8,512,128,8,16,1153 +3,2,64,8,512,128,7,10,717 +3,4,32,8,512,128,5,6,446 +3,8,16,8,512,128,5,3,249 +9,1,128,8,128,512,9,15,1053 +9,2,64,8,128,512,8,9,601 +9,4,32,8,128,512,8,4,296 +9,8,16,8,128,512,8,2,148 +3,1,128,16,512,128,8,33,1307 +3,2,64,16,512,128,6,21,852 +3,4,32,16,512,128,5,13,532 +3,8,16,16,512,128,4,8,302 +9,1,128,16,128,512,10,26,976 +9,2,64,16,128,512,7,18,670 +9,4,32,16,128,512,8,9,325 +9,8,16,16,128,512,8,4,159 +3,1,128,32,512,128,13,40,977 +3,2,64,32,512,128,9,29,720 +3,4,32,32,512,128,6,21,519 +3,8,16,32,512,128,5,14,340 +9,1,128,32,128,512,14,38,807 +9,2,64,32,128,512,9,29,614 +9,4,32,32,128,512,8,17,361 +9,8,16,32,128,512,7,9,195 +3,1,128,64,512,128,13,84,1398 +3,2,64,64,512,128,9,57,946 +3,4,32,64,512,128,7,41,685 +3,8,16,64,512,128,5,25,407 +9,1,128,64,128,512,17,65,882 +9,2,64,64,128,512,13,41,566 +9,4,32,64,128,512,9,31,429 +9,8,16,64,128,512,7,19,258 +3,1,128,128,512,128,20,107,1362 +3,2,64,128,512,128,12,93,1182 +3,4,32,128,512,128,8,69,878 +3,8,16,128,512,128,6,45,573 +9,1,128,128,128,512,22,96,938 +9,2,64,128,128,512,20,55,534 +9,4,32,128,128,512,10,53,519 +9,8,16,128,128,512,7,40,389 +3,1,128,256,512,128,29,146,1567 +3,2,64,256,512,128,17,125,1338 +3,4,32,256,512,128,13,84,902 +3,8,16,256,512,128,7,76,817 +9,1,128,256,128,512,34,128,999 +9,2,64,256,128,512,19,111,870 +9,4,32,256,128,512,16,68,530 +9,8,16,256,128,512,9,60,467 diff --git a/results/MI300X_dense_gemm.csv b/results/MI300X_dense_gemm.csv new file mode 100644 index 0000000..706d381 --- /dev/null +++ b/results/MI300X_dense_gemm.csv @@ -0,0 +1,67 @@ +matrix_idx,tp,m,n,k,time_us,throughput_TFLOPS,bandwidth_GBps +1,1,8,2112,7168,29,8,524 +2,1,8,24576,1536,12,51,3233 +2,2,8,12288,1536,10,31,1972 +2,4,8,6144,1536,9,16,1034 +2,8,8,3072,1536,7,11,705 +4,1,8,7168,16384,115,16,1027 +4,2,8,7168,8192,34,27,1716 +4,4,8,7168,4096,19,25,1558 +4,8,8,7168,2048,11,22,1362 +5,1,8,4096,7168,30,16,979 +6,1,8,7168,2048,11,21,1354 +1,1,16,2112,7168,30,16,515 +2,1,16,24576,1536,11,109,3487 +2,2,16,12288,1536,8,73,2331 +2,4,16,6144,1536,8,38,1228 +2,8,16,3072,1536,6,25,808 +4,1,16,7168,16384,105,36,1123 +4,2,16,7168,8192,28,66,2086 +4,4,16,7168,4096,16,57,1810 +4,8,16,7168,2048,9,52,1642 +5,1,16,4096,7168,24,40,1246 +6,1,16,7168,2048,9,50,1600 +1,1,32,2112,7168,30,32,517 +2,1,32,24576,1536,16,152,2474 +2,2,32,12288,1536,9,139,2269 +2,4,32,6144,1536,8,75,1222 +2,8,32,3072,1536,8,37,613 +4,1,32,7168,16384,104,72,1135 +4,2,32,7168,8192,28,132,2092 +4,4,32,7168,4096,17,111,1776 +4,8,32,7168,2048,11,82,1327 +5,1,32,4096,7168,24,79,1247 +6,1,32,7168,2048,11,82,1329 +1,1,64,2112,7168,30,65,530 +2,1,64,24576,1536,22,219,1855 +2,2,64,12288,1536,15,159,1351 +2,4,64,6144,1536,8,143,1226 +2,8,64,3072,1536,8,71,616 +4,1,64,7168,16384,102,148,1176 +4,2,64,7168,8192,30,255,2038 +4,4,64,7168,4096,17,216,1754 +4,8,64,7168,2048,11,170,1424 +5,1,64,4096,7168,25,149,1201 +6,1,64,7168,2048,11,174,1459 +1,1,128,2112,7168,30,129,551 +2,1,128,24576,1536,42,231,1059 +2,2,128,12288,1536,22,215,988 +2,4,128,6144,1536,15,160,741 +2,8,128,3072,1536,11,111,526 +4,1,128,7168,16384,149,201,813 +4,2,128,7168,8192,56,270,1106 +4,4,128,7168,4096,33,227,960 +4,8,128,7168,2048,12,323,1443 +5,1,128,4096,7168,26,289,1205 +6,1,128,7168,2048,12,322,1436 +1,1,256,2112,7168,31,253,589 +2,1,256,24576,1536,76,253,665 +2,2,256,12288,1536,43,224,594 +2,4,256,6144,1536,22,215,579 +2,8,256,3072,1536,11,211,583 +4,1,256,7168,16384,194,310,645 +4,2,256,7168,8192,85,353,757 +4,4,256,7168,4096,51,296,670 +4,8,256,7168,2048,20,379,952 +5,1,256,4096,7168,51,296,654 +6,1,256,7168,2048,20,381,956 diff --git a/results/MI300X_group_gemm.csv b/results/MI300X_group_gemm.csv new file mode 100644 index 0000000..0da0405 --- /dev/null +++ b/results/MI300X_group_gemm.csv @@ -0,0 +1,171 @@ +d,num_groups,b_mla,m_per_group,matrix_idx,tp,m,n,k,time_us,throughput_TFLOPS,bandwidth_GBps +16,18,8,4,7,1,4,7168,2048,215,10,1237 +16,18,16,8,7,1,8,7168,2048,226,19,1177 +16,18,32,16,7,1,16,7168,2048,228,37,1180 +16,18,64,32,7,1,32,7168,2048,230,74,1190 +16,18,128,64,7,1,64,7168,2048,282,120,1004 +24,12,8,8,7,1,8,7168,2048,171,16,1038 +24,12,16,16,7,1,16,7168,2048,173,33,1039 +24,12,32,32,7,1,32,7168,2048,187,60,975 +24,12,64,64,7,1,64,7168,2048,190,118,992 +24,12,128,128,7,1,128,7168,2048,272,166,740 +32,9,8,8,7,1,8,7168,2048,149,14,892 +32,9,16,16,7,1,16,7168,2048,151,28,892 +32,9,32,32,7,1,32,7168,2048,170,50,804 +32,9,64,64,7,1,64,7168,2048,172,98,822 +32,9,128,128,7,1,128,7168,2048,243,139,621 +32,9,256,256,7,1,256,7168,2048,479,141,355 +48,6,8,16,7,1,16,7168,2048,130,22,687 +48,6,16,32,7,1,32,7168,2048,137,41,665 +48,6,32,64,7,1,64,7168,2048,151,74,623 +48,6,64,128,7,1,128,7168,2048,214,106,471 +48,6,128,256,7,1,256,7168,2048,257,176,441 +48,6,256,512,7,1,512,7168,2048,486,186,285 +72,4,8,16,7,1,16,7168,2048,114,16,522 +72,4,16,32,7,1,32,7168,2048,117,32,520 +72,4,32,64,7,1,64,7168,2048,133,57,473 +72,4,64,128,7,1,128,7168,2048,195,77,345 +72,4,128,256,7,1,256,7168,2048,213,141,354 +72,4,256,512,7,1,512,7168,2048,449,134,206 +96,3,8,32,7,1,32,7168,2048,104,27,437 +96,3,16,64,7,1,64,7168,2048,116,49,406 +96,3,32,128,7,1,128,7168,2048,175,64,287 +96,3,64,256,7,1,256,7168,2048,181,124,312 +96,3,128,512,7,1,512,7168,2048,262,172,265 +96,3,256,1024,7,1,1024,7168,2048,514,176,184 +144,2,8,32,7,1,32,7168,2048,96,20,316 +144,2,16,64,7,1,64,7168,2048,108,35,293 +144,2,32,128,7,1,128,7168,2048,160,47,209 +144,2,64,256,7,1,256,7168,2048,162,93,233 +144,2,128,512,7,1,512,7168,2048,207,146,223 +144,2,256,1024,7,1,1024,7168,2048,438,137,144 +288,1,8,64,7,1,64,7168,2048,101,19,156 +288,1,16,128,7,1,128,7168,2048,146,26,115 +288,1,32,256,7,1,256,7168,2048,143,53,132 +288,1,64,512,7,1,512,7168,2048,155,97,149 +288,1,128,1024,7,1,1024,7168,2048,247,122,127 +288,1,256,2048,7,1,2048,7168,2048,521,115,93 +16,18,16,4,7,1,4,7168,2048,215,10,1233 +16,18,32,8,7,1,8,7168,2048,225,19,1183 +16,18,64,16,7,1,16,7168,2048,228,37,1177 +16,18,128,32,7,1,32,7168,2048,231,73,1186 +24,12,8,4,7,1,4,7168,2048,170,8,1040 +24,12,16,8,7,1,8,7168,2048,172,16,1032 +24,12,32,16,7,1,16,7168,2048,174,32,1033 +24,12,64,32,7,1,32,7168,2048,188,60,972 +24,12,128,64,7,1,64,7168,2048,191,118,988 +32,9,8,4,7,1,4,7168,2048,150,7,885 +32,9,16,8,7,1,8,7168,2048,150,14,887 +32,9,32,16,7,1,16,7168,2048,152,28,884 +32,9,64,32,7,1,32,7168,2048,170,50,803 +32,9,128,64,7,1,64,7168,2048,173,98,816 +32,9,256,128,7,1,128,7168,2048,244,139,619 +48,6,8,8,7,1,8,7168,2048,130,11,684 +48,6,16,16,7,1,16,7168,2048,131,21,682 +48,6,32,32,7,1,32,7168,2048,138,41,662 +48,6,64,64,7,1,64,7168,2048,153,74,616 +48,6,128,128,7,1,128,7168,2048,214,105,471 +48,6,256,256,7,1,256,7168,2048,258,175,439 +72,4,8,8,7,1,8,7168,2048,115,8,514 +72,4,16,16,7,1,16,7168,2048,116,16,516 +72,4,32,32,7,1,32,7168,2048,117,32,520 +72,4,64,64,7,1,64,7168,2048,136,55,463 +72,4,128,128,7,1,128,7168,2048,196,77,342 +72,4,256,256,7,1,256,7168,2048,213,141,354 +96,3,8,16,7,1,16,7168,2048,103,14,434 +96,3,16,32,7,1,32,7168,2048,104,27,438 +96,3,32,64,7,1,64,7168,2048,117,48,402 +96,3,64,128,7,1,128,7168,2048,177,64,284 +96,3,128,256,7,1,256,7168,2048,182,124,311 +96,3,256,512,7,1,512,7168,2048,258,175,268 +144,2,8,16,7,1,16,7168,2048,96,10,312 +144,2,16,32,7,1,32,7168,2048,97,19,314 +144,2,32,64,7,1,64,7168,2048,108,35,292 +144,2,64,128,7,1,128,7168,2048,159,47,210 +144,2,128,256,7,1,256,7168,2048,164,92,231 +144,2,256,512,7,1,512,7168,2048,207,145,223 +288,1,8,32,7,1,32,7168,2048,87,11,175 +288,1,16,64,7,1,64,7168,2048,100,19,157 +288,1,32,128,7,1,128,7168,2048,147,26,114 +288,1,64,256,7,1,256,7168,2048,145,52,131 +288,1,128,512,7,1,512,7168,2048,156,96,148 +288,1,256,1024,7,1,1024,7168,2048,251,120,126 +16,18,32,4,7,1,4,7168,2048,216,10,1227 +16,18,64,8,7,1,8,7168,2048,227,19,1172 +16,18,128,16,7,1,16,7168,2048,229,37,1173 +24,12,16,4,7,1,4,7168,2048,172,8,1031 +24,12,32,8,7,1,8,7168,2048,173,16,1028 +24,12,64,16,7,1,16,7168,2048,174,32,1033 +24,12,128,32,7,1,32,7168,2048,188,60,972 +32,9,16,4,7,1,4,7168,2048,150,7,883 +32,9,32,8,7,1,8,7168,2048,151,14,883 +32,9,64,16,7,1,16,7168,2048,151,28,891 +32,9,128,32,7,1,32,7168,2048,170,50,803 +32,9,256,64,7,1,64,7168,2048,174,97,815 +48,6,8,4,7,1,4,7168,2048,130,5,682 +48,6,16,8,7,1,8,7168,2048,130,11,681 +48,6,32,16,7,1,16,7168,2048,131,21,682 +48,6,64,32,7,1,32,7168,2048,138,41,663 +48,6,128,64,7,1,64,7168,2048,152,74,620 +48,6,256,128,7,1,128,7168,2048,215,105,468 +72,4,8,4,7,1,4,7168,2048,114,4,518 +72,4,16,8,7,1,8,7168,2048,118,8,500 +72,4,32,16,7,1,16,7168,2048,116,16,513 +72,4,64,32,7,1,32,7168,2048,116,32,523 +72,4,128,64,7,1,64,7168,2048,133,57,474 +72,4,256,128,7,1,128,7168,2048,196,77,342 +96,3,8,8,7,1,8,7168,2048,101,7,439 +96,3,16,16,7,1,16,7168,2048,103,14,436 +96,3,32,32,7,1,32,7168,2048,105,27,434 +96,3,64,64,7,1,64,7168,2048,117,48,405 +96,3,128,128,7,1,128,7168,2048,177,64,284 +96,3,256,256,7,1,256,7168,2048,182,124,311 +144,2,8,8,7,1,8,7168,2048,94,5,315 +144,2,16,16,7,1,16,7168,2048,95,10,315 +144,2,32,32,7,1,32,7168,2048,97,19,313 +144,2,64,64,7,1,64,7168,2048,109,35,290 +144,2,128,128,7,1,128,7168,2048,160,47,210 +144,2,256,256,7,1,256,7168,2048,163,92,231 +288,1,8,16,7,1,16,7168,2048,89,5,169 +288,1,16,32,7,1,32,7168,2048,89,11,171 +288,1,32,64,7,1,64,7168,2048,99,19,159 +288,1,64,128,7,1,128,7168,2048,146,26,115 +288,1,128,256,7,1,256,7168,2048,143,53,132 +288,1,256,512,7,1,512,7168,2048,156,96,148 +16,18,64,4,7,1,4,7168,2048,214,10,1242 +16,18,128,8,7,1,8,7168,2048,227,19,1177 +24,12,32,4,7,1,4,7168,2048,173,8,1025 +24,12,64,8,7,1,8,7168,2048,173,16,1028 +24,12,128,16,7,1,16,7168,2048,174,32,1033 +32,9,32,4,7,1,4,7168,2048,151,7,877 +32,9,64,8,7,1,8,7168,2048,151,14,882 +32,9,128,16,7,1,16,7168,2048,152,28,882 +32,9,256,32,7,1,32,7168,2048,172,49,798 +48,6,16,4,7,1,4,7168,2048,129,5,684 +48,6,32,8,7,1,8,7168,2048,130,11,685 +48,6,64,16,7,1,16,7168,2048,132,21,681 +48,6,128,32,7,1,32,7168,2048,138,41,659 +48,6,256,64,7,1,64,7168,2048,152,74,619 +72,4,16,4,7,1,4,7168,2048,113,4,521 +72,4,32,8,7,1,8,7168,2048,115,8,517 +72,4,64,16,7,1,16,7168,2048,116,16,516 +72,4,128,32,7,1,32,7168,2048,117,32,521 +72,4,256,64,7,1,64,7168,2048,133,56,472 +96,3,8,4,7,1,4,7168,2048,101,3,436 +96,3,16,8,7,1,8,7168,2048,102,7,435 +96,3,32,16,7,1,16,7168,2048,103,14,437 +96,3,64,32,7,1,32,7168,2048,105,27,435 +96,3,128,64,7,1,64,7168,2048,117,48,402 +96,3,256,128,7,1,128,7168,2048,178,63,283 +144,2,8,4,7,1,4,7168,2048,93,3,316 +144,2,16,8,7,1,8,7168,2048,94,5,315 +144,2,32,16,7,1,16,7168,2048,96,10,312 +144,2,64,32,7,1,32,7168,2048,98,19,312 +144,2,128,64,7,1,64,7168,2048,108,35,290 +144,2,256,128,7,1,128,7168,2048,161,47,209 +288,1,8,8,7,1,8,7168,2048,87,3,169 +288,1,16,16,7,1,16,7168,2048,87,5,171 +288,1,32,32,7,1,32,7168,2048,89,11,171 +288,1,64,64,7,1,64,7168,2048,100,19,157 +288,1,128,128,7,1,128,7168,2048,147,26,114 +288,1,256,256,7,1,256,7168,2048,144,52,131 diff --git a/results/MI300X_mla.csv b/results/MI300X_mla.csv new file mode 100644 index 0000000..cc16114 --- /dev/null +++ b/results/MI300X_mla.csv @@ -0,0 +1,99 @@ +b,s_q,mean_sk,h_q,h_kv,d,dv,causal,varlen,latency,tflops,bandwidth +,,,,,,,,,,, +,,,,,,,,,,, +,,,,,,,,,,, +8,1,5000,128,1,576,512,True,False,0.122,91,396 +,,,,,,,,,,, +8,1,5000,128,1,576,512,True,True,0.112,96,416 +,,,,,,,,,,, +8,1,5000,64,1,576,512,True,False,0.075,74,627 +,,,,,,,,,,, +8,1,5000,64,1,576,512,True,True,0.076,76,644 +,,,,,,,,,,, +8,1,5000,32,1,576,512,True,False,0.075,37,625 +,,,,,,,,,,, +8,1,5000,32,1,576,512,True,True,0.076,31,514 +,,,,,,,,,,, +8,1,5000,16,1,576,512,True,False,0.076,18,607 +,,,,,,,,,,, +8,1,5000,16,1,576,512,True,True,0.075,20,661 +,,,,,,,,,,, +16,1,5000,128,1,576,512,True,False,0.219,102,442 +,,,,,,,,,,, +16,1,5000,128,1,576,512,True,True,0.214,115,496 +,,,,,,,,,,, +16,1,5000,64,1,576,512,True,False,0.123,91,768 +,,,,,,,,,,, +16,1,5000,64,1,576,512,True,True,0.112,90,766 +,,,,,,,,,,, +16,1,5000,32,1,576,512,True,False,0.071,79,1316 +,,,,,,,,,,, +16,1,5000,32,1,576,512,True,True,0.081,70,1172 +,,,,,,,,,,, +16,1,5000,16,1,576,512,True,False,0.072,39,1287 +,,,,,,,,,,, +16,1,5000,16,1,576,512,True,True,0.072,46,1518 +,,,,,,,,,,, +32,1,5000,128,1,576,512,True,False,0.376,119,514 +,,,,,,,,,,, +32,1,5000,128,1,576,512,True,True,0.334,125,544 +,,,,,,,,,,, +32,1,5000,64,1,576,512,True,False,0.196,114,962 +,,,,,,,,,,, +32,1,5000,64,1,576,512,True,True,0.212,120,1013 +,,,,,,,,,,, +32,1,5000,32,1,576,512,True,False,0.107,105,1750 +,,,,,,,,,,, +32,1,5000,32,1,576,512,True,True,0.116,98,1634 +,,,,,,,,,,, +32,1,5000,16,1,576,512,True,False,0.071,79,2619 +,,,,,,,,,,, +32,1,5000,16,1,576,512,True,True,0.070,73,2422 +,,,,,,,,,,, +64,1,5000,128,1,576,512,True,False,0.656,136,589 +,,,,,,,,,,, +64,1,5000,128,1,576,512,True,True,0.655,139,601 +,,,,,,,,,,, +64,1,5000,64,1,576,512,True,False,0.348,128,1086 +,,,,,,,,,,, +64,1,5000,64,1,576,512,True,True,0.365,127,1076 +,,,,,,,,,,, +64,1,5000,32,1,576,512,True,False,0.176,126,2118 +,,,,,,,,,,, +64,1,5000,32,1,576,512,True,True,0.206,113,1895 +,,,,,,,,,,, +64,1,5000,16,1,576,512,True,False,0.096,116,3858 +,,,,,,,,,,, +64,1,5000,16,1,576,512,True,True,0.109,102,3386 +,,,,,,,,,,, +128,1,5000,128,1,576,512,True,False,1.224,146,631 +,,,,,,,,,,, +128,1,5000,128,1,576,512,True,True,1.255,143,622 +,,,,,,,,,,, +128,1,5000,64,1,576,512,True,False,0.642,139,1176 +,,,,,,,,,,, +128,1,5000,64,1,576,512,True,True,0.703,131,1111 +,,,,,,,,,,, +128,1,5000,32,1,576,512,True,False,0.321,139,2326 +,,,,,,,,,,, +128,1,5000,32,1,576,512,True,True,0.374,120,2013 +,,,,,,,,,,, +128,1,5000,16,1,576,512,True,False,0.167,133,4429 +,,,,,,,,,,, +128,1,5000,16,1,576,512,True,True,0.222,102,3398 +,,,,,,,,,,, +256,1,5000,128,1,576,512,True,False,2.355,151,657 +,,,,,,,,,,, +256,1,5000,128,1,576,512,True,True,2.490,147,636 +,,,,,,,,,,, +256,1,5000,64,1,576,512,True,False,1.230,145,1228 +,,,,,,,,,,, +256,1,5000,64,1,576,512,True,True,1.271,134,1140 +,,,,,,,,,,, +256,1,5000,32,1,576,512,True,False,0.608,146,2453 +,,,,,,,,,,, +256,1,5000,32,1,576,512,True,True,0.752,120,2005 +,,,,,,,,,,, +256,1,5000,16,1,576,512,True,False,0.313,142,4743 +,,,,,,,,,,, +256,1,5000,16,1,576,512,True,True,0.404,112,3728 diff --git a/run_test.sh b/run_test.sh index 88aed47..6ad9b6e 100644 --- a/run_test.sh +++ b/run_test.sh @@ -2,7 +2,7 @@ set -e OUTPUT_FOLDER=${OUTPUT_FOLDER:-"results"} -PREFIX=${PREFIX:-"H800"} +PREFIX=${PREFIX:-"MI300X"} function parse_mla_result_to_csv { # Create output CSV file with header