From 50dc0433532df8226b98cfe0fcc61ae9c5ea4f8b Mon Sep 17 00:00:00 2001 From: gufengc Date: Fri, 22 May 2026 05:07:41 +0000 Subject: [PATCH 1/2] feat(sglang): upgrade sglang to 0.5.12 --- pyproject.toml | 3 +- src/parallax/server/executor/base_executor.py | 4 +- .../server/executor/sglang_executor.py | 2 +- src/parallax/sglang/batch_info.py | 27 +-- src/parallax/sglang/model_runner.py | 14 ++ .../monkey_patch_utils/model_parallel.py | 214 ++++++------------ .../monkey_patch_utils/triton_backend.py | 119 +++------- 7 files changed, 130 insertions(+), 253 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c0ef40f..8b4cf7ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ mac = [ ] gpu = [ - "sglang[all]==0.5.7", + "sglang[all]==0.5.12", + "accelerate", "mlx-lm==0.28.4", "mlx[cpu]==0.30.0", ] diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 7180242c..3d8d0089 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -618,10 +618,12 @@ def _handle_raw_request(self, raw_request: Dict): if self.tokenizer.chat_template: messages = raw_request["messages"] process_message_content(messages) - chat_template_kwargs = raw_request.get("chat_template_kwargs", {}) + chat_template_kwargs = dict(raw_request.get("chat_template_kwargs", {})) # check extra_body for backward compatibility if "extra_body" in raw_request and "chat_template_kwargs" in raw_request["extra_body"]: chat_template_kwargs.update(raw_request["extra_body"]["chat_template_kwargs"]) + # Transformers 5.x defaults return_dict=True, but Parallax expects list[int]. + chat_template_kwargs["return_dict"] = False prompt = self.tokenizer.apply_chat_template( messages, diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 555db81c..0f4979b5 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -126,7 +126,7 @@ def __init__( "dp_rank": dp_rank, "dp_size": dp_size, "nccl_port": nccl_port, - "using_hfcache": use_hfcache, + "use_hfcache": use_hfcache, "enable_lora": self.enable_lora, "max_lora_rank": self.max_lora_rank, "lora_target_modules": self.lora_target_modules, diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 56720040..e8f5e731 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -5,11 +5,13 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch """ -from types import SimpleNamespace from typing import List, Optional import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.mem_cache.cache_init_params import CacheInitParams +from sglang.srt.mem_cache.chunk_cache import ChunkCache +from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -18,7 +20,6 @@ from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from parallax.server.executor.sglang_executor import PageRadixCache from parallax.server.request import Request from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, @@ -95,23 +96,23 @@ def form_sgl_batch_prefill( ) -> ForwardBatch: """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" - sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache) + tree_cache = page_tree_cache + if tree_cache is None: + cache_params = CacheInitParams( + disable=True, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + page_size=model_runner.server_args.page_size, + ) + tree_cache = ChunkCache(cache_params) - def dummy_evict(*args): - pass + sgl_reqs = transform_requests_to_sglang(requests, tree_cache) - dummy_tree_cache = SimpleNamespace( - page_size=model_runner.server_args.page_size, - device=model_runner.device, - token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - evictable_size=0, - ) - dummy_tree_cache.evict = dummy_evict schedule_batch = ScheduleBatch.init_new( reqs=sgl_reqs, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, - tree_cache=page_tree_cache if page_tree_cache is not None else dummy_tree_cache, + tree_cache=tree_cache, model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index e24f2595..8ba9aa6d 100644 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -21,6 +21,7 @@ init_distributed_environment, set_custom_all_reduce, set_mscclpp_all_reduce, + set_torch_symm_mem_all_reduce, ) from sglang.srt.layers.dp_attention import ( get_attention_tp_group, @@ -118,6 +119,8 @@ def init_torch_distributed(self): backend = "gloo" elif self.device == "npu": backend = "hccl" + else: + backend = "gloo" before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) if not self.server_args.enable_p2p_check: @@ -129,6 +132,7 @@ def init_torch_distributed(self): dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_mscclpp_all_reduce(self.server_args.enable_mscclpp) + set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem) if not self.is_draft_worker: if self.device == "cpu": @@ -153,6 +157,8 @@ def init_torch_distributed(self): local_rank=self.gpu_id, distributed_init_method=dist_init_method, timeout=self.server_args.dist_timeout, + moe_a2a_backend=self.server_args.moe_a2a_backend, + recovered_rank=self.server_args.elastic_ep_rejoin, ) # Use monkey patch modified function @@ -160,7 +166,12 @@ def init_torch_distributed(self): tensor_model_parallel_size=self.tp_size, pipeline_model_parallel_size=self.pp_size, expert_model_parallel_size=self.moe_ep_size, + attention_data_parallel_size=self.dp_size, + attention_context_model_parallel_size=self.attn_cp_size, + moe_data_model_parallel_size=self.moe_dp_size, duplicate_tp_group=self.server_args.enable_pdmux, + enable_symm_mem=self.server_args.enable_symm_mem, + recovered_rank=self.server_args.elastic_ep_rejoin, pp_start_layer=self.pp_start_layer, pp_end_layer=self.pp_end_layer, hidden_layers=self.model_config.num_hidden_layers, @@ -225,6 +236,7 @@ def form_sgl_server_args( lora_eviction_policy: Optional[str] = "lru", lora_backend: Optional[str] = "triton", max_lora_chunk_size: Optional[int] = 128, + max_num_tokens_per_batch: int = 16384, ): """Creates a SGL ServerArgs object""" sgl_server_args = ServerArgs( @@ -247,6 +259,7 @@ def form_sgl_server_args( lora_backend=lora_backend, max_lora_chunk_size=max_lora_chunk_size, dp_size=dp_size, + max_total_tokens=max_num_tokens_per_batch, ) return sgl_server_args @@ -338,6 +351,7 @@ def initialize_sgl_model_runner( lora_eviction_policy, lora_backend, max_lora_chunk_size, + max_num_tokens_per_batch=max_num_tokens_per_batch, ) initialize_moe_config(server_args) quant_method = None diff --git a/src/parallax/sglang/monkey_patch_utils/model_parallel.py b/src/parallax/sglang/monkey_patch_utils/model_parallel.py index e387936e..234c1529 100644 --- a/src/parallax/sglang/monkey_patch_utils/model_parallel.py +++ b/src/parallax/sglang/monkey_patch_utils/model_parallel.py @@ -18,9 +18,11 @@ """ import logging +from datetime import timedelta from typing import Any, Dict, List, Optional, Tuple, Union import sglang +import sglang.srt.distributed import sglang.srt.distributed.parallel_state import torch from sglang.srt.distributed import get_world_group @@ -30,9 +32,8 @@ from sglang.srt.utils import ( LayerFn, add_prefix, - cpu_has_amx_support, - get_bool_env_var, is_npu, + is_xpu, ) from torch.distributed import Backend @@ -40,7 +41,11 @@ logger = logging.getLogger(__name__) -_is_cpu_amx_available = cpu_has_amx_support() +_is_npu = is_npu() +_is_xpu = is_xpu() +_sgl_initialize_model_parallel = ( + sglang.srt.distributed.parallel_state.initialize_model_parallel +) class ParallaxGroupCoordinator(SGLGroupCoordinator): @@ -64,6 +69,8 @@ def __init__( use_torch_symm_mem: bool = False, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, + gloo_timeout: timedelta = timedelta(seconds=120 * 60), + recovered_rank: bool = False, pp_start_layer: int = 0, pp_end_layer: int = 0, hidden_layers: int = 0, @@ -82,6 +89,8 @@ def __init__( use_torch_symm_mem_all_reduce=use_torch_symm_mem, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, + gloo_timeout=gloo_timeout, + recovered_rank=recovered_rank, ) self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer @@ -102,10 +111,13 @@ def monkey_patch_init_model_parallel_group( group_ranks: List[List[int]], local_rank: int, backend: str, + use_pynccl: Optional[bool] = None, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, use_mscclpp_allreduce: Optional[bool] = None, + use_torch_symm_mem_allreduce: Optional[bool] = None, + recovered_rank: bool = False, pp_start_layer: int = 0, pp_end_layer: int = 0, hidden_layers: int = 0, @@ -115,18 +127,28 @@ def monkey_patch_init_model_parallel_group( use_custom_allreduce = sglang.srt.distributed.parallel_state._ENABLE_CUSTOM_ALL_REDUCE if use_mscclpp_allreduce is None: use_mscclpp_allreduce = sglang.srt.distributed.parallel_state._ENABLE_MSCCLPP_ALL_REDUCE + if use_torch_symm_mem_allreduce is None: + use_torch_symm_mem_allreduce = ( + sglang.srt.distributed.parallel_state._ENABLE_TORCH_SYMM_MEM_ALL_REDUCE + ) return ParallaxGroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not is_npu(), + use_pynccl=( + not (_is_npu or _is_xpu or backend == "mooncake") + if use_pynccl is None + else use_pynccl + ), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, + use_torch_symm_mem=use_torch_symm_mem_allreduce, use_hpu_communicator=True, use_xpu_communicator=True, use_npu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, + recovered_rank=recovered_rank, pp_start_layer=pp_start_layer, pp_end_layer=pp_end_layer, hidden_layers=hidden_layers, @@ -137,152 +159,51 @@ def monkey_patch_initialize_model_parallel( tensor_model_parallel_size: int = 1, expert_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + attention_data_parallel_size: int = 1, + attention_context_model_parallel_size: int = 1, + moe_data_model_parallel_size: int = 1, backend: Optional[str] = None, duplicate_tp_group: bool = False, + enable_symm_mem: bool = False, + recovered_rank: bool = False, pp_start_layer: int = 0, pp_end_layer: int = 0, hidden_layers: int = 0, ) -> None: """A monkey patch to replace sglang.srt.distributed.parallel_state.initialize_model_parallel""" - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - - if world_size != tensor_model_parallel_size * pipeline_model_parallel_size: - raise RuntimeError( - f"world_size ({world_size}) is not equal to " - f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})" - ) - - # Build the tensor model-parallel groups. - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - ############################################################################ - ## This is a patch code for sgalng - ## Ignore parallel state already set alert - # assert ( - # sglang.srt.distributed.parallel_state._TP is None - # ), "tensor model parallel group is already initialized" - ## End of patch - ############################################################################ - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) - - # message queue broadcaster is only used in tensor model parallel group - sglang.srt.distributed.parallel_state._TP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=get_bool_env_var( - "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" - ), - group_name="tp", - ) - ) - - if duplicate_tp_group: - global _PDMUX_PREFILL_TP_GROUP - assert ( - _PDMUX_PREFILL_TP_GROUP is None - ), "tensor model parallel group for PD-Multiplexing Prefill is already initialized" - _PDMUX_PREFILL_TP_GROUP = sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=get_bool_env_var( - "SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true" - ), - group_name="pdmux_prefill_tp", - ) - sglang.srt.distributed.parallel_state._TP.pynccl_comm.disabled = False - _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False - - moe_ep_size = expert_model_parallel_size - - moe_tp_size = tensor_model_parallel_size // moe_ep_size - ############################################################################ - ## This is a patch code for sgalng - ## Ignore parallel state already set alert - # assert ( - # sglang.srt.distributed.parallel_state._MOE_EP is None - # ), "expert model parallel group is already initialized" - ## End of patch - ############################################################################ - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_tp_size): - st = i * tensor_model_parallel_size + j - en = (i + 1) * tensor_model_parallel_size + j - ranks = list(range(st, en, moe_tp_size)) - group_ranks.append(ranks) - - sglang.srt.distributed.parallel_state._MOE_EP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="moe_ep", - ) - ) - - ############################################################################ - ## This is a patch code for sgalng - ## Ignore parallel state already set alert - # assert ( - # sglang.srt.distributed.parallel_state._MOE_TP is None - # ), "expert model parallel group is already initialized" - ## End of patch - ############################################################################ - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_ep_size): - st = i * tensor_model_parallel_size + j * moe_tp_size - en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size - ranks = list(range(st, en)) - group_ranks.append(ranks) - - sglang.srt.distributed.parallel_state._MOE_TP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="moe_tp", + parallel_state = sglang.srt.distributed.parallel_state + if any( + getattr(parallel_state, group_name) is not None + for group_name in ( + "_TP", + "_PP", + "_MOE_EP", + "_MOE_TP", + "_ATTN_CP", + "_ATTN_TP", + "_MOE_DP", + "_PDMUX_PREFILL_TP_GROUP", ) + ): + parallel_state.destroy_model_parallel() + + _sgl_initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + attention_data_parallel_size=attention_data_parallel_size, + attention_context_model_parallel_size=attention_context_model_parallel_size, + moe_data_model_parallel_size=moe_data_model_parallel_size, + backend=backend, + duplicate_tp_group=duplicate_tp_group, + enable_symm_mem=enable_symm_mem, + recovered_rank=recovered_rank, ) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - ############################################################################ - ## This is a patch code for sgalng - ## Ignore parallel state already set alert - # assert ( - # sglang.srt.distributed.parallel_state._PP is None - # ), "pipeline model parallel group is already initialized" - ## End of patch - ############################################################################ - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - sglang.srt.distributed.parallel_state._PP = ( - sglang.srt.distributed.parallel_state.init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, - group_name="pp", - pp_start_layer=pp_start_layer, - pp_end_layer=pp_end_layer, - hidden_layers=hidden_layers, - ) - ) + pp_group = parallel_state._PP + pp_group.pp_start_layer = pp_start_layer + pp_group.pp_end_layer = pp_end_layer + pp_group.hidden_layers = hidden_layers def monkey_patch_make_layers( @@ -291,9 +212,9 @@ def monkey_patch_make_layers( pp_rank: Optional[int] = None, pp_size: Optional[int] = None, prefix: str = "", - return_tuple: bool = True, - offloader_kwargs: Dict[str, Any] = {}, -) -> Tuple[int, int, torch.nn.ModuleList]: + return_tuple: bool = False, + offloader_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[torch.nn.ModuleList, int, int]: """A monkey patch to replace sglang.srt.utils.make_layers""" # circula imports from sglang.srt.distributed import get_pp_group @@ -310,7 +231,7 @@ def monkey_patch_make_layers( layer_fn(idx=idx, prefix=add_prefix(idx, prefix)) for idx in range(start_layer, end_layer) ), - **offloader_kwargs, + **(offloader_kwargs or {}), ) + [PPMissingLayer(return_tuple=return_tuple) for _ in range(end_layer, num_hidden_layers)] ) @@ -326,4 +247,9 @@ def apply_model_parallel_monkey_patch(): sglang.srt.distributed.parallel_state.initialize_model_parallel = ( monkey_patch_initialize_model_parallel ) + sglang.srt.distributed.init_model_parallel_group = monkey_patch_init_model_parallel_group + sglang.srt.distributed.initialize_model_parallel = monkey_patch_initialize_model_parallel sglang.srt.utils.make_layers = monkey_patch_make_layers + import sglang.srt.utils.common as utils_common + + utils_common.make_layers = monkey_patch_make_layers diff --git a/src/parallax/sglang/monkey_patch_utils/triton_backend.py b/src/parallax/sglang/monkey_patch_utils/triton_backend.py index da7e2169..040b9271 100644 --- a/src/parallax/sglang/monkey_patch_utils/triton_backend.py +++ b/src/parallax/sglang/monkey_patch_utils/triton_backend.py @@ -2,9 +2,9 @@ import torch from sglang.srt.layers.attention.triton_backend import TritonAttnBackend -from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.utils import get_bool_env_var, get_device_core_count, get_int_env_var + +_original_triton_backend_init = TritonAttnBackend.__init__ def parallax_triton_backend_init( @@ -13,98 +13,31 @@ def parallax_triton_backend_init( skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, ): - # Lazy import to avoid the initialization of cuda context - from sglang.srt.layers.attention.triton_ops.decode_attention import ( - decode_attention_fwd, - ) - from sglang.srt.layers.attention.triton_ops.extend_attention import ( - extend_attention_fwd, - ) - - self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd) - self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd) - - # Parse args - self.skip_prefill = skip_prefill - max_bs = model_runner.req_to_token_pool.size - self.sliding_window_size = model_runner.sliding_window_size - self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator - self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens - self.speculative_num_steps = model_runner.server_args.speculative_num_steps - self.num_head = model_runner.model_config.num_attention_heads // get_attention_tp_size() - self.num_kv_head = model_runner.model_config.get_num_kv_heads(get_attention_tp_size()) - # Modifies layer id to support pipeline parallel - if model_runner.hybrid_gdn_config is not None: - # For hybrid linear models, layer_id = 0 may not be full attention - self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() - else: - - ################################################################################ - ## Patch for PP: get pp_start_layer - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer( - model_runner.pp_start_layer - ).shape[-1] - ## End of patch - ################################################################################ - self.max_context_len = model_runner.model_config.context_len - self.device = model_runner.device - self.device_core_count = get_device_core_count(model_runner.gpu_id) - self.static_kv_splits = get_bool_env_var("SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false") - self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits - - # Decide whether enable deterministic inference with batch-invariant operations - self.enable_deterministic = model_runner.server_args.enable_deterministic_inference - - # Configure deterministic inference settings - if self.enable_deterministic: - # Use fixed split tile size for batch invariance - self.split_tile_size = get_int_env_var("SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE", 256) - # Set static_kv_splits to False to use deterministic logic instead - self.static_kv_splits = False - else: - self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size - - if self.split_tile_size is not None: - self.max_kv_splits = ( - self.max_context_len + self.split_tile_size - 1 - ) // self.split_tile_size - # Check arguments - assert not ( - model_runner.sliding_window_size is not None - and model_runner.model_config.is_encoder_decoder - ), "Sliding window and cross attention are not supported together" - - # Initialize buffers - # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled - if kv_indptr_buf is None: - self.kv_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - else: - self.kv_indptr = kv_indptr_buf - - # If sliding window is enabled, we might need two sets of buffers - # because of interleaved attention types (e.g. for Gemma3) - self.window_kv_indptr = None - if self.sliding_window_size is not None and self.sliding_window_size > 0: - if kv_indptr_buf is None: - self.window_kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) + pp_start_layer = getattr(model_runner, "pp_start_layer", 0) + token_to_kv_pool = model_runner.token_to_kv_pool + token_to_kv_pool_dict = getattr(token_to_kv_pool, "__dict__", {}) + had_get_value_buffer_override = "get_value_buffer" in token_to_kv_pool_dict + get_value_buffer_override = token_to_kv_pool_dict.get("get_value_buffer") + original_get_value_buffer = token_to_kv_pool.get_value_buffer + + def get_value_buffer(layer_id: int, *args, **kwargs): + if layer_id == 0: + layer_id = pp_start_layer + return original_get_value_buffer(layer_id, *args, **kwargs) + + token_to_kv_pool.get_value_buffer = get_value_buffer + try: + return _original_triton_backend_init( + self, + model_runner, + skip_prefill=skip_prefill, + kv_indptr_buf=kv_indptr_buf, + ) + finally: + if had_get_value_buffer_override: + token_to_kv_pool.get_value_buffer = get_value_buffer_override else: - # When provided a buffer, create a clone for the second buffer - self.window_kv_indptr = torch.zeros_like(kv_indptr_buf) - - if not self.skip_prefill: - self.qo_indptr = torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - - self.mask_indptr = torch.zeros((max_bs + 1,), dtype=torch.int64, device=model_runner.device) - - # Initialize forward metadata - from sglang.srt.layers.attention.triton_backend import ForwardMetadata - - self.forward_metadata: ForwardMetadata = None - - self.cuda_graph_custom_mask = None + delattr(token_to_kv_pool, "get_value_buffer") def apply_triton_backend_init_monkey_patch(): From 70779aabe117748e2f047af9bd16714c360814e8 Mon Sep 17 00:00:00 2001 From: gufengc Date: Fri, 22 May 2026 05:11:05 +0000 Subject: [PATCH 2/2] update --- .../sglang/monkey_patch_utils/model_parallel.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/parallax/sglang/monkey_patch_utils/model_parallel.py b/src/parallax/sglang/monkey_patch_utils/model_parallel.py index 234c1529..f95b1d72 100644 --- a/src/parallax/sglang/monkey_patch_utils/model_parallel.py +++ b/src/parallax/sglang/monkey_patch_utils/model_parallel.py @@ -25,16 +25,10 @@ import sglang.srt.distributed import sglang.srt.distributed.parallel_state import torch -from sglang.srt.distributed import get_world_group from sglang.srt.distributed.parallel_state import ( GroupCoordinator as SGLGroupCoordinator, ) -from sglang.srt.utils import ( - LayerFn, - add_prefix, - is_npu, - is_xpu, -) +from sglang.srt.utils import LayerFn, add_prefix, is_npu, is_xpu from torch.distributed import Backend # from parallax.sglang.monkey_patch.model_runner import ModelRunner as SGLModelRunner @@ -43,9 +37,7 @@ _is_npu = is_npu() _is_xpu = is_xpu() -_sgl_initialize_model_parallel = ( - sglang.srt.distributed.parallel_state.initialize_model_parallel -) +_sgl_initialize_model_parallel = sglang.srt.distributed.parallel_state.initialize_model_parallel class ParallaxGroupCoordinator(SGLGroupCoordinator): @@ -136,9 +128,7 @@ def monkey_patch_init_model_parallel_group( local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=( - not (_is_npu or _is_xpu or backend == "mooncake") - if use_pynccl is None - else use_pynccl + not (_is_npu or _is_xpu or backend == "mooncake") if use_pynccl is None else use_pynccl ), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce,