Skip to content

Commit

Permalink
Enable overlap scheduler by default for the triton attention backend (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 20, 2024
1 parent 56a347f commit 722530f
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 24 deletions.
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)

total_num_tokens = torch.sum(forward_batch.seq_lens).item()
total_num_tokens = forward_batch.seq_lens_sum
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
Expand Down
15 changes: 3 additions & 12 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,9 @@ def __init__(
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if (
server_args.attention_backend == "triton"
or server_args.enable_double_sparsity
or (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
)
):
self.enable_overlap = False
logger.info(
"Overlap scheduler is disabled if using triton attention backend."
)

if self.enable_overlap:
self.disable_jump_forward = True

# Launch a tensor parallel worker
if self.enable_overlap:
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,21 @@ def forward_thread_func(self):

@torch.no_grad()
def forward_thread_func_(self):
batch_pt = 0
batch_lists = [None] * 2

while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break

# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists[batch_pt % 2] = model_worker_batch
batch_pt += 1

# Create event
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()

Expand Down
9 changes: 3 additions & 6 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def __init__(self, model_runner: "ModelRunner"):
self.encoder_lens = None

if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
Expand Down Expand Up @@ -264,10 +263,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
mrope_positions = self.mrope_positions[:, :bs]

if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
global_num_tokens = None
gathered_buffer = None

# Attention backend
Expand Down Expand Up @@ -296,7 +295,7 @@ def run_once():
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
Expand Down Expand Up @@ -348,8 +347,6 @@ def replay(self, forward_batch: ForwardBatch):
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
Expand Down
7 changes: 2 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,17 @@ def __post_init__(self):
self.cuda_graph_max_bs = 4
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")

# Choose kernel backends
if not is_flashinfer_available():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"

# Default kernel backends
if self.attention_backend is None:
self.attention_backend = "flashinfer"

if self.sampling_backend is None:
self.sampling_backend = "flashinfer"

# Others
if self.enable_dp_attention:
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
Expand All @@ -205,9 +205,6 @@ def __post_init__(self):
)
self.disable_overlap_schedule = True

if not self.disable_overlap_schedule:
self.disable_jump_forward = True

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
Expand Down
1 change: 1 addition & 0 deletions scripts/killall_sglang.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}')

0 comments on commit 722530f

Please sign in to comment.