Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/sglang/srt/compilation/piecewise_context_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional

from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch


@dataclass
Expand Down
98 changes: 77 additions & 21 deletions python/sglang/srt/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@
from sglang.srt.configs.chatglm import ChatGLMConfig
from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.dots_ocr import DotsOCRConfig
from sglang.srt.configs.dots_vlm import DotsVLMConfig
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_linear import KimiLinearConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.olmo3 import Olmo3Config
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig,
)

__all__ = [
"ExaoneConfig",
"ChatGLMConfig",
Expand All @@ -39,3 +18,80 @@
"FalconH1Config",
"NemotronHConfig",
]


def __getattr__(name):
"""Lazily import config classes on first access."""
if name == "ChatGLMConfig":
from sglang.srt.configs.chatglm import ChatGLMConfig

return ChatGLMConfig
elif name == "DbrxConfig":
from sglang.srt.configs.dbrx import DbrxConfig

return DbrxConfig
elif name == "DeepseekVL2Config":
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config

return DeepseekVL2Config
elif name == "DotsOCRConfig":
from sglang.srt.configs.dots_ocr import DotsOCRConfig

return DotsOCRConfig
elif name == "DotsVLMConfig":
from sglang.srt.configs.dots_vlm import DotsVLMConfig

return DotsVLMConfig
elif name == "ExaoneConfig":
from sglang.srt.configs.exaone import ExaoneConfig

return ExaoneConfig
elif name == "FalconH1Config":
from sglang.srt.configs.falcon_h1 import FalconH1Config

return FalconH1Config
elif name == "MultiModalityConfig":
from sglang.srt.configs.janus_pro import MultiModalityConfig

return MultiModalityConfig
elif name == "KimiLinearConfig":
from sglang.srt.configs.kimi_linear import KimiLinearConfig

return KimiLinearConfig
elif name == "KimiVLConfig":
from sglang.srt.configs.kimi_vl import KimiVLConfig

return KimiVLConfig
elif name == "MoonViTConfig":
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig

return MoonViTConfig
elif name == "LongcatFlashConfig":
from sglang.srt.configs.longcat_flash import LongcatFlashConfig

return LongcatFlashConfig
elif name == "NemotronHConfig":
from sglang.srt.configs.nemotron_h import NemotronHConfig

return NemotronHConfig
elif name == "Olmo3Config":
from sglang.srt.configs.olmo3 import Olmo3Config

return Olmo3Config
elif name == "Qwen3NextConfig":
from sglang.srt.configs.qwen3_next import Qwen3NextConfig

return Qwen3NextConfig
elif name == "Step3TextConfig":
from sglang.srt.configs.step3_vl import Step3TextConfig

return Step3TextConfig
elif name == "Step3VisionEncoderConfig":
from sglang.srt.configs.step3_vl import Step3VisionEncoderConfig

return Step3VisionEncoderConfig
elif name == "Step3VLConfig":
from sglang.srt.configs.step3_vl import Step3VLConfig

return Step3VLConfig
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
10 changes: 8 additions & 2 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import json
import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import Any, List, Optional, Set, Union
from typing import TYPE_CHECKING, Any, List, Optional, Set, Union

import torch
from transformers import PretrainedConfig

from sglang.srt.environ import envs
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_hip, retry
from sglang.srt.utils.hf_transformers_utils import (
Expand All @@ -37,6 +38,9 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from transformers import PretrainedConfig


class AttentionArch(IntEnum):
MLA = auto()
Expand Down Expand Up @@ -623,6 +627,8 @@ def _validate_quantize_and_serve_config(self):

# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
from sglang.srt.layers.quantization import QUANTIZATION_METHODS

supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
Expand All @@ -69,7 +67,7 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler

CLIP_MAX_NEW_TOKEN = get_int_env_var("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", 4096)
Expand Down Expand Up @@ -227,6 +225,8 @@ def __init__(
self.kv_manager = self._init_kv_manager()

def _init_kv_manager(self) -> BaseKVManager:
from sglang.srt.layers.dp_attention import get_attention_tp_size

kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()

Expand Down Expand Up @@ -296,6 +296,8 @@ def _init_kv_manager(self) -> BaseKVManager:

def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
from sglang.srt.managers.schedule_batch import RequestStage
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very repetitive pattern, but the from sglang.srt.managers.schedule_batch import is somewhat expensive. The RequestStage object is just an enum, so we don't want to be doing the massive import for just an enum in my opinion. Maybe we can put this enum in a different file / some other solution that effectively does the same thing, so we can just import at the top level.


if self._check_if_req_exceed_kv_capacity(req):
return

Expand Down Expand Up @@ -410,6 +412,8 @@ def _update_handshake_waiters(self) -> None:

def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage

self._update_handshake_waiters()

preallocated_reqs = []
Expand Down Expand Up @@ -697,6 +701,8 @@ def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs)

def pop_transferred(self) -> List[Req]:
from sglang.srt.managers.schedule_batch import RequestStage

if not self.queue:
return []
polls = poll_and_all_reduce(
Expand Down Expand Up @@ -824,6 +830,7 @@ class SchedulerDisaggregationDecodeMixin:
@torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode."""
from sglang.srt.managers.schedule_batch import RequestStage

while True:
recv_reqs = self.recv_requests()
Expand Down Expand Up @@ -868,6 +875,8 @@ def event_loop_normal_disagg_decode(self: Scheduler):

@torch.no_grad()
def event_loop_overlap_disagg_decode(self: Scheduler):
from sglang.srt.managers.schedule_batch import RequestStage

self.result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
Expand Down Expand Up @@ -982,6 +991,8 @@ def get_next_disagg_decode_batch_to_run(

def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill"""
from sglang.srt.managers.schedule_batch import RequestStage, ScheduleBatch

if self.grammar_queue:
self.move_ready_grammar_requests()

Expand Down
14 changes: 8 additions & 6 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import (
FINISH_LENGTH,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
NSATokenToKVPool,
Expand All @@ -59,6 +53,7 @@
if TYPE_CHECKING:
from torch.distributed import ProcessGroup

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache

Expand Down Expand Up @@ -179,6 +174,8 @@ def _init_kv_manager(self) -> BaseKVManager:
return kv_manager

def add(self, req: Req, num_kv_heads: int) -> None:
from sglang.srt.managers.schedule_batch import RequestStage

if self._check_if_req_exceed_kv_capacity(req):
return

Expand Down Expand Up @@ -231,6 +228,7 @@ def pop_bootstrapped(
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
from sglang.srt.managers.schedule_batch import RequestStage

bootstrapped_reqs = []
failed_reqs = []
Expand Down Expand Up @@ -396,6 +394,8 @@ def process_batch_result_disagg_prefill(
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Adapted from process_batch_result_prefill
"""
from sglang.srt.managers.schedule_batch import RequestStage

(
logits_output,
next_token_ids,
Expand Down Expand Up @@ -513,6 +513,8 @@ def process_disagg_prefill_inflight_queue(
Poll the requests in the middle of transfer. If done, return the request.
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, RequestStage

if len(self.disagg_prefill_inflight_queue) == 0:
return []

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/eplb/expert_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@

from sglang.srt.environ import envs
from sglang.srt.metrics.collector import ExpertDispatchCollector
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, is_npu

_is_npu = is_npu()

if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch

logger = logging.getLogger(__name__)

Expand Down
21 changes: 14 additions & 7 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple

import torch
from torch.nn.parameter import Parameter, UninitializedParameter
Expand All @@ -20,19 +22,19 @@
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
method_has_implemented_embedding,
)
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
from sglang.srt.utils import (
cpu_has_amx_support,
get_compiler_backend,
is_cpu,
set_weight_attrs,
)

if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)

DEFAULT_VOCAB_PADDING_SIZE = 64

_is_cpu_amx_available = cpu_has_amx_support()
Expand Down Expand Up @@ -255,6 +257,11 @@ def __init__(
)
self.embedding_dim = embedding_dim

from sglang.srt.layers.quantization.base_config import (
method_has_implemented_embedding,
)
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod

quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
Expand Down
Loading
Loading