Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm/distributed/ec_transfer/ec_connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ def get_finished(
"""
return None, None

@abstractmethod
def get_stats(self) -> Any:
"""
Get the statistics of the connector.

Returns:
Statistics object.
"""
pass
Comment on lines +192 to +200
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_stats method is marked as @abstractmethod but ECSharedStorageConnector (in shared_storage_connector.py) does not implement this method. This will cause instantiation failures for ECSharedStorageConnector. Either remove the @abstractmethod decorator to make it optional, or ensure all subclasses implement this method.

Copilot uses AI. Check for mistakes.

# ==============================
# Scheduler-side methods
# ==============================
Expand Down
38 changes: 38 additions & 0 deletions vllm/distributed/ec_transfer/ec_connector/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any, Union


@dataclass
class ECConnectorStats:
"""
Base class for EC Connector Stats, a container for transfer performance
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace at the end of the line. Consider removing it for consistent code style.

Suggested change
Base class for EC Connector Stats, a container for transfer performance
Base class for EC Connector Stats, a container for transfer performance

Copilot uses AI. Check for mistakes.
metrics or otherwise important telemetry from the connector.
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace at the end of the line. Consider removing it for consistent code style.

Suggested change
metrics or otherwise important telemetry from the connector.
metrics or otherwise important telemetry from the connector.

Copilot uses AI. Check for mistakes.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data: dict[str, Any] = field(default_factory=dict)

def reset(self):
"""Reset the stats, clear the state."""
raise NotImplementedError

def aggregate(self, other: "ECConnectorStats") -> "ECConnectorStats":
"""
Aggregate stats with another `ECConnectorStats` object.
"""
raise NotImplementedError

def reduce(self) -> dict[str, Union[int, float]]:
"""
Reduce the observations collected during a time interval to one or
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace at the end of the line. Consider removing it for consistent code style.

Suggested change
Reduce the observations collected during a time interval to one or
Reduce the observations collected during a time interval to one or

Copilot uses AI. Check for mistakes.
more representative values (eg avg/median/sum of the series).
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace at the end of the line. Consider removing it for consistent code style.

Suggested change
more representative values (eg avg/median/sum of the series).
more representative values (eg avg/median/sum of the series).

Copilot uses AI. Check for mistakes.
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise NotImplementedError

def is_empty(self) -> bool:
"""Return True if the stats are empty."""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import os
from contextlib import contextmanager
from dataclasses import dataclass
from importlib import import_module
from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union

from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorBase, ECConnectorMetadata, ECConnectorRole)
from vllm.distributed.ec_transfer.ec_connector.metrics import ECConnectorStats
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput

Expand Down Expand Up @@ -56,6 +60,7 @@ def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole):
# mm_hash -> num_tokens
self._mm_datas_need_loads: dict[str, int] = {}
self.store = ECMooncakeStore(vllm_config)
self.stats = MooncakeECConnectorStats()

def start_load_caches(self, encoder_cache, **kwargs) -> None:
"""
Expand All @@ -79,18 +84,19 @@ def start_load_caches(self, encoder_cache, **kwargs) -> None:
if not metadata.mm_datas:
return

mm_hashes = [
mm_data.mm_hash for mm_data in metadata.mm_datas
if mm_data.mm_hash not in encoder_cache
]
device = self._vllm_config.device_config.device
tensors = self.store.batch_get(mm_hashes, device)
with self.stats.load_timer():
mm_hashes = [
mm_data.mm_hash for mm_data in metadata.mm_datas
if mm_data.mm_hash not in encoder_cache
]
device = self._vllm_config.device_config.device
tensors = self.store.batch_get(mm_hashes, device)

for mm_hash, ec_cache in zip(mm_hashes, tensors):
encoder_cache[mm_hash] = ec_cache
if ec_cache is None:
logger.error("Load failed for %s", mm_hash)
logger.debug("Load tensor for %s successfully", mm_hash)
for mm_hash, ec_cache in zip(mm_hashes, tensors):
encoder_cache[mm_hash] = ec_cache
if ec_cache is None:
logger.error("Load failed for %s", mm_hash)
logger.debug("Load tensor for %s successfully", mm_hash)

def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
"""
Expand All @@ -113,6 +119,9 @@ def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None:
self.store.batch_put([mm_hash], [encoder_cache[mm_hash]])

def wait_for_save(self):
if not self.is_producer:
return

self.store.wait_for_put()

def has_caches(
Expand Down Expand Up @@ -167,3 +176,74 @@ def build_connector_meta(
meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token))
self._mm_datas_need_loads.clear()
return meta

def get_stats(self) -> ECConnectorStats:
return self.stats.clone_and_reset()


@dataclass
class MooncakeECConnectorStats(ECConnectorStats):
"""Container for transfer performance metrics"""

def __post_init__(self):
if "load_time_ms" not in self.data:
self.data["load_time_ms"] = 0.0
if "save_time_ms" not in self.data:
self.data["save_time_ms"] = 0.0
if "num_loads" not in self.data:
self.data["num_loads"] = 0
if "num_saves" not in self.data:
self.data["num_saves"] = 0

def reset(self):
self.data = {
"load_time_ms": 0.0,
"save_time_ms": 0.0,
"num_loads": 0,
"num_saves": 0,
}

@contextmanager
def load_timer(self):
start = perf_counter()
try:
yield
finally:
elapsed_ms = (perf_counter() - start) * 1000.0
self.record_load(elapsed_ms)

def record_load(self, load_time_ms: float):
self.data["load_time_ms"] += load_time_ms
self.data["num_loads"] += 1

def record_save(self, save_time_ms: float):
self.data["save_time_ms"] += save_time_ms
self.data["num_saves"] += 1

def clone_and_reset(self) -> "MooncakeECConnectorStats":
old = copy.copy(self)
self.reset()
return old

def is_empty(self) -> bool:
return self.data["num_loads"] == 0 and self.data["num_saves"] == 0

def aggregate(self, other: ECConnectorStats) -> ECConnectorStats:
if not other.is_empty():
self.data["load_time_ms"] += other.data["load_time_ms"]
self.data["save_time_ms"] += other.data["save_time_ms"]
self.data["num_loads"] += other.data["num_loads"]
self.data["num_saves"] += other.data["num_saves"]
return self
Comment on lines +231 to +237
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aggregate method accepts ECConnectorStats but directly accesses keys specific to MooncakeECConnectorStats (e.g., other.data["load_time_ms"]). If a different ECConnectorStats subclass is passed, this will raise a KeyError. Consider either:

  1. Type-checking with isinstance(other, MooncakeECConnectorStats) before accessing the keys, or
  2. Changing the parameter type to "MooncakeECConnectorStats" to make the expectation explicit.

Copilot uses AI. Check for mistakes.

def reduce(self) -> dict[str, Union[int, float]]:
return {
"avg_load_time_ms":
(self.data["load_time_ms"] / max(1, self.data["num_loads"])),
"avg_save_time_ms":
(self.data["save_time_ms"] / max(1, self.data["num_saves"])),
"total_loads":
self.data["num_loads"],
"total_saves":
self.data["num_saves"],
}
13 changes: 10 additions & 3 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole
from vllm.distributed.ec_transfer.ec_connector.factory import (
ECConnectorFactory)
from vllm.distributed.ec_transfer.ec_connector.metrics import ECConnectorStats
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
Expand Down Expand Up @@ -917,11 +918,14 @@ def update_from_output(
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
ec_connector_output = model_runner_output.ec_connector_output

outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats = (kv_connector_output.kv_connector_stats
if kv_connector_output else None)
ec_connector_stats = (ec_connector_output.ec_connector_stats
if ec_connector_output else None)

# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
Expand Down Expand Up @@ -1057,8 +1061,8 @@ def update_from_output(
finished_requests=finished_set)
finished_req_ids.clear()

if (stats := self.make_stats(spec_decoding_stats,
kv_connector_stats)) is not None:
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats,
ec_connector_stats)) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request
Expand Down Expand Up @@ -1224,6 +1228,7 @@ def make_stats(
self,
spec_decoding_stats: Optional[SpecDecodingStats] = None,
kv_connector_stats: Optional[KVConnectorStats] = None,
ec_connector_stats: Optional[ECConnectorStats] = None,
) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
Expand All @@ -1237,7 +1242,9 @@ def make_stats(
num_corrupted_reqs=sum(req.is_output_corrupted
for req in self.running),
kv_connector_stats=kv_connector_stats.data
if kv_connector_stats else None)
if kv_connector_stats else None,
ec_connector_stats=ec_connector_stats.data
if ec_connector_stats else None)

def make_spec_decoding_stats(
self,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class SchedulerStats:

spec_decoding_stats: Optional[SpecDecodingStats] = None
kv_connector_stats: Optional[dict[str, Any]] = None
ec_connector_stats: Optional[dict[str, Any]] = None

num_corrupted_reqs: int = 0

Expand Down
19 changes: 12 additions & 7 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
from vllm.distributed.ec_transfer.ec_connector.metrics import (
ECConnectorStats)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)

Expand Down Expand Up @@ -100,6 +102,7 @@ class ECConnectorOutput:
# [mm_hash]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
ec_connector_stats: Optional[ECConnectorStats] = None


# ModelRunnerOutput is serialized and sent to the scheduler process.
Expand Down Expand Up @@ -200,10 +203,12 @@ def make_empty_encoder_model_runner_output(
)


EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None)
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
num_nans_in_logits=None,
)
2 changes: 1 addition & 1 deletion vllm/v1/worker/ec_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ def _get_ec_connector_output(
finally:
output.finished_sending, output.finished_recving = (
ec_connector.get_finished(scheduler_output.finished_req_ids))

output.ec_connector_stats = ec_connector.get_stats()
ec_connector.clear_connector_metadata()