diff --git a/python/sglang/srt/compilation/piecewise_context_manager.py b/python/sglang/srt/compilation/piecewise_context_manager.py index 711b4f487ea..a0e07b6f081 100644 --- a/python/sglang/srt/compilation/piecewise_context_manager.py +++ b/python/sglang/srt/compilation/piecewise_context_manager.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch @dataclass diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 690a1e3eb0c..456af1d6266 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,23 +1,27 @@ -from sglang.srt.configs.chatglm import ChatGLMConfig -from sglang.srt.configs.dbrx import DbrxConfig -from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config -from sglang.srt.configs.dots_ocr import DotsOCRConfig -from sglang.srt.configs.dots_vlm import DotsVLMConfig -from sglang.srt.configs.exaone import ExaoneConfig -from sglang.srt.configs.falcon_h1 import FalconH1Config -from sglang.srt.configs.janus_pro import MultiModalityConfig -from sglang.srt.configs.jet_nemotron import JetNemotronConfig -from sglang.srt.configs.kimi_linear import KimiLinearConfig -from sglang.srt.configs.kimi_vl import KimiVLConfig -from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig -from sglang.srt.configs.longcat_flash import LongcatFlashConfig -from sglang.srt.configs.nemotron_h import NemotronHConfig -from sglang.srt.configs.olmo3 import Olmo3Config -from sglang.srt.configs.qwen3_next import Qwen3NextConfig -from sglang.srt.configs.step3_vl import ( - Step3TextConfig, - Step3VisionEncoderConfig, - Step3VLConfig, +from sglang.utils import LazyImport + +ChatGLMConfig = LazyImport("sglang.srt.configs.chatglm", "ChatGLMConfig") +DbrxConfig = LazyImport("sglang.srt.configs.dbrx", "DbrxConfig") +DeepseekVL2Config = LazyImport("sglang.srt.configs.deepseekvl2", "DeepseekVL2Config") +DotsOCRConfig = LazyImport("sglang.srt.configs.dots_ocr", "DotsOCRConfig") +DotsVLMConfig = LazyImport("sglang.srt.configs.dots_vlm", "DotsVLMConfig") +ExaoneConfig = LazyImport("sglang.srt.configs.exaone", "ExaoneConfig") +FalconH1Config = LazyImport("sglang.srt.configs.falcon_h1", "FalconH1Config") +MultiModalityConfig = LazyImport("sglang.srt.configs.janus_pro", "MultiModalityConfig") +JetNemotronConfig = LazyImport("sglang.srt.configs.jet_nemotron", "JetNemotronConfig") +KimiLinearConfig = LazyImport("sglang.srt.configs.kimi_linear", "KimiLinearConfig") +KimiVLConfig = LazyImport("sglang.srt.configs.kimi_vl", "KimiVLConfig") +MoonViTConfig = LazyImport("sglang.srt.configs.kimi_vl_moonvit", "MoonViTConfig") +LongcatFlashConfig = LazyImport( + "sglang.srt.configs.longcat_flash", "LongcatFlashConfig" +) +NemotronHConfig = LazyImport("sglang.srt.configs.nemotron_h", "NemotronHConfig") +Olmo3Config = LazyImport("sglang.srt.configs.olmo3", "Olmo3Config") +Qwen3NextConfig = LazyImport("sglang.srt.configs.qwen3_next", "Qwen3NextConfig") +Step3VLConfig = LazyImport("sglang.srt.configs.step3_vl", "Step3VLConfig") +Step3TextConfig = LazyImport("sglang.srt.configs.step3_vl", "Step3TextConfig") +Step3VisionEncoderConfig = LazyImport( + "sglang.srt.configs.step3_vl", "Step3VisionEncoderConfig" ) __all__ = [ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index b51d19d26f1..b81b85f3d31 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -12,18 +12,18 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import json import logging import math import os from enum import Enum, IntEnum, auto -from typing import Any, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, List, Optional, Set, Union import torch -from transformers import PretrainedConfig from sglang.srt.environ import envs -from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_hip, retry from sglang.srt.utils.hf_transformers_utils import ( @@ -37,6 +37,9 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from transformers import PretrainedConfig + class AttentionArch(IntEnum): MLA = auto() @@ -631,6 +634,8 @@ def _validate_quantize_and_serve_config(self): # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py def _verify_quantization(self) -> None: + from sglang.srt.layers.quantization import QUANTIZATION_METHODS + supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = [ "awq", diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index fb28f914a18..f3cef733812 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -46,8 +46,7 @@ poll_and_all_reduce, prepare_abort, ) -from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch +from sglang.srt.managers.request_types import FINISH_ABORT, RequestStage from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -67,7 +66,7 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: - from sglang.srt.managers.schedule_batch import Req + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import Scheduler CLIP_MAX_NEW_TOKEN = get_int_env_var("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", 4096) @@ -225,6 +224,8 @@ def __init__( self.kv_manager = self._init_kv_manager() def _init_kv_manager(self) -> BaseKVManager: + from sglang.srt.layers.dp_attention import get_attention_tp_size + kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) kv_args = kv_args_class() @@ -884,6 +885,8 @@ def get_next_disagg_decode_batch_to_run( def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: """Create a schedulebatch for fake completed prefill""" + from sglang.srt.managers.schedule_batch import ScheduleBatch + if self.grammar_queue: self.move_ready_grammar_requests() diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 69eb22f3012..e9dd38d5777 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -42,12 +42,7 @@ poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import ( - FINISH_LENGTH, - Req, - RequestStage, - ScheduleBatch, -) +from sglang.srt.managers.request_types import FINISH_LENGTH, RequestStage from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.mem_cache.memory_pool import ( HybridLinearKVPool, @@ -60,6 +55,7 @@ if TYPE_CHECKING: from torch.distributed import ProcessGroup + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler from sglang.srt.mem_cache.memory_pool import KVCache @@ -232,7 +228,6 @@ def pop_bootstrapped( return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank. """ - bootstrapped_reqs = [] failed_reqs = [] indices_to_remove = set() diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index d660172de58..ecfa81b1514 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +from sglang.srt.managers.request_types import FINISH_ABORT from sglang.srt.utils import is_npu if TYPE_CHECKING: @@ -346,8 +347,6 @@ def is_mla_backend(target_kv_pool) -> bool: def prepare_abort(req: Req, error_message: str, status_code=None): - from sglang.srt.managers.schedule_batch import FINISH_ABORT - # populate finish metadata and stream output req.finished_reason = FINISH_ABORT(error_message, status_code) diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index bfec3802e6d..6a96f89f599 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -29,7 +29,6 @@ from sglang.srt.environ import envs from sglang.srt.metrics.collector import ExpertDispatchCollector -from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_int_env_var, is_npu @@ -37,6 +36,7 @@ if TYPE_CHECKING: from sglang.srt.eplb.expert_location import ExpertLocationMetadata + from sglang.srt.model_executor.forward_batch_info import ForwardBatch logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/moe/__init__.py b/python/sglang/srt/layers/moe/__init__.py index 74d23ecd7c7..908294ea062 100644 --- a/python/sglang/srt/layers/moe/__init__.py +++ b/python/sglang/srt/layers/moe/__init__.py @@ -1,4 +1,3 @@ -from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig from sglang.srt.layers.moe.utils import ( DeepEPMode, MoeA2ABackend, @@ -12,6 +11,10 @@ is_tbo_enabled, should_use_flashinfer_cutlass_moe_fp4_allgather, ) +from sglang.utils import LazyImport + +MoeRunner = LazyImport("sglang.srt.layers.moe.moe_runner.runner", "MoeRunner") +MoeRunnerConfig = LazyImport("sglang.srt.layers.moe.moe_runner.base", "MoeRunnerConfig") __all__ = [ "DeepEPMode", diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 6c153b25051..4bea1322955 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -1,8 +1,10 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py +from __future__ import annotations + import logging from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple import torch from torch.nn.parameter import Parameter, UninitializedParameter @@ -20,12 +22,6 @@ from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter -from sglang.srt.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, - method_has_implemented_embedding, -) -from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod from sglang.srt.utils import ( cpu_has_amx_support, get_compiler_backend, @@ -33,6 +29,12 @@ set_weight_attrs, ) +if TYPE_CHECKING: + from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + ) + DEFAULT_VOCAB_PADDING_SIZE = 64 _is_cpu_amx_available = cpu_has_amx_support() @@ -255,6 +257,11 @@ def __init__( ) self.embedding_dim = embedding_dim + from sglang.srt.layers.quantization.base_config import ( + method_has_implemented_embedding, + ) + from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod + quant_method = None if quant_config is not None: quant_method = quant_config.get_quant_method(self, prefix=prefix) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index cb897a643cd..6b0192e7c0c 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -13,6 +13,8 @@ # ============================================================================== """A controller that dispatches requests to multiple data parallel workers.""" +from __future__ import annotations + import faulthandler import logging import multiprocessing as mp @@ -21,7 +23,7 @@ import time from collections import deque from enum import Enum, auto -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import psutil import setproctitle @@ -34,7 +36,7 @@ TokenizedGenerateReqInput, WatchLoadUpdateReq, ) -from sglang.srt.managers.schedule_batch import Req, RequestStage +from sglang.srt.managers.request_types import RequestStage from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import ( DP_ATTENTION_HANDSHAKE_PORT_DELTA, @@ -60,6 +62,9 @@ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index cdecebca3c4..f1d4167ac83 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -16,6 +16,8 @@ processes (TokenizerManager, DetokenizerManager, Scheduler). """ +from __future__ import annotations + import copy import uuid from abc import ABC @@ -24,7 +26,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from sglang.srt.lora.lora_registry import LoRARef -from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.utils import ImageData @@ -32,6 +33,8 @@ # Handle serialization of Image for pydantic if TYPE_CHECKING: from PIL.Image import Image + + from sglang.srt.managers.request_types import BaseFinishReason else: Image = Any diff --git a/python/sglang/srt/managers/request_types.py b/python/sglang/srt/managers/request_types.py new file mode 100644 index 00000000000..1ca49bc9d00 --- /dev/null +++ b/python/sglang/srt/managers/request_types.py @@ -0,0 +1,107 @@ +"""Lightweight request types and enums from schedule_batch.py.""" + +from __future__ import annotations + +import enum +from typing import List, Union + + +class BaseFinishReason: + def __init__(self, is_error: bool = False): + self.is_error = is_error + + def to_json(self): + raise NotImplementedError() + + +class FINISH_MATCHED_TOKEN(BaseFinishReason): + def __init__(self, matched: Union[int, List[int]]): + super().__init__() + self.matched = matched + + def to_json(self): + return { + "type": "stop", # to match OpenAI API's return value + "matched": self.matched, + } + + +class FINISH_MATCHED_STR(BaseFinishReason): + def __init__(self, matched: str): + super().__init__() + self.matched = matched + + def to_json(self): + return { + "type": "stop", # to match OpenAI API's return value + "matched": self.matched, + } + + +class FINISHED_MATCHED_REGEX(BaseFinishReason): + def __init__(self, matched: str): + super().__init__() + self.matched = matched + + def to_json(self): + return { + "type": "stop", # to match OpenAI API's return value + "matched": self.matched, + } + + +class FINISH_LENGTH(BaseFinishReason): + def __init__(self, length: int): + super().__init__() + self.length = length + + def to_json(self): + return { + "type": "length", # to match OpenAI API's return value + "length": self.length, + } + + +class FINISH_ABORT(BaseFinishReason): + def __init__(self, message=None, status_code=None, err_type=None): + super().__init__(is_error=True) + self.message = message or "Aborted" + self.status_code = status_code + self.err_type = err_type + + def to_json(self): + return { + "type": "abort", + "message": self.message, + "status_code": self.status_code, + "err_type": self.err_type, + } + + +class RequestStage(str, enum.Enum): + # Tokenizer + TOKENIZE = "tokenize" + TOKENIZER_DISPATCH = "dispatch" + + # DP controller + DC_DISPATCH = "dc_dispatch" + + # common/non-disaggregation + PREFILL_WAITING = "prefill_waiting" + REQUEST_PROCESS = "request_process" + DECODE_LOOP = "decode_loop" + PREFILL_FORWARD = "prefill_forward" + PREFILL_CHUNKED_FORWARD = "chunked_prefill" + + # disaggregation prefill + PREFILL_PREPARE = "prefill_prepare" + PREFILL_BOOTSTRAP = "prefill_bootstrap" + PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" + + # disaggregation decode + DECODE_PREPARE = "decode_prepare" + DECODE_BOOTSTRAP = "decode_bootstrap" + DECODE_WAITING = "decode_waiting" + DECODE_TRANSFERRED = "decode_transferred" + DECODE_FAKE_OUTPUT = "fake_output" + DECODE_QUICK_FINISH = "quick_finish" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 326b010b2c7..e4bc64b0aa3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,9 +1,3 @@ -from __future__ import annotations - -import enum - -from sglang.srt.model_executor.forward_batch_info import ForwardBatch - # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,6 +29,8 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future. """ +from __future__ import annotations + import copy import dataclasses import logging @@ -56,6 +52,15 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.environ import envs +from sglang.srt.managers.request_types import ( + FINISH_ABORT, + FINISH_LENGTH, + FINISH_MATCHED_STR, + FINISH_MATCHED_TOKEN, + FINISHED_MATCHED_REGEX, + BaseFinishReason, + RequestStage, +) from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, @@ -98,78 +103,6 @@ logger = logging.getLogger(__name__) -class BaseFinishReason: - def __init__(self, is_error: bool = False): - self.is_error = is_error - - def to_json(self): - raise NotImplementedError() - - -class FINISH_MATCHED_TOKEN(BaseFinishReason): - def __init__(self, matched: Union[int, List[int]]): - super().__init__() - self.matched = matched - - def to_json(self): - return { - "type": "stop", # to match OpenAI API's return value - "matched": self.matched, - } - - -class FINISH_MATCHED_STR(BaseFinishReason): - def __init__(self, matched: str): - super().__init__() - self.matched = matched - - def to_json(self): - return { - "type": "stop", # to match OpenAI API's return value - "matched": self.matched, - } - - -class FINISHED_MATCHED_REGEX(BaseFinishReason): - def __init__(self, matched: str): - super().__init__() - self.matched = matched - - def to_json(self): - return { - "type": "stop", # to match OpenAI API's return value - "matched": self.matched, - } - - -class FINISH_LENGTH(BaseFinishReason): - def __init__(self, length: int): - super().__init__() - self.length = length - - def to_json(self): - return { - "type": "length", # to match OpenAI API's return value - "length": self.length, - } - - -class FINISH_ABORT(BaseFinishReason): - def __init__(self, message=None, status_code=None, err_type=None): - super().__init__(is_error=True) - self.message = message or "Aborted" - self.status_code = status_code - self.err_type = err_type - - def to_json(self): - return { - "type": "abort", - "message": self.message, - "status_code": self.status_code, - "err_type": self.err_type, - } - - class Modality(Enum): IMAGE = auto() MULTI_IMAGES = auto() @@ -402,35 +335,6 @@ def merge(self, other: MultimodalInputs): # other args would be kept intact -class RequestStage(str, enum.Enum): - # Tokenizer - TOKENIZE = "tokenize" - TOKENIZER_DISPATCH = "dispatch" - - # DP controller - DC_DISPATCH = "dc_dispatch" - - # common/non-disaggregation - PREFILL_WAITING = "prefill_waiting" - REQUEST_PROCESS = "request_process" - DECODE_LOOP = "decode_loop" - PREFILL_FORWARD = "prefill_forward" - PREFILL_CHUNKED_FORWARD = "chunked_prefill" - - # disaggregation prefill - PREFILL_PREPARE = "prefill_prepare" - PREFILL_BOOTSTRAP = "prefill_bootstrap" - PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" - - # disaggregation decode - DECODE_PREPARE = "decode_prepare" - DECODE_BOOTSTRAP = "decode_bootstrap" - DECODE_WAITING = "decode_waiting" - DECODE_TRANSFERRED = "decode_transferred" - DECODE_FAKE_OUTPUT = "fake_output" - DECODE_QUICK_FINISH = "quick_finish" - - class Req: """The input and output status of a request.""" diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index dc257dcdfe9..bf437411278 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -115,12 +115,11 @@ ) from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.managers.overlap_utils import FutureMap +from sglang.srt.managers.request_types import FINISH_ABORT, RequestStage from sglang.srt.managers.schedule_batch import ( - FINISH_ABORT, ModelWorkerBatch, MultimodalInputs, Req, - RequestStage, ScheduleBatch, ) from sglang.srt.managers.schedule_policy import ( diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 6ad020cabdf..9c98b2efe2d 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -14,12 +14,8 @@ BatchEmbeddingOutput, BatchTokenIDOutput, ) -from sglang.srt.managers.schedule_batch import ( - BaseFinishReason, - Req, - RequestStage, - ScheduleBatch, -) +from sglang.srt.managers.request_types import BaseFinishReason, RequestStage +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.tracing.trace import trace_slice, trace_slice_batch, trace_slice_end diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index 4990a5cacf6..5119e64b3fa 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -15,7 +15,8 @@ from typing import Dict, Optional from sglang.srt.managers.io_struct import TokenizedGenerateReqInput -from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req +from sglang.srt.managers.request_types import FINISH_ABORT +from sglang.srt.managers.schedule_batch import Req class SessionReqNode: diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 7f49e48c418..b46ac2b39d2 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -21,6 +21,7 @@ import fastapi import zmq +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.io_struct import ( ClearHiCacheReqInput, ClearHiCacheReqOutput, @@ -68,7 +69,7 @@ UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) -from sglang.srt.server_args import LoRARef, ServerArgs +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var from sglang.utils import TypeBasedDispatcher diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4dbf078c8b6..06a0c0936c8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -70,7 +70,7 @@ from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager -from sglang.srt.managers.schedule_batch import RequestStage +from sglang.srt.managers.request_types import RequestStage from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 036748f7c7e..af9368be332 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,7 +26,6 @@ import orjson -from sglang.srt.connector import ConnectorType from sglang.srt.environ import ToolStrictLevel, envs from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.lora.lora_registry import LoRARef @@ -61,7 +60,6 @@ wait_port_available, xpu_has_xmx_support, ) -from sglang.srt.utils.hf_transformers_utils import check_gguf_file, get_config from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -904,6 +902,7 @@ def _handle_cpu_backends(self): def _handle_model_specific_adjustments(self): from sglang.srt.configs.model_config import is_deepseek_nsa + from sglang.srt.connector import ConnectorType if parse_connector_type(self.model_path) == ConnectorType.INSTANCE: return @@ -1662,6 +1661,8 @@ def _handle_speculative_decoding(self): ) def _handle_load_format(self): + from sglang.srt.utils.hf_transformers_utils import check_gguf_file + if ( self.load_format == "auto" or self.load_format == "gguf" ) and check_gguf_file(self.model_path): @@ -1780,6 +1781,8 @@ def _handle_metrics_labels(self): ) def _handle_deterministic_inference(self): + from sglang.srt.connector import ConnectorType + if self.rl_on_policy_target is not None: logger.warning( "Enable deterministic inference because of rl_on_policy_target." @@ -3678,6 +3681,8 @@ def url(self): return f"http://{self.host}:{self.port}" def get_hf_config(self): + from sglang.srt.utils.hf_transformers_utils import get_config + kwargs = {} hf_config = get_config( self.model_path, diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index fbc34e906ee..b9dfa504cd5 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -13,6 +13,8 @@ # ============================================================================== """Utilities for Huggingface Transformers.""" +from __future__ import annotations + import contextlib import json import logging @@ -20,74 +22,82 @@ import tempfile import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import torch from huggingface_hub import snapshot_download -from transformers import ( - AutoConfig, - AutoProcessor, - AutoTokenizer, - GenerationConfig, - PretrainedConfig, - PreTrainedTokenizer, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, -) -from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - -from sglang.srt.configs import ( - ChatGLMConfig, - DbrxConfig, - DeepseekVL2Config, - DotsOCRConfig, - DotsVLMConfig, - ExaoneConfig, - FalconH1Config, - JetNemotronConfig, - KimiLinearConfig, - KimiVLConfig, - LongcatFlashConfig, - MultiModalityConfig, - NemotronHConfig, - Olmo3Config, - Qwen3NextConfig, - Step3VLConfig, -) -from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config -from sglang.srt.configs.internvl import InternVLChatConfig + from sglang.srt.connector import create_remote_connector -from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset +from sglang.utils import LazyImport -_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ - ChatGLMConfig, - DbrxConfig, - ExaoneConfig, - DeepseekVL2Config, - MultiModalityConfig, - KimiVLConfig, - InternVLChatConfig, - Step3VLConfig, - LongcatFlashConfig, - Olmo3Config, - KimiLinearConfig, - Qwen3NextConfig, - FalconH1Config, - DotsVLMConfig, - DotsOCRConfig, - NemotronHConfig, - DeepseekVLV2Config, - JetNemotronConfig, -] - -_CONFIG_REGISTRY = { - config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY -} +if TYPE_CHECKING: + from transformers import ( + PretrainedConfig, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ) -for name, cls in _CONFIG_REGISTRY.items(): - with contextlib.suppress(ValueError): - AutoConfig.register(name, cls) +AutoConfig = LazyImport("transformers", "AutoConfig") +AutoTokenizer = LazyImport("transformers", "AutoTokenizer") +AutoProcessor = LazyImport("transformers", "AutoProcessor") +GenerationConfig = LazyImport("transformers", "GenerationConfig") +PreTrainedTokenizerFast = LazyImport("transformers", "PreTrainedTokenizerFast") +PreTrainedTokenizerBase = LazyImport("transformers", "PreTrainedTokenizerBase") +SiglipVisionConfig = LazyImport("transformers", "SiglipVisionConfig") + + +def _register_custom_configs(): + from sglang.srt.configs import ( + ChatGLMConfig, + DbrxConfig, + DeepseekVL2Config, + DotsOCRConfig, + DotsVLMConfig, + ExaoneConfig, + FalconH1Config, + JetNemotronConfig, + KimiLinearConfig, + KimiVLConfig, + LongcatFlashConfig, + MultiModalityConfig, + NemotronHConfig, + Olmo3Config, + Qwen3NextConfig, + Step3VLConfig, + ) + from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config + from sglang.srt.configs.internvl import InternVLChatConfig + + _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ + ChatGLMConfig, + DbrxConfig, + ExaoneConfig, + DeepseekVL2Config, + MultiModalityConfig, + KimiVLConfig, + InternVLChatConfig, + Step3VLConfig, + LongcatFlashConfig, + Olmo3Config, + KimiLinearConfig, + Qwen3NextConfig, + FalconH1Config, + DotsVLMConfig, + DotsOCRConfig, + NemotronHConfig, + DeepseekVLV2Config, + JetNemotronConfig, + ] + + _CONFIG_REGISTRY = { + config_cls.model_type: config_cls for config_cls in _CONFIG_REGISTRY + } + for name, cls in _CONFIG_REGISTRY.items(): + with contextlib.suppress(ValueError): + AutoConfig.register(name, cls) + + return _CONFIG_REGISTRY def download_from_hf( @@ -196,6 +206,8 @@ def get_config( model_override_args: Optional[dict] = None, **kwargs, ): + _CONFIG_REGISTRY = _register_custom_configs() + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = model @@ -228,8 +240,6 @@ def get_config( # Phi4MMForCausalLM uses a hard-coded vision_config. See: # https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71 # We set it here to support cases where num_attention_heads is not divisible by the TP size. - from transformers import SiglipVisionConfig - vision_config = { "hidden_size": 1152, "image_size": 448, @@ -271,6 +281,10 @@ def get_config( # Special architecture mapping check for GGUF models if is_gguf: + from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + ) + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] @@ -457,6 +471,10 @@ def get_processor( use_fast: Optional[bool] = True, **kwargs, ): + from sglang.srt.multimodal.customized_mm_processor_utils import ( + _CUSTOMIZED_MM_PROCESSOR, + ) + # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 041630093da..cfc57ad2aa8 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -323,6 +323,19 @@ def __call__(self, *args, **kwargs): module = self._load() return module(*args, **kwargs) + def __or__(self, other): + self_type = self._load() + if isinstance(other, LazyImport): + other_type = other._load() + elif isinstance(other, tuple): + return (self_type,) + other + else: + other_type = other + return (self_type, other_type) + + def __instancecheck__(self, instance): + return isinstance(instance, self._load()) + def download_and_cache_file(url: str, filename: Optional[str] = None): """Read and cache a file from a url."""