Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added python/__pycache__/common.cpython-312.pyc
Binary file not shown.
Binary file added python/__pycache__/deepgemm_utils.cpython-312.pyc
Binary file not shown.
17 changes: 16 additions & 1 deletion python/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ def param_num_to_GB(param_num, ele_size=1):
"volume": 80,
"intra_node_bw": 180,
"inter_node_bw": 39,
},
"H200-144": {
"volume": 144,
"intra_node_bw": 380,
"inter_node_bw": 39,
},
"MI300X-192": {
"volume": 192,
"intra_node_bw": 315,
"inter_node_bw": 39,
},
"MI308X-192": {
"volume": 192,
"intra_node_bw": 315,
"inter_node_bw": 39,
}
}

Expand Down Expand Up @@ -96,7 +111,7 @@ def total_nodup_expert_params(self):
@dataclass
class TestConfig:
device_nums: List[int] = None
s: int = 5000
s: int = 5000 # mean seqlen
gpu: str = "H800-80"
model_config: ModelConfig = None
tp_nums: List[int] = None
Expand Down
288 changes: 288 additions & 0 deletions python/deepgemm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# porting from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/utils.py

import os
import sys
import time
import torch
import torch.distributed as dist


def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
cache.zero_()

# Warmup
for _ in range(num_warmups):
fn()

# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
x @ y

# Testing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_tests):
fn()
end_event.record()
torch.cuda.synchronize()

return start_event.elapsed_time(end_event) / num_tests


class empty_suppress:
def __enter__(self):
return self

def __exit__(self, *_):
pass


class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')

self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()

self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())

self.old_stdout = sys.stdout
self.old_stderr = sys.stderr

os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)

sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self

def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr

os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)

os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)

self.outnull_file.close()
self.errnull_file.close()


def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True):
# Conflict with Nsight Systems
using_nsys = os.environ.get('DG_NSYS_PROFILING', False)

# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
# this avoid thermal throttling while keeping DVFS at max clocks (slight gain vs sleep / more consistent on GH200)
sleep_between_tests = 0.0
flush_l2_size = int(8e9 // 4)
if os.environ.get('DG_BENCH_DISABLE_L2_FLUSH', False):
flush_l2 = False
if os.environ.get('DG_BENCH_POWER_LIMITED', False):
# if we want to be thermally limited, we need to run many iterations non-stop for a fairly long time
# and spend as little time as possible doing memset and other setup work (80MiB should be enough to flush L2)
num_tests = 2000
flush_l2_size = int(80e6 // 4)
sleep_val = os.environ.get('DG_BENCH_SLEEP_BETWEEN_TESTS', False)
if sleep_val:
try:
sleep_between_tests = float(sleep_val)
except ValueError:
pass # Keep default

# For some auto-tuning kernels with prints
fn()

# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
if sleep_between_tests > 0.0:
time.sleep(sleep_between_tests)
if flush_l2:
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
fn()

if not using_nsys:
profiler.step()

# Return 1 if using Nsight Systems
if using_nsys:
return 1

# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
# print(f"bill-dbg: prof_lines: ")
# print(profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100))
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'

# Save chrome traces
if trace_path is not None:
profiler.export_chrome_trace(trace_path)

# Return average kernel times
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, '')) / scale)
break
break
return tuple(kernel_times) if is_tupled else kernel_times[0]


def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim


def count_bytes(tensors):
total = 0
for t in tensors:
if isinstance(t, tuple):
total += count_bytes(t)
else:
total += t.numel() * t.element_size()
return total


_num_sms = None


def set_num_sms(num_sms: int) -> None:
"""
Set the maximum SM count for all GEMM kernels to use.

Arguments:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
global _num_sms
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
_num_sms = num_sms


def get_num_sms() -> int:
"""
Get the current maximum limit of SM count for all GEMM kernels to use.
If the count is never specified, the function will return the number of device SMs.

Returns:
Current maximum limit of SM count for all GEMM kernels to use.
"""
global _num_sms
if _num_sms is None:
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
return _num_sms


def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.

Args:
x: the dividend.
y: the divisor.

Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y


def get_m_alignment_for_contiguous_layout():
"""
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
with GEMM block shape.

Returns:
Group-level alignment requirement for grouped contiguous layout, which is always 128.
"""
return 128


def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.

Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.

Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment


def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.

Arguments:
x: usually the LHS scaling tensor in GEMM.

Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True

b = x.shape[0]

# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
return x.squeeze(0) if remove_dim else x

# Normal layout requires transposing
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
20 changes: 10 additions & 10 deletions python/process_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def process_data(dense_gemm_file: str, group_gemm_file: str, batch_gemm_file: st
(group_df['b_mla'] == b_mla) &
(group_df['m_per_group'] == m_per_group) &
(group_df['matrix_idx'] == 7)]['time_us'].iloc[0])
down_gemm = int(group_df[(group_df['d'] == d) &
(group_df['m_per_group'] == m_per_group) &
(group_df['b_mla'] == b_mla) &
(group_df['matrix_idx'] == 8)]['time_us'].iloc[0])
# down_gemm = int(group_df[(group_df['d'] == d) &
# (group_df['m_per_group'] == m_per_group) &
# (group_df['b_mla'] == b_mla) &
# (group_df['matrix_idx'] == 8)]['time_us'].iloc[0])

dispatch_alltoall = int(
config.calculate_alltoall_time(d, tp, b_mla, True))
Expand All @@ -65,15 +65,15 @@ def process_data(dense_gemm_file: str, group_gemm_file: str, batch_gemm_file: st
# 计算两种模式下的层时间
# Two microbatch overlapping
t_moe_layer_two = int(2 * (max(dispatch_alltoall, shared_time + qkv_time) +
up_gemm + down_gemm +
up_gemm +
max(attn_time + o_time + allreduce, combine_alltoall)))
t_dense_layer_two = int(
2 * (shared_time + qkv_time + up_gemm + down_gemm + attn_time + o_time + allreduce))
2 * (shared_time + qkv_time + up_gemm + attn_time + o_time + allreduce))
# Single batch comp-compute overlapping
t_moe_layer_single = int(max(dispatch_alltoall, shared_time) + qkv_time + up_gemm +
max(down_gemm, combine_alltoall) + attn_time + o_time + allreduce)
t_moe_layer_single = int(max(dispatch_alltoall, shared_time) + qkv_time +
max(up_gemm, combine_alltoall) + attn_time + o_time + allreduce)
t_dense_layer_single = int(
shared_time + qkv_time + up_gemm + down_gemm + attn_time + o_time + allreduce)
shared_time + qkv_time + up_gemm + attn_time + o_time + allreduce)

# 计算TPOT和吞吐量
tpot_two = int((t_moe_layer_two * 58 + t_dense_layer_two * 3) / 1000)
Expand All @@ -93,7 +93,7 @@ def process_data(dense_gemm_file: str, group_gemm_file: str, batch_gemm_file: st
'O(us)': o_time,
'Shared(us)': shared_time,
'Up_Gemm(us)': up_gemm,
'Down_Gemm(us)': down_gemm,
'Down_Gemm(us)': 0,
'Dispatch_AlltoAll(us)': dispatch_alltoall,
'Combine_AlltoAll(us)': combine_alltoall,
'AllReduce(us)': allreduce,
Expand Down
Loading