diff --git a/vllm/disaggregated/disagg_worker.py b/vllm/disaggregated/disagg_worker.py index 322ce9e07892..206bc9bacbeb 100644 --- a/vllm/disaggregated/disagg_worker.py +++ b/vllm/disaggregated/disagg_worker.py @@ -96,9 +96,17 @@ async def _abort_handler(self, req: GenerationRequest): self.engine.abort(request_id=req.request_id) async def _heartbeat_handler(self, req: HeartbeatRequest): + try: + self.engine.check_health() + status = "OK" + except Exception: + status = "FAIL" + logger.exception("Check health Failed.") + msg = (ResponseType.HEARTBEAT, self.encoder.encode( - HeartbeatResponse(request_id=req.request_id, status="OK"))) + HeartbeatResponse(request_id=req.request_id, + status=status))) await self.to_proxy.send_multipart(msg, copy=False) async def _generate( diff --git a/vllm/disaggregated/proxy.py b/vllm/disaggregated/proxy.py index 5d77c04bb359..d6c96c6ddd8f 100644 --- a/vllm/disaggregated/proxy.py +++ b/vllm/disaggregated/proxy.py @@ -53,6 +53,12 @@ def __init__( health_check_interval=10, health_threshold=3, ): + self._check_type("enable_health_monitor", enable_health_monitor, bool) + self._check_positive_int("health_check_interval", + health_check_interval) + self._check_positive_int("health_threshold", health_threshold) + self._check_subclass("router", router, RoutingInterface) + self.queues: dict[str, asyncio.Queue] = {} self.encoder = msgspec.msgpack.Encoder() @@ -498,6 +504,20 @@ async def get_vllm_config(self) -> VllmConfig: async def reset_mm_cache(self) -> None: raise NotImplementedError + def _check_type(self, name, value, expected_type): + if not isinstance(value, expected_type): + raise TypeError(f"{name} must be {expected_type.__name__}, ", + f"got {type(value).__name__}") + + def _check_positive_int(self, name, value): + if not isinstance(value, int) or value <= 0: + raise ValueError(f"{name} must be a positive integer") + + def _check_subclass(self, name, value, base_class): + if not isinstance(value, type) or not issubclass(value, base_class): + raise TypeError( + f"{name} must be a subclass of {base_class.__name__}") + def _has_mm_data(prompt: PromptType) -> bool: if isinstance(prompt, dict): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 7fb36cf5941e..198c3a474344 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -486,6 +486,8 @@ async def do_log_stats( async def check_health(self) -> None: logger.debug("Called check_health.") + if self.errored: + raise self.dead_error async def start_profile(self) -> None: await self.engine_core.profile_async(True)