diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py index 477c50457c6f..cd739adae3bc 100644 --- a/vllm/distributed/ec_transfer/ec_connector/base.py +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -189,6 +189,16 @@ def get_finished( """ return None, None + @abstractmethod + def get_stats(self) -> Any: + """ + Get the statistics of the connector. + + Returns: + Statistics object. + """ + pass + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/distributed/ec_transfer/ec_connector/metrics.py b/vllm/distributed/ec_transfer/ec_connector/metrics.py new file mode 100644 index 000000000000..aacc229f4419 --- /dev/null +++ b/vllm/distributed/ec_transfer/ec_connector/metrics.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, field +from typing import Any, Union + + +@dataclass +class ECConnectorStats: + """ + Base class for EC Connector Stats, a container for transfer performance + metrics or otherwise important telemetry from the connector. + All sub-classes need to be serializable as stats are sent from worker to + logger process. + """ + data: dict[str, Any] = field(default_factory=dict) + + def reset(self): + """Reset the stats, clear the state.""" + raise NotImplementedError + + def aggregate(self, other: "ECConnectorStats") -> "ECConnectorStats": + """ + Aggregate stats with another `ECConnectorStats` object. + """ + raise NotImplementedError + + def reduce(self) -> dict[str, Union[int, float]]: + """ + Reduce the observations collected during a time interval to one or + more representative values (eg avg/median/sum of the series). + This is meant to be called by the logger to produce a summary of the + stats for the last time interval. + """ + raise NotImplementedError + + def is_empty(self) -> bool: + """Return True if the stats are empty.""" + raise NotImplementedError diff --git a/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py b/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py index ec1c76131ec2..85eddc60b497 100644 --- a/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/mooncake_storage_connector.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import os +from contextlib import contextmanager from dataclasses import dataclass from importlib import import_module +from time import perf_counter from typing import TYPE_CHECKING, Optional, Union from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( ECConnectorBase, ECConnectorMetadata, ECConnectorRole) +from vllm.distributed.ec_transfer.ec_connector.metrics import ECConnectorStats from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -56,6 +60,7 @@ def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): # mm_hash -> num_tokens self._mm_datas_need_loads: dict[str, int] = {} self.store = ECMooncakeStore(vllm_config) + self.stats = MooncakeECConnectorStats() def start_load_caches(self, encoder_cache, **kwargs) -> None: """ @@ -79,18 +84,19 @@ def start_load_caches(self, encoder_cache, **kwargs) -> None: if not metadata.mm_datas: return - mm_hashes = [ - mm_data.mm_hash for mm_data in metadata.mm_datas - if mm_data.mm_hash not in encoder_cache - ] - device = self._vllm_config.device_config.device - tensors = self.store.batch_get(mm_hashes, device) + with self.stats.load_timer(): + mm_hashes = [ + mm_data.mm_hash for mm_data in metadata.mm_datas + if mm_data.mm_hash not in encoder_cache + ] + device = self._vllm_config.device_config.device + tensors = self.store.batch_get(mm_hashes, device) - for mm_hash, ec_cache in zip(mm_hashes, tensors): - encoder_cache[mm_hash] = ec_cache - if ec_cache is None: - logger.error("Load failed for %s", mm_hash) - logger.debug("Load tensor for %s successfully", mm_hash) + for mm_hash, ec_cache in zip(mm_hashes, tensors): + encoder_cache[mm_hash] = ec_cache + if ec_cache is None: + logger.error("Load failed for %s", mm_hash) + logger.debug("Load tensor for %s successfully", mm_hash) def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None: """ @@ -113,6 +119,9 @@ def save_caches(self, encoder_cache, mm_hash, **kwargs) -> None: self.store.batch_put([mm_hash], [encoder_cache[mm_hash]]) def wait_for_save(self): + if not self.is_producer: + return + self.store.wait_for_put() def has_caches( @@ -167,3 +176,74 @@ def build_connector_meta( meta.add_mm_data(MMMeta.make_meta(mm_hash, num_encoder_token)) self._mm_datas_need_loads.clear() return meta + + def get_stats(self) -> ECConnectorStats: + return self.stats.clone_and_reset() + + +@dataclass +class MooncakeECConnectorStats(ECConnectorStats): + """Container for transfer performance metrics""" + + def __post_init__(self): + if "load_time_ms" not in self.data: + self.data["load_time_ms"] = 0.0 + if "save_time_ms" not in self.data: + self.data["save_time_ms"] = 0.0 + if "num_loads" not in self.data: + self.data["num_loads"] = 0 + if "num_saves" not in self.data: + self.data["num_saves"] = 0 + + def reset(self): + self.data = { + "load_time_ms": 0.0, + "save_time_ms": 0.0, + "num_loads": 0, + "num_saves": 0, + } + + @contextmanager + def load_timer(self): + start = perf_counter() + try: + yield + finally: + elapsed_ms = (perf_counter() - start) * 1000.0 + self.record_load(elapsed_ms) + + def record_load(self, load_time_ms: float): + self.data["load_time_ms"] += load_time_ms + self.data["num_loads"] += 1 + + def record_save(self, save_time_ms: float): + self.data["save_time_ms"] += save_time_ms + self.data["num_saves"] += 1 + + def clone_and_reset(self) -> "MooncakeECConnectorStats": + old = copy.copy(self) + self.reset() + return old + + def is_empty(self) -> bool: + return self.data["num_loads"] == 0 and self.data["num_saves"] == 0 + + def aggregate(self, other: ECConnectorStats) -> ECConnectorStats: + if not other.is_empty(): + self.data["load_time_ms"] += other.data["load_time_ms"] + self.data["save_time_ms"] += other.data["save_time_ms"] + self.data["num_loads"] += other.data["num_loads"] + self.data["num_saves"] += other.data["num_saves"] + return self + + def reduce(self) -> dict[str, Union[int, float]]: + return { + "avg_load_time_ms": + (self.data["load_time_ms"] / max(1, self.data["num_loads"])), + "avg_save_time_ms": + (self.data["save_time_ms"] / max(1, self.data["num_saves"])), + "total_loads": + self.data["num_loads"], + "total_saves": + self.data["num_saves"], + } diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8e99869bae19..a7db582e2011 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,6 +13,7 @@ from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole from vllm.distributed.ec_transfer.ec_connector.factory import ( ECConnectorFactory) +from vllm.distributed.ec_transfer.ec_connector.metrics import ECConnectorStats from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) @@ -917,11 +918,14 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + ec_connector_output = model_runner_output.ec_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None kv_connector_stats = (kv_connector_output.kv_connector_stats if kv_connector_output else None) + ec_connector_stats = (ec_connector_output.ec_connector_stats + if ec_connector_output else None) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -1057,8 +1061,8 @@ def update_from_output( finished_requests=finished_set) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, + ec_connector_stats)) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1224,6 +1228,7 @@ def make_stats( self, spec_decoding_stats: Optional[SpecDecodingStats] = None, kv_connector_stats: Optional[KVConnectorStats] = None, + ec_connector_stats: Optional[ECConnectorStats] = None, ) -> Optional[SchedulerStats]: if not self.log_stats: return None @@ -1237,7 +1242,9 @@ def make_stats( num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) + if kv_connector_stats else None, + ec_connector_stats=ec_connector_stats.data + if ec_connector_stats else None) def make_spec_decoding_stats( self, diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 296c39e8cdb5..05033b4aaced 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -44,6 +44,7 @@ class SchedulerStats: spec_decoding_stats: Optional[SpecDecodingStats] = None kv_connector_stats: Optional[dict[str, Any]] = None + ec_connector_stats: Optional[dict[str, Any]] = None num_corrupted_reqs: int = 0 diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index dad487c1dbe5..91199616bea0 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -10,6 +10,8 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: + from vllm.distributed.ec_transfer.ec_connector.metrics import ( + ECConnectorStats) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorStats) @@ -100,6 +102,7 @@ class ECConnectorOutput: # [mm_hash] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + ec_connector_stats: Optional[ECConnectorStats] = None # ModelRunnerOutput is serialized and sent to the scheduler process. @@ -200,10 +203,12 @@ def make_empty_encoder_model_runner_output( ) -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, +) diff --git a/vllm/v1/worker/ec_connector_model_runner_mixin.py b/vllm/v1/worker/ec_connector_model_runner_mixin.py index 9228ad45fa6e..cfef0d8e2f3d 100644 --- a/vllm/v1/worker/ec_connector_model_runner_mixin.py +++ b/vllm/v1/worker/ec_connector_model_runner_mixin.py @@ -84,5 +84,5 @@ def _get_ec_connector_output( finally: output.finished_sending, output.finished_recving = ( ec_connector.get_finished(scheduler_output.finished_req_ids)) - + output.ec_connector_stats = ec_connector.get_stats() ec_connector.clear_connector_metadata()