From 722530fa018290fd3921c8f030fb806b190f32b7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 20 Nov 2024 02:58:35 -0800 Subject: [PATCH] Enable overlap scheduler by default for the triton attention backend (#2105) --- .../sglang/srt/layers/attention/triton_backend.py | 2 +- python/sglang/srt/managers/scheduler.py | 15 +++------------ .../srt/managers/tp_worker_overlap_thread.py | 11 +++++++++++ .../srt/model_executor/cuda_graph_runner.py | 9 +++------ python/sglang/srt/server_args.py | 7 ++----- scripts/killall_sglang.sh | 1 + 6 files changed, 21 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index b1ec3fd6d5e..b9597b3ea41 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -53,7 +53,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - total_num_tokens = torch.sum(forward_batch.seq_lens).item() + total_num_tokens = forward_batch.seq_lens_sum attn_logits = torch.empty( (self.num_head, total_num_tokens), dtype=self.reduce_dtype, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index be0a0f699ca..6241ed6b260 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -170,18 +170,9 @@ def __init__( if not self.is_generation: self.enable_overlap = False logger.info("Overlap scheduler is disabled for embedding models.") - if ( - server_args.attention_backend == "triton" - or server_args.enable_double_sparsity - or ( - self.model_config.attention_arch == AttentionArch.MLA - and not self.server_args.disable_mla - ) - ): - self.enable_overlap = False - logger.info( - "Overlap scheduler is disabled if using triton attention backend." - ) + + if self.enable_overlap: + self.disable_jump_forward = True # Launch a tensor parallel worker if self.enable_overlap: diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index ab37ceed261..74fd0bb6203 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -94,10 +94,21 @@ def forward_thread_func(self): @torch.no_grad() def forward_thread_func_(self): + batch_pt = 0 + batch_lists = [None] * 2 + while True: model_worker_batch, future_token_ids_ct = self.input_queue.get() if not model_worker_batch: break + + # Keep a reference of model_worker_batch by storing it into a list. + # Otherwise, the tensor members of model_worker_batch will be released + # by pytorch and cause CUDA illegal memory access errors. + batch_lists[batch_pt % 2] = model_worker_batch + batch_pt += 1 + + # Create event self.launch_done = threading.Event() copy_done = torch.cuda.Event() diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 91c6603a2ef..a00a30b027d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -170,7 +170,6 @@ def __init__(self, model_runner: "ModelRunner"): self.encoder_lens = None if self.enable_dp_attention: - self.global_num_tokens = [0] * self.tp_size self.gathered_buffer = torch.zeros( ( self.max_bs * self.tp_size, @@ -264,10 +263,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable): mrope_positions = self.mrope_positions[:, :bs] if self.enable_dp_attention: - self.global_num_tokens[:] = [bs] * self.tp_size + global_num_tokens = [bs] * self.tp_size gathered_buffer = self.gathered_buffer[: bs * self.tp_size] else: - self.global_num_tokens = None + global_num_tokens = None gathered_buffer = None # Attention backend @@ -296,7 +295,7 @@ def run_once(): top_logprobs_nums=[0] * bs, positions=clamp_position(seq_lens), mrope_positions=mrope_positions, - global_num_tokens=self.global_num_tokens, + global_num_tokens=global_num_tokens, gathered_buffer=gathered_buffer, ) logits_output = forward(input_ids, forward_batch.positions, forward_batch) @@ -348,8 +347,6 @@ def replay(self, forward_batch: ForwardBatch): self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) - if self.enable_dp_attention: - self.global_num_tokens[:] = [bs] * self.tp_size # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e1cbbd29fe0..4a1cad89e76 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -174,17 +174,17 @@ def __post_init__(self): self.cuda_graph_max_bs = 4 logger.info("Automatically adjust --chunked-prefill-size for small GPUs.") + # Choose kernel backends if not is_flashinfer_available(): self.attention_backend = "triton" self.sampling_backend = "pytorch" - # Default kernel backends if self.attention_backend is None: self.attention_backend = "flashinfer" - if self.sampling_backend is None: self.sampling_backend = "flashinfer" + # Others if self.enable_dp_attention: self.dp_size = self.tp_size self.chunked_prefill_size = self.chunked_prefill_size // 2 @@ -205,9 +205,6 @@ def __post_init__(self): ) self.disable_overlap_schedule = True - if not self.disable_overlap_schedule: - self.disable_jump_forward = True - @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 203da604021..fcad493c59c 100644 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -2,3 +2,4 @@ kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') +kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}')