|
| 1 | +import argparse |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import nvtx |
| 5 | +import torch |
| 6 | +import yaml |
| 7 | + |
| 8 | +from tensorrt_llm._torch.autotuner import AutoTuner, autotune |
| 9 | +from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream |
| 10 | +from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size |
| 11 | +from tensorrt_llm.tools.layer_wise_benchmarks.deepseekv3_runner import ( |
| 12 | + BalanceMethod, DeepSeekV3Runner) |
| 13 | + |
| 14 | + |
| 15 | +def comma_separated_ints(s): |
| 16 | + return [int(x) for x in s.split(",")] |
| 17 | + |
| 18 | + |
| 19 | +# Parse cmdline |
| 20 | +parser = argparse.ArgumentParser() |
| 21 | +parser.add_argument("config_path", type=str) |
| 22 | +parser.add_argument("--model", type=str, help="Pretrained model name or path") |
| 23 | +parser.add_argument( |
| 24 | + "--layer-indices", |
| 25 | + type=comma_separated_ints, |
| 26 | + help="Comma separated indices of layers, should be a contiguous range") |
| 27 | +parser.add_argument("--run-type", type=str, choices=["CTX", "GEN"]) |
| 28 | +parser.add_argument("--scaled-from", type=int) |
| 29 | +# KV cache related args |
| 30 | +parser.add_argument("--tokens-per-block", type=int) |
| 31 | +parser.add_argument("--max-seq-len", type=int) |
| 32 | +group = parser.add_mutually_exclusive_group(required=False) |
| 33 | +group.add_argument("--enable-attention-dp", |
| 34 | + action="store_true", |
| 35 | + dest="enable_attention_dp") |
| 36 | +group.add_argument("--no-enable-attention-dp", |
| 37 | + action="store_false", |
| 38 | + dest="enable_attention_dp") |
| 39 | +parser.set_defaults(enable_attention_dp=None) |
| 40 | +# Model init args |
| 41 | +parser.add_argument("--max-num-tokens", type=int) |
| 42 | +parser.add_argument("--moe-backend", type=str) |
| 43 | +group = parser.add_mutually_exclusive_group(required=False) |
| 44 | +group.add_argument("--use-cuda-graph", |
| 45 | + action="store_true", |
| 46 | + dest="use_cuda_graph") |
| 47 | +group.add_argument("--no-use-cuda-graph", |
| 48 | + action="store_false", |
| 49 | + dest="use_cuda_graph") |
| 50 | +parser.set_defaults(use_cuda_graph=None) |
| 51 | +# Per iteration args |
| 52 | +parser.add_argument("--batch-size", type=int) |
| 53 | +parser.add_argument("--seq-len-q", type=int) |
| 54 | +parser.add_argument("--seq-len-kv-cache", type=int) |
| 55 | +parser.add_argument("--balance-method", type=str) |
| 56 | +parser.add_argument("--balance-ratio", type=float) |
| 57 | +args = parser.parse_args() |
| 58 | +with open(args.config_path) as f: |
| 59 | + config = yaml.safe_load(f) |
| 60 | +del args.config_path |
| 61 | +for k, v in vars(args).items(): |
| 62 | + if v is None: |
| 63 | + setattr(args, k, config[k]) |
| 64 | +print(args) |
| 65 | + |
| 66 | +# MPI args |
| 67 | +rank = mpi_rank() |
| 68 | +world_size = mpi_world_size() |
| 69 | +local_rank = local_mpi_rank() |
| 70 | +torch.cuda.set_device(local_rank) |
| 71 | + |
| 72 | +# Create KV cache manager |
| 73 | +mapping = DeepSeekV3Runner.create_mapping( |
| 74 | + enable_attention_dp=args.enable_attention_dp) |
| 75 | +max_batch_size = 2048 |
| 76 | +kv_cache_manager = DeepSeekV3Runner.create_kv_cache_manager( |
| 77 | + args.model, |
| 78 | + mapping, |
| 79 | + tokens_per_block=args.tokens_per_block, |
| 80 | + max_batch_size=max_batch_size, |
| 81 | + max_seq_len=args.max_seq_len, |
| 82 | + layer_indices=args.layer_indices) |
| 83 | +attn_workspace = torch.empty((0, ), device="cuda", dtype=torch.int8) |
| 84 | + |
| 85 | +# Create other global objects |
| 86 | +AutoTuner.get().clear_cache() |
| 87 | +capture_stream = torch.cuda.Stream() |
| 88 | + |
| 89 | +# Create Runner |
| 90 | +runner = DeepSeekV3Runner(args.model, |
| 91 | + mapping, |
| 92 | + moe_backend=args.moe_backend, |
| 93 | + layer_indices=args.layer_indices, |
| 94 | + scaled_from=args.scaled_from, |
| 95 | + max_seq_len=args.max_seq_len, |
| 96 | + max_num_tokens=args.max_num_tokens, |
| 97 | + use_cuda_graph=args.use_cuda_graph) |
| 98 | + |
| 99 | +# Warm up |
| 100 | +assert args.batch_size <= max_batch_size |
| 101 | +assert args.seq_len_q + args.seq_len_kv_cache <= args.max_seq_len |
| 102 | +run_pack = runner.create_run_pack(args.run_type, |
| 103 | + batch_size=args.batch_size, |
| 104 | + seq_len_q=args.seq_len_q, |
| 105 | + seq_len_kv_cache=args.seq_len_kv_cache, |
| 106 | + kv_cache_manager=kv_cache_manager, |
| 107 | + attn_workspace=attn_workspace) |
| 108 | +runner.replace_routing_method(balance_method=BalanceMethod[args.balance_method], |
| 109 | + balance_ratio=args.balance_ratio) |
| 110 | +capture_stream.wait_stream(torch.cuda.current_stream()) |
| 111 | +with torch.cuda.stream(capture_stream): |
| 112 | + run_pack() |
| 113 | + with autotune(): |
| 114 | + run_pack() |
| 115 | +torch.cuda.current_stream().wait_stream(capture_stream) |
| 116 | +torch.cuda.synchronize() |
| 117 | + |
| 118 | +# Profile: capture graph and replay it |
| 119 | +torch.cuda.cudart().cudaProfilerStart() |
| 120 | +if args.use_cuda_graph: |
| 121 | + with with_multi_stream(True): |
| 122 | + g = torch.cuda.CUDAGraph() |
| 123 | + with torch.cuda.graph(g, |
| 124 | + stream=capture_stream, |
| 125 | + capture_error_mode="global"): |
| 126 | + run_pack() |
| 127 | + |
| 128 | +warmup_times = 20 |
| 129 | +run_times = 100 |
| 130 | +events = [ |
| 131 | + torch.cuda.Event(enable_timing=True) |
| 132 | + for _ in range(warmup_times + run_times + 1) |
| 133 | +] |
| 134 | +for i in range(warmup_times + run_times): |
| 135 | + events[i].record() |
| 136 | + with nvtx.annotate( |
| 137 | + f"b={args.batch_size} s={args.seq_len_q} EP{world_size}"): |
| 138 | + if args.use_cuda_graph: |
| 139 | + g.replay() |
| 140 | + else: |
| 141 | + run_pack() |
| 142 | +events[-1].record() |
| 143 | +torch.cuda.synchronize() |
| 144 | + |
| 145 | +# Print statistics |
| 146 | +# Print before `cudaProfilerStop` to ensure messages are included in the profile |
| 147 | +time_list = [ |
| 148 | + start.elapsed_time(stop) for start, stop in zip(events, events[1:]) |
| 149 | +] |
| 150 | +time_list = time_list[warmup_times:] |
| 151 | +print(f"[RANK {rank}]" |
| 152 | + f" min {np.min(time_list) * 1000:.1f}" |
| 153 | + f" max {np.max(time_list) * 1000:.1f}" |
| 154 | + f" mean {np.mean(time_list) * 1000:.1f}" |
| 155 | + f" median {np.median(time_list) * 1000:.1f}" |
| 156 | + f" P90 {np.percentile(time_list, 90) * 1000:.1f}" |
| 157 | + f" (us)") |
| 158 | + |
| 159 | +torch.cuda.cudart().cudaProfilerStop() |
0 commit comments