Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ mac = [
]

gpu = [
"sglang[all]==0.5.7",
"sglang[all]==0.5.12",
"accelerate",
"mlx-lm==0.28.4",
"mlx[cpu]==0.30.0",
]
Expand Down
4 changes: 3 additions & 1 deletion src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,12 @@ def _handle_raw_request(self, raw_request: Dict):
if self.tokenizer.chat_template:
messages = raw_request["messages"]
process_message_content(messages)
chat_template_kwargs = raw_request.get("chat_template_kwargs", {})
chat_template_kwargs = dict(raw_request.get("chat_template_kwargs", {}))
# check extra_body for backward compatibility
if "extra_body" in raw_request and "chat_template_kwargs" in raw_request["extra_body"]:
chat_template_kwargs.update(raw_request["extra_body"]["chat_template_kwargs"])
# Transformers 5.x defaults return_dict=True, but Parallax expects list[int].
chat_template_kwargs["return_dict"] = False

prompt = self.tokenizer.apply_chat_template(
messages,
Expand Down
2 changes: 1 addition & 1 deletion src/parallax/server/executor/sglang_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
"dp_rank": dp_rank,
"dp_size": dp_size,
"nccl_port": nccl_port,
"using_hfcache": use_hfcache,
"use_hfcache": use_hfcache,
"enable_lora": self.enable_lora,
"max_lora_rank": self.max_lora_rank,
"lora_target_modules": self.lora_target_modules,
Expand Down
27 changes: 14 additions & 13 deletions src/parallax/sglang/batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
"""

from types import SimpleNamespace
from typing import List, Optional

import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import (
Expand All @@ -18,7 +20,6 @@
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

from parallax.server.executor.sglang_executor import PageRadixCache
from parallax.server.request import Request
from parallax.server.sampling.sampling_params import (
SamplingParams as ParallaxSamplingParams,
Expand Down Expand Up @@ -95,23 +96,23 @@ def form_sgl_batch_prefill(
) -> ForwardBatch:
"""Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow"""

sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache)
tree_cache = page_tree_cache
if tree_cache is None:
cache_params = CacheInitParams(
disable=True,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
page_size=model_runner.server_args.page_size,
)
tree_cache = ChunkCache(cache_params)

def dummy_evict(*args):
pass
sgl_reqs = transform_requests_to_sglang(requests, tree_cache)

dummy_tree_cache = SimpleNamespace(
page_size=model_runner.server_args.page_size,
device=model_runner.device,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
evictable_size=0,
)
dummy_tree_cache.evict = dummy_evict
schedule_batch = ScheduleBatch.init_new(
reqs=sgl_reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=page_tree_cache if page_tree_cache is not None else dummy_tree_cache,
tree_cache=tree_cache,
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
Expand Down
14 changes: 14 additions & 0 deletions src/parallax/sglang/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
init_distributed_environment,
set_custom_all_reduce,
set_mscclpp_all_reduce,
set_torch_symm_mem_all_reduce,
)
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
Expand Down Expand Up @@ -118,6 +119,8 @@ def init_torch_distributed(self):
backend = "gloo"
elif self.device == "npu":
backend = "hccl"
else:
backend = "gloo"

before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
if not self.server_args.enable_p2p_check:
Expand All @@ -129,6 +132,7 @@ def init_torch_distributed(self):
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)

if not self.is_draft_worker:
if self.device == "cpu":
Expand All @@ -153,14 +157,21 @@ def init_torch_distributed(self):
local_rank=self.gpu_id,
distributed_init_method=dist_init_method,
timeout=self.server_args.dist_timeout,
moe_a2a_backend=self.server_args.moe_a2a_backend,
recovered_rank=self.server_args.elastic_ep_rejoin,
)

# Use monkey patch modified function
sglang.srt.distributed.parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.tp_size,
pipeline_model_parallel_size=self.pp_size,
expert_model_parallel_size=self.moe_ep_size,
attention_data_parallel_size=self.dp_size,
attention_context_model_parallel_size=self.attn_cp_size,
moe_data_model_parallel_size=self.moe_dp_size,
duplicate_tp_group=self.server_args.enable_pdmux,
enable_symm_mem=self.server_args.enable_symm_mem,
recovered_rank=self.server_args.elastic_ep_rejoin,
pp_start_layer=self.pp_start_layer,
pp_end_layer=self.pp_end_layer,
hidden_layers=self.model_config.num_hidden_layers,
Expand Down Expand Up @@ -225,6 +236,7 @@ def form_sgl_server_args(
lora_eviction_policy: Optional[str] = "lru",
lora_backend: Optional[str] = "triton",
max_lora_chunk_size: Optional[int] = 128,
max_num_tokens_per_batch: int = 16384,
):
"""Creates a SGL ServerArgs object"""
sgl_server_args = ServerArgs(
Expand All @@ -247,6 +259,7 @@ def form_sgl_server_args(
lora_backend=lora_backend,
max_lora_chunk_size=max_lora_chunk_size,
dp_size=dp_size,
max_total_tokens=max_num_tokens_per_batch,
)
return sgl_server_args

Expand Down Expand Up @@ -338,6 +351,7 @@ def initialize_sgl_model_runner(
lora_eviction_policy,
lora_backend,
max_lora_chunk_size,
max_num_tokens_per_batch=max_num_tokens_per_batch,
)
initialize_moe_config(server_args)
quant_method = None
Expand Down
Loading
Loading