Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions mooncake-wheel/mooncake/mooncake_connector_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import contextlib
import threading
import time
import importlib.metadata
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -22,16 +23,18 @@
import torch
import zmq

from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
get_tp_group)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import get_ip, make_zmq_path, make_zmq_socket
try:
from vllm.utils import get_ip, make_zmq_path, make_zmq_socket
except ImportError:
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
Expand Down Expand Up @@ -114,8 +117,9 @@ class MooncakeConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
super().__init__(vllm_config, role)
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id

if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \
MooncakeConnectorScheduler(vllm_config, self.engine_id)
Expand Down Expand Up @@ -425,12 +429,22 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.use_mla)
self.backend_name = backend.get_name()
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
vllm_version = importlib.metadata.version("vllm")
if vllm_version.startswith("0.11.0"):
from vllm.attention.selector import backend_name_to_enum
from vllm.platforms import _Backend
attn_backend = backend_name_to_enum(self.backend_name)
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
elif vllm_version.startswith("0.11.1") or vllm_version.startswith("0.11.2"):
from vllm.attention.selector import AttentionBackendEnum
attn_backend = AttentionBackendEnum[self.backend_name]
self._use_flashinfer = attn_backend in [AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER_MLA]
self._use_pallas_v1 = attn_backend == AttentionBackendEnum.PALLAS
else:
raise Exception("Unsupported vllm version %s", vllm_version)
self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
Expand Down Expand Up @@ -759,4 +773,4 @@ def group_concurrent_contiguous(
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]

return src_groups, dst_groups
return src_groups, dst_groups
Loading