From 0d74ecaf8430f2f30155d0de7e14e6d46296b267 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Mon, 4 May 2026 06:55:24 -0700 Subject: [PATCH 01/18] feat(rl): import RL surface from bis/parity-tokenize-tcp as baseline for bis/dynamo-rl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squash-imports the v1 RL surface from bis/parity-tokenize-tcp as the testable starting point for bis/dynamo-rl. Subsequent commits refactor toward the cleaner design in bis-dev/design-docs/rl-support.md (phases 0-5). Merge conflict resolutions: - nvext.rs: keep main's NvExtResponseFieldSelection refactor; add v1's completion_token_ids field + selection flag + gating in build_response_nvext. - preprocessor.rs: combine main's tracing::info! [SIDECAR-SKIP-TOKENIZE] log with v1's skip_token_annotation = has_backend_instance_id semantics (so RL/TITO callers without backend_instance_id keep the token_ids annotation; GAIE EPP callers continue to skip it). - chat_completions/delta.rs: route through main's build_response_nvext helper; layer v1's RL completion_token_ids accumulator on top so the full token list is emitted only on the finish chunk. - completions/delta.rs: drop v1's redundant inline NvExtResponse build (main's helper covers it). - worker_factory.py: extend main's register_engine_routes() helper with v1's RL routes (pause_generation, resume_generation, flush_cache, update_weights_from_path, get_weight_version, load_lora_adapter, unload_lora_adapter). - handlers.py: keep both main's start_profile/stop_profile and v1's RL handlers (pause_generation through unload_lora_adapter). - publisher.py: keep v1's 'scheduler_stats can be None right after a weight reload / cache reset' explanatory comment. - install_vllm.sh: keep VLLM_VER=0.19.1 from v1 (matches the venv the smoke runs against; main's 0.20.0 will be picked up by the long-term plan's upgrade workstream). cargo check --workspace clean (1 pre-existing benign warning). Test loop: ~/dev/rl/work/bis-dev/4-02/{lora,sft}/run.sh — mirrors the known-working bis-dev/4 reference smokes. --- .../src/dynamo/frontend/vllm_processor.py | 39 + components/src/dynamo/vllm/handlers.py | 274 +++++ components/src/dynamo/vllm/publisher.py | 1 + components/src/dynamo/vllm/worker_factory.py | 26 +- container/deps/vllm/install_vllm.sh | 33 +- docs/Dynamo-RL-api-draft.md | 973 ++++++++++++++++++ lib/llm/src/audit/stream.rs | 3 + lib/llm/src/entrypoint/input/text.rs | 2 + lib/llm/src/http/service/openai.rs | 946 +++++++++++++++++ lib/llm/src/http/service/service_v2.rs | 30 +- lib/llm/src/preprocessor.rs | 11 +- lib/llm/src/protocols/anthropic/types.rs | 2 + lib/llm/src/protocols/openai.rs | 6 +- .../src/protocols/openai/chat_completions.rs | 16 + .../openai/chat_completions/aggregator.rs | 1 + .../openai/chat_completions/delta.rs | 81 +- lib/llm/src/protocols/openai/nvext.rs | 22 + lib/llm/src/protocols/openai/responses/mod.rs | 2 + lib/llm/src/protocols/openai/tokenization.rs | 124 +++ lib/llm/src/protocols/openai/validate.rs | 20 +- lib/tokenizers/src/fastokens.rs | 26 +- lib/tokenizers/src/hf.rs | 38 +- lib/tokenizers/src/lib.rs | 33 +- lib/tokenizers/src/tiktoken.rs | 53 +- 24 files changed, 2712 insertions(+), 50 deletions(-) create mode 100644 docs/Dynamo-RL-api-draft.md create mode 100644 lib/llm/src/protocols/openai/tokenization.rs diff --git a/components/src/dynamo/frontend/vllm_processor.py b/components/src/dynamo/frontend/vllm_processor.py index 8962bfd02ec5..df29f1bd6dd6 100644 --- a/components/src/dynamo/frontend/vllm_processor.py +++ b/components/src/dynamo/frontend/vllm_processor.py @@ -633,6 +633,40 @@ async def _generate_and_stream( break choice = post.process_output(output) if choice: + # ── RL logprobs injection ────────────────────── + # The vLLM worker sends log_probs/top_logprobs in + # the engine_response dict. Since we can't easily + # construct LogprobsLists for EngineCoreOutput, we + # inject them directly into the choice here. + worker_log_probs = engine_response.get("log_probs") + worker_top_logprobs = engine_response.get("top_logprobs") + if worker_log_probs is not None and choice.get("logprobs") is None: + oai_logprobs_content = [] + new_tids = engine_response.get("token_ids", []) + for i, lp in enumerate(worker_log_probs): + # Always populate token/bytes so consumers never see a + # missing key. If top_logprobs is absent or the token + # string cannot be resolved we fall back to the numeric + # ID as a string — better than a KeyError / silent None. + tid_str = str(new_tids[i]) if i < len(new_tids) else "" + entry: dict = { + "logprob": lp, + "token": tid_str, + "bytes": None, + } + # Resolve the human-readable token string and top_logprobs + # from the engine's top_logprobs table when available. + if worker_top_logprobs and i < len(worker_top_logprobs): + tops = worker_top_logprobs[i] + entry["top_logprobs"] = tops + if i < len(new_tids): + for tp in tops: + if tp.get("token_id") == new_tids[i]: + entry["token"] = tp.get("token", tid_str) + break + oai_logprobs_content.append(entry) + choice["logprobs"] = {"content": oai_logprobs_content} + choices.append(choice) if choices: @@ -646,6 +680,11 @@ async def _generate_and_stream( if usage := engine_response.get("completion_usage"): dynamo_out["usage"] = usage + # ── RL: pass output token IDs for nvext.completion_token_ids ── + new_token_ids = engine_response.get("token_ids", []) + if new_token_ids: + dynamo_out["_completion_token_ids"] = new_token_ids + yield dynamo_out _nvtx.end_range(rng_stream) except Exception as e: diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 46964076b8c8..25b898b8d630 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -825,6 +825,280 @@ async def stop_profile(self, body: dict) -> dict: logger.error(f"Failed to stop profiling: {e}") return {"status": "error", "message": str(e)} + # ── RL weight lifecycle engine routes ────────────────────────────── + # Signatures kept compatible with SGLang's merged #6094 routes so + # a single admin coordinator can talk to either backend. + + async def pause_generation(self, body: dict) -> dict: + """Pause the engine: drain in-flight requests, keep model loaded. + + Called by RL admin coordinator before weight updates. + Uses engine_client.pause_generation() directly -- does NOT sleep + (no GPU memory release) and does NOT unregister from discovery. + """ + body = body or {} + try: + await self.engine_client.pause_generation() + logger.info("[RL] Engine paused (generation quiesced)") + return {"status": "ok", "message": "Engine paused"} + except Exception as e: + logger.error(f"[RL] Failed to pause: {e}") + return {"status": "error", "message": str(e)} + + async def resume_generation(self, body: dict) -> dict: + """Resume the engine after a weight update.""" + body = body or {} + try: + await self.engine_client.resume_generation() + logger.info("[RL] Engine resumed") + return {"status": "ok", "message": "Engine resumed"} + except Exception as e: + logger.error(f"[RL] Failed to resume: {e}") + return {"status": "error", "message": str(e)} + + async def flush_cache(self, body: dict) -> dict: + """Invalidate prefix/KV cache. Called after weight updates.""" + body = body or {} + try: + await self.engine_client.reset_prefix_cache() + logger.info("[RL] Prefix cache flushed") + return {"status": "ok", "message": "Cache flushed"} + except Exception as e: + logger.error(f"[RL] Failed to flush cache: {e}") + return {"status": "error", "message": str(e)} + + async def update_weights_from_path(self, body: dict) -> dict: + """Load weights from a filesystem path (safetensors/torch checkpoint). + + Expects body: {"path": "/path/to/weights", "version": "step_N"} + The caller is responsible for pausing/resuming around this call. + """ + body = body or {} + path = body.get("path") + version = body.get("version", "unknown") + if not path: + return {"status": "error", "message": "Missing 'path' in body"} + try: + # Use vLLM's built-in reload_weights via collective RPC. + # This calls Worker.reload_weights() -> GPUModelRunner.reload_weights() + # which handles loading safetensors from a directory using vLLM's + # model loader with proper layerwise reload. + await self.engine_client.collective_rpc( + "reload_weights", + kwargs={"weights_path": path}, + ) + self._weight_version = version + logger.info(f"[RL] Weights loaded from {path} (version={version})") + return { + "status": "ok", + "message": f"Weights loaded from {path}", + "version": version, + } + except Exception as e: + logger.error(f"[RL] Failed to load weights from {path}: {e}") + return {"status": "error", "message": str(e)} + + async def get_weight_version(self, body: dict) -> dict: + """Return the current weight version tag.""" + return {"version": getattr(self, "_weight_version", "initial")} + + async def load_lora_adapter(self, body: dict) -> dict: + """Load (or hot-swap) a LoRA adapter from a filesystem path. + + Expects body: {"lora_name": str, "lora_path": "/path/to/adapter_dir"} + + The adapter directory must contain ``adapter_model.safetensors`` and + ``adapter_config.json`` -- the standard PEFT output layout that Prime-RL + writes each training step. + + Unlike :meth:`load_lora` (which downloads from a URI via ``LoRAManager`` + streaming a gRPC response), this method is the RL admin equivalent used + for training-loop weight updates: + + * Reads the adapter directly from the given filesystem path (no URI / + no network fetch, no LoRAManager needed). + * Hot-swaps if ``lora_name`` is already loaded (remove old id then + re-add) so every training step replaces the same logical adapter. + * Resets the prefix cache after a hot-swap so stale KV entries keyed + to the previous adapter weights do not poison subsequent rollouts. + * Publishes a ModelDeploymentCard the first time a new ``lora_name`` is + loaded. Prime-RL switches its request ``model`` field to the LoRA + name after load (``scheduler.py``: ``self.model_name = self.lora_name``) + so the frontend needs an MDC entry to route ``r16-a32`` → this worker. + On subsequent hot-swaps the MDC is already published and we skip + re-registration. + """ + body = body or {} + lora_name = body.get("lora_name") + lora_path = body.get("lora_path") + if not lora_name: + return {"status": "error", "message": "Missing 'lora_name' in body"} + if not lora_path: + return {"status": "error", "message": "Missing 'lora_path' in body"} + try: + lock = self._get_lora_lock(lora_name) + async with lock: + lora_id = lora_name_to_id(lora_name) + is_hot_swap = lora_name in self.loaded_loras + + # Hot-swap: vLLM's add_lora is a no-op when the lora_int_id is + # already registered, so we must remove the previous adapter + # first. remove_lora is best-effort on a fresh add. + if is_hot_swap: + old_id = self.loaded_loras[lora_name].id + try: + await self.engine_client.remove_lora(old_id) + # Invalidate the cache entry immediately after remove succeeds. + # If add_lora below fails, this prevents a stale entry pointing + # at an adapter the engine no longer holds from poisoning future + # rollouts with wrong importance ratios (Tier-1 RL correctness risk). + self.loaded_loras.pop(lora_name, None) + except Exception as e: + logger.warning( + f"[RL] remove_lora({lora_name}, id={old_id}) failed during hot-swap: {e}" + ) + + await self.engine_client.add_lora( + LoRARequest( + lora_name=lora_name, + lora_int_id=lora_id, + lora_path=lora_path, + ) + ) + self.loaded_loras[lora_name] = LoRAInfo(id=lora_id, path=lora_path) + + # Invalidate KV cache on hot-swap so stale prefix entries keyed + # to the previous LoRA weights can't contaminate new rollouts. + if is_hot_swap: + try: + await self.engine_client.reset_prefix_cache() + except Exception as e: + # ERROR not WARNING: a failed cache reset means subsequent requests + # sharing a prefix with an old rollout can reuse KV state computed + # under the previous adapter — causing silent logprobs mismatch. + logger.error( + f"[RL] reset_prefix_cache after LoRA swap failed — KV cache may " + f"be contaminated with stale entries from the old adapter. " + f"Rollouts on this worker are unreliable until the next successful " + f"swap: {e}" + ) + + # Publish an MDC for the LoRA on first load so Dynamo's frontend + # can route requests with model= to this worker. + # Mirror the logic in load_lora() (URI variant). Skip on hot-swap + # since the MDC was already published on the first load. + if not is_hot_swap and self.generate_endpoint is not None: + try: + runtime_config = ModelRuntimeConfig() + runtime_config.tool_call_parser = self.config.dyn_tool_call_parser + runtime_config.reasoning_parser = self.config.dyn_reasoning_parser + await register_model( + model_input=ModelInput.Tokens, + model_type=ModelType.Chat | ModelType.Completions, + endpoint=self.generate_endpoint, + model_path=self.config.model, + kv_cache_block_size=self.config.engine_args.block_size, + runtime_config=runtime_config, + user_data={"lora_adapter": True, "lora_id": lora_id}, + lora_name=lora_name, + base_model_path=self.config.model, + ) + logger.info( + f"[RL] Published LoRA '{lora_name}' ModelDeploymentCard" + ) + except Exception as e: + # Rollback: remove the LoRA from the engine to keep state consistent. + logger.exception( + f"[RL] Failed to publish LoRA '{lora_name}' MDC: {e}; rolling back add_lora" + ) + try: + await self.engine_client.remove_lora(lora_id) + except Exception as rollback_err: + # The adapter is now leaked in the engine: it is registered but + # unreachable via loaded_loras (we pop it below). Log at ERROR + # so this doesn't go unnoticed in production. + logger.error( + f"[RL] Rollback remove_lora({lora_name}, id={lora_id}) failed " + f"— adapter is leaked in the engine: {rollback_err}" + ) + self.loaded_loras.pop(lora_name, None) + return { + "status": "error", + "message": f"Failed to register LoRA '{lora_name}' in discovery registry: {e}", + "lora_name": lora_name, + } + + logger.info( + f"[RL] LoRA adapter {'hot-swapped' if is_hot_swap else 'loaded'}: " + f"name={lora_name} id={lora_id} path={lora_path}" + ) + return { + "status": "ok", + "message": f"LoRA adapter '{lora_name}' loaded from {lora_path}", + "lora_name": lora_name, + "lora_id": lora_id, + "hot_swap": is_hot_swap, + } + except Exception as e: + logger.exception( + f"[RL] Failed to load LoRA adapter '{lora_name}' from {lora_path}: {e}" + ) + return {"status": "error", "message": str(e)} + + async def unload_lora_adapter(self, body: dict) -> dict: + """Unload a LoRA adapter previously loaded via :meth:`load_lora_adapter`. + + Expects body: {"lora_name": str} + + Idempotent: unloading an already-absent LoRA returns status=ok so + callers can safely retry without special-casing the not-found path. + """ + body = body or {} + lora_name = body.get("lora_name") + if not lora_name: + return {"status": "error", "message": "Missing 'lora_name' in body"} + try: + lock = self._get_lora_lock(lora_name) + async with lock: + lora = self.loaded_loras.get(lora_name) + if lora is None: + return { + "status": "ok", + "message": f"LoRA adapter '{lora_name}' not loaded (no-op)", + "lora_name": lora_name, + } + lora_id = lora.id + await self.engine_client.remove_lora(lora_id) + del self.loaded_loras[lora_name] + + # Unregister the MDC published on load so the frontend stops + # routing `model=` requests to this worker. + if self.generate_endpoint is not None: + try: + await unregister_model( + endpoint=self.generate_endpoint, + lora_name=lora_name, + ) + except Exception as e: + logger.warning( + f"[RL] Failed to unregister LoRA '{lora_name}' MDC (adapter already removed from engine): {e}" + ) + + logger.info( + f"[RL] LoRA adapter unloaded: name={lora_name} id={lora_id}" + ) + return { + "status": "ok", + "message": f"LoRA adapter '{lora_name}' unloaded", + "lora_name": lora_name, + "lora_id": lora_id, + } + except Exception as e: + logger.exception( + f"[RL] Failed to unload LoRA adapter '{lora_name}': {e}" + ) + return {"status": "error", "message": str(e)} + @abstractmethod def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]: raise NotImplementedError diff --git a/components/src/dynamo/vllm/publisher.py b/components/src/dynamo/vllm/publisher.py index d3a8619ad9d9..d233c55f636d 100644 --- a/components/src/dynamo/vllm/publisher.py +++ b/components/src/dynamo/vllm/publisher.py @@ -58,6 +58,7 @@ def record( *args: object, **kwargs: object, ) -> None: + # scheduler_stats can be None right after a weight reload / cache reset. if scheduler_stats is None: return diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index ce3d473fdb24..0c1649c4e3af 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -366,7 +366,7 @@ async def _create_decode_worker( component_name=config.component, ) - # Register engine routes + # Register engine routes (sleep/wake_up + RL weight-lifecycle + RL LoRA) self.register_engine_routes(runtime, handler) # Parse endpoint types from --endpoint-types flag @@ -576,7 +576,7 @@ async def _create_prefill_worker( component_name=config.component, ) - # Register engine routes + # Register engine routes (sleep/wake_up + RL weight-lifecycle + RL LoRA) self.register_engine_routes(runtime, handler) await self._maybe_wait_for_failover_lock(handler, runtime, config) @@ -676,6 +676,26 @@ def register_engine_routes( runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("scale_elastic_ep", handler.scale_elastic_ep) + # RL weight-lifecycle routes (parity with SGLang #6094) — driven by the + # /v1/rl/{pause,resume,update_weights} bracket in the Rust frontend. + runtime.register_engine_route("pause_generation", handler.pause_generation) + runtime.register_engine_route("resume_generation", handler.resume_generation) + runtime.register_engine_route("flush_cache", handler.flush_cache) + runtime.register_engine_route( + "update_weights_from_path", handler.update_weights_from_path + ) + runtime.register_engine_route("get_weight_version", handler.get_weight_version) + + # RL LoRA adapter routes: filesystem-native hot-swap used by Prime-RL + # every training step to broadcast new adapter weights into the engine. + runtime.register_engine_route("load_lora_adapter", handler.load_lora_adapter) + runtime.register_engine_route( + "unload_lora_adapter", handler.unload_lora_adapter + ) + logger.info( - "Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep, /engine/start_profile, /engine/stop_profile" + "Registered engine routes: sleep, wake_up, scale_elastic_ep, " + "start_profile, stop_profile, pause_generation, resume_generation, " + "flush_cache, update_weights_from_path, get_weight_version, " + "load_lora_adapter, unload_lora_adapter" ) diff --git a/container/deps/vllm/install_vllm.sh b/container/deps/vllm/install_vllm.sh index 5ab0aa0e6896..90bee7bf3034 100755 --- a/container/deps/vllm/install_vllm.sh +++ b/container/deps/vllm/install_vllm.sh @@ -13,7 +13,7 @@ set -euo pipefail -VLLM_VER="0.20.0" +VLLM_VER="0.19.1" VLLM_REF="v${VLLM_VER}" DEVICE="cuda" @@ -300,4 +300,35 @@ if [ "$DEVICE" = "cuda" ]; then # TODO we will be able to specify which pplx and deepep commit we want in future TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" bash install_python_libraries.sh fi + +# --------------------------------------------------------------------------- +# prime-rl inference-side vLLM plugin (pinned tag). +# +# Registers the ``vllm.general_plugins`` entry-point that applies prime-rl's +# monkey patches (LoRA adapter load, DP engine pause/resume deadlock, Qwen 3.5 +# LoRA, etc.) automatically in every vLLM worker process -- including spawned +# subprocesses. Required for prime-rl / Dynamo RL training integration. +# +# Override at build time: --build-arg PRIME_RL_REF=v0.5.1.dev101 +# --no-deps: prime-rl's full dep tree includes trainer + wandb; Dynamo only +# needs the inference-side plugin and worker-extension classes. +# Python version: prime-rl pins requires-python = "~=3.12.0"; Dynamo containers +# are Python 3.12, so no version override is needed. For 3.11 local +# dev venvs use the regular pip (not uv) with --ignore-requires-python. +# --------------------------------------------------------------------------- +PRIME_RL_REF="${PRIME_RL_REF:-v0.5.1.dev101}" +echo "\n=== Installing prime-rl vLLM plugin (ref=${PRIME_RL_REF}) ===" +uv pip install --no-deps \ + "prime-rl @ git+https://github.com/PrimeIntellect-ai/prime-rl@${PRIME_RL_REF}" + +# Sanity-check: confirm vllm.general_plugins entry-point is registered. +python3 - <<'PY_SANITY' +from importlib.metadata import entry_points +names = [ep.name for ep in entry_points(group="vllm.general_plugins")] +assert "prime_rl" in names, ( + f"prime-rl plugin NOT registered; vllm.general_plugins={names}" +) +print(f"✓ prime-rl plugin registered (vllm.general_plugins={names})") +PY_SANITY + echo "\n✅ All installations completed successfully!" diff --git a/docs/Dynamo-RL-api-draft.md b/docs/Dynamo-RL-api-draft.md new file mode 100644 index 000000000000..ae8113ed3073 --- /dev/null +++ b/docs/Dynamo-RL-api-draft.md @@ -0,0 +1,973 @@ +# Dynamo RL API Draft + +**Branch:** `bis/parity-tokenize-tcp` (HEAD: `d837fbd67`) + +Commit `70f84570b` (the current auto-enable-token-ids commit on HEAD) is an +equivalent cherry-pick of the earlier `19d1bf13d` referenced in prior drafts — +same subject, same patch semantics, different parent tree after rebase onto +`origin/main`. + +--- + +## Table of Contents + +1. [Overview](#1-overview) +2. [Architecture](#2-architecture) +3. [Configuration](#3-configuration) +4. [API Reference](#4-api-reference) + - 4.1 Chat Completions (RL-enhanced) + - 4.2 Token-In / Token-Out (TITO) + - 4.3 Tokenization + - 4.4 Fleet Control (`/v1/rl/*`) +5. [Data Flow](#5-data-flow) +6. [Key Data Structures](#6-key-data-structures) +7. [Worker Engine Routes (Internal)](#7-worker-engine-routes-internal) +8. [Known Limitations](#8-known-limitations) +9. [Validation Results](#9-validation-results) + +--- + +## 1. Overview + +This document describes the RL training API surface on the Dynamo serving stack for integration with prime-rl. The Dynamo frontend (Rust) exposes: + +- An `/v1/rl/*` router for the full RL control-plane lifecycle (pause/resume, weight updates, readiness checks) +- Automatic token-level data injection (`prompt_token_ids`, `completion_token_ids`) in chat completion responses +- `/v1/tokenize` and `/v1/detokenize` endpoints +- A `/v1/chat/completions/tokens` TITO endpoint for pre-tokenized prompt bypass + +Zero Python in the inference or admin data path. The Rust frontend handles all HTTP API surface while vLLM workers expose engine routes for weight lifecycle operations on the GPU. + +### Endpoint Summary + +| Capability | Endpoint | Purpose | +|------------|----------|---------| +| Inference | `POST /v1/chat/completions` | Generate rollouts; responses include `prompt_token_ids` + `choice.token_ids` | +| TITO inference | `POST /v1/chat/completions/tokens` | Pre-tokenized prompt bypass (turn 2+ in multi-turn RL) | +| Tokenization | `POST /v1/tokenize` | Consistent tokenization using the model's chat template | +| Detokenization | `POST /v1/detokenize` | Token IDs back to text | +| Pause fleet | `POST /v1/rl/pause` | Drain in-flight requests before weight update | +| Resume fleet | `POST /v1/rl/resume` | Resume generation after weight update | +| Update weights | `POST /v1/rl/update_weights` | Atomic flush + reload from checkpoint directory | +| Load LoRA adapter | `POST /v1/rl/load_lora_adapter` | Hot-load/swap a PEFT-style adapter from filesystem path | +| Unload LoRA adapter | `POST /v1/rl/unload_lora_adapter` | Remove a previously loaded adapter by name | +| Weight version | `GET /v1/rl/weight_version` | Query current weight version across workers | +| Health | `GET /v1/rl/health` | Lightweight frontend health check | +| Readiness | `GET /v1/rl/ready` | Deep check: are workers reachable and healthy? | + +### What Changed vs. Stock Dynamo + +All changes are on `bis/parity-tokenize-tcp` (18 commits, 26 files, +3030/-41 — the diff counts include this doc). Nothing touches Dynamo's core serving pipeline (NATS, scheduler, KV cache, disaggregation). The changes are additive: + +- **Rust frontend** (`lib/llm/`): New routes, response post-processing, tokenization endpoints, LoRA hot-swap admin routes +- **vLLM worker** (`components/`): 7 engine route handlers (5 weight-lifecycle + 2 LoRA), publisher crash guard +- **Deps** (`container/`): default `VLLM_VER` bumped 0.19.0 → 0.19.1; prime-rl plugin installed via `pip` so `vllm.general_plugins` patches apply at engine start +- **Compat fixes**: `/v1/tokenize` and `/v1/detokenize` adapted to upstream `DecodeResult`-returning decoder (main commit `2cabf4414`, #8022) + +--- + +## 2. Architecture + +### Component Topology + +```mermaid +flowchart TD + subgraph prime_rl["prime-rl"] + orch["Orchestrator
(prime_rl.orchestrator)"] + trainer["Trainer
(prime_rl.trainer.rl.train)
torchrun --nproc-per-node=N"] + end + + subgraph dynamo["Dynamo Serving Stack"] + subgraph frontend["Frontend Pod (Rust, port 8000)"] + cc["/v1/chat/completions
+ prompt_token_ids
+ choice.token_ids"] + tito["/v1/chat/completions/tokens
(TITO)"] + tok["/v1/tokenize   /v1/detokenize"] + rl["/v1/rl/*
health, ready, pause, resume,
update_weights, weight_version"] + end + subgraph worker["vLLM Worker Pod (Python, system port 9090)"] + eng["/engine/*
pause_generation
resume_generation
flush_cache
update_weights_from_path
get_weight_version"] + gpu["GPU
Model Weights"] + end + end + + subgraph storage["Shared Storage (PVC)"] + pvc["prime-rl-shared-data
safetensors checkpoints"] + end + + orch -- "rollouts
POST /v1/chat/completions" --> cc + orch -- "TITO turn 2+
POST /v1/chat/completions/tokens" --> tito + orch -- "POST /v1/tokenize" --> tok + orch -- "weight lifecycle
pause / update_weights / resume" --> rl + rl -- "HTTP fan-out
(concurrent to all workers)" --> eng + eng --> gpu + trainer -- "write checkpoint" --> pvc + eng -- "reload_weights
(collective_rpc)" --> pvc +``` + +### Key Design Decisions + +1. **Single entry point.** Prime-RL points both `base_url` and `admin_base_url` at the Dynamo frontend. No separate admin service to deploy. + +2. **Fan-out in Rust.** The `/v1/rl/*` handlers fan out to all vLLM workers via `DYN_RL_WORKER_SYSTEM_URLS`. This supports DP>1 without Prime-RL needing to discover workers. The frontend returns HTTP 200 only when all workers respond OK, and HTTP 502 otherwise with per-worker details. + +3. **Token IDs as a response extension.** When `DYN_ENABLE_RL=true`, `prompt_token_ids` and `choices[i].token_ids` are injected into every non-streaming response automatically. No client-side configuration needed. + +4. **Backward compatible.** All new response fields use `#[serde(skip_serializing_if = "Option::is_none")]`. Clients that don't set `DYN_ENABLE_RL` see standard OpenAI-compatible responses with no extra fields. + +--- + +## 3. Configuration + +### Environment Variables (Frontend) + +| Variable | Default | Description | +|----------|---------|-------------| +| `DYN_ENABLE_RL` | `false` | Master switch. Mounts `/v1/rl/*` routes, auto-injects token IDs in chat completion responses, mounts TITO endpoint. | +| `DYN_RL_WORKER_SYSTEM_URLS` | `http://localhost:8081` | Comma-separated list of vLLM worker system HTTP base URLs for fan-out. | + +### Environment Variables (Worker) + +| Variable | Default | Description | +|----------|---------|-------------| +| `DYN_SYSTEM_PORT` | `8081` (local) / `9090` (k8s) | Worker's system HTTP port where engine routes are registered. | + +### Prime-RL Configuration (`orch.toml`) + +```toml +max_steps = 20 +seq_len = 512 +batch_size = 16 +rollouts_per_example = 4 +use_token_client = false + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[sampling] +max_tokens = 64 + +[[env]] +id = "reverse-text" + +[client] +# Point BOTH base_url and admin_base_url at the Dynamo frontend. +# admin_base_url uses /v1/rl because Prime-RL strips trailing /v1 +# from admin URLs, but /v1/rl is preserved. +base_url = ["http://:8000/v1"] +admin_base_url = ["http://:8000/v1/rl"] +skip_model_check = true + +[weight_broadcast] +type = "filesystem" + +[experimental] +# Disable prefix cache salt until Dynamo supports it. +# verifiers dev6+ defaults use_prefix_cache_salt=True; current image returns 400. +use_prefix_cache_salt = false +``` + +**Important:** Do NOT set `send_return_token_ids = true` in `[sampling]`. The Rust frontend handles token ID injection automatically when `DYN_ENABLE_RL=true`. Sending `return_token_ids=true` in the request causes the OpenAI SDK to parse the response and strip unknown fields. + +### Kubernetes (DGD) + +```yaml +# Frontend pod env +- name: DYN_ENABLE_RL + value: "true" +- name: DYN_RL_WORKER_SYSTEM_URLS + value: "http://prime-rl-dynamo-vllmworker..svc.cluster.local:9090" +``` + +### Launch Commands (Local) + +```bash +# Frontend with RL routes enabled +DYN_ENABLE_RL=true \ +DYN_RL_WORKER_SYSTEM_URLS=http://localhost:8081 \ + python -m dynamo.frontend + +# vLLM Worker +CUDA_VISIBLE_DEVICES=0 \ +DYN_SYSTEM_PORT=8081 \ + python -m dynamo.vllm \ + --model PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT \ + --served-model-name PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT \ + --enforce-eager \ + --max-model-len 2048 \ + --gpu-memory-utilization 0.5 +``` + +--- + +## 4. API Reference + +All endpoints live on the Dynamo Rust frontend (default port 8000). Unless noted, request/response formats follow the OpenAI API specification. + +### 4.1 Chat Completions (RL-enhanced) + +``` +POST /v1/chat/completions +``` + +Standard OpenAI chat completions with RL extensions. When `DYN_ENABLE_RL=true`, every non-streaming response is automatically enriched with token IDs for the trainer. + +#### Request + +Standard OpenAI `ChatCompletionRequest`. Two additional fields are accepted and silently consumed (never forwarded to the vLLM worker): + +| Field | Type | Default | Mapped to | +|-------|------|---------|-----------| +| `tokens` | `u32[]` | `null` | `nvext.token_data` (tokenizer bypass) | +| `return_token_ids` | `bool` | `null` | `nvext.extra_fields: ["token_ids", "completion_token_ids"]` + `logprobs: true` | + +When `DYN_ENABLE_RL=true`, `return_token_ids` is implicitly `true` for every request. + +#### Sample Request + +```bash +curl -s -X POST http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", + "messages": [ + {"role": "user", "content": "Reverse this: hello world"} + ], + "max_tokens": 64, + "temperature": 1.0 + }' +``` + +#### Sample Response (Non-Streaming, with `DYN_ENABLE_RL=true`) + +```json +{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "dlrow olleh"}, + "finish_reason": "stop", + "logprobs": { + "content": [ + {"token": "dl", "logprob": -0.523, "top_logprobs": []}, + {"token": "row", "logprob": -0.102, "top_logprobs": []}, + {"token": " ol", "logprob": -0.834, "top_logprobs": []}, + {"token": "leh", "logprob": -0.211, "top_logprobs": []} + ] + }, + "token_ids": [67, 1245, 893, 15] + }], + "prompt_token_ids": [151644, 8948, 198, 151645, 198, 151644, 872, 198, + 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, + 198, 151644, 77091, 198], + "usage": {"prompt_tokens": 21, "completion_tokens": 4, "total_tokens": 25}, + "nvext": { + "completion_token_ids": [67, 1245, 893, 15] + } +} +``` + +#### Response Field Reference + +| Field | JSON path | Description | +|-------|-----------|-------------| +| `prompt_token_ids` | `response.prompt_token_ids` | Token IDs from tokenizing the prompt messages through the model's chat template. Generated by the Rust frontend's tokenizer after the response is fully received. | +| `token_ids` | `response.choices[i].token_ids` | Completion token IDs generated by the engine. Promoted from `nvext.completion_token_ids`. | +| `completion_token_ids` | `response.nvext.completion_token_ids` | Canonical Dynamo location for output token IDs. Accumulated across all SSE chunks by `DeltaGenerator`. | + +**Why `token_ids` appears in two locations:** Prime-RL's verifiers library reads `response.prompt_token_ids` and `choices[i].token_ids` (top-level on the choice object). Dynamo natively emits output token IDs in `nvext.completion_token_ids`. The Rust post-processor promotes the latter to the former for compatibility. Both contain the same values. + +**Invariant:** `len(completion_token_ids) == len(logprobs.content)` -- the output token IDs are in exact 1:1 correspondence with the logprob entries. + +#### Sample Response (Streaming / SSE -- final chunk only) + +Intermediate chunks carry `delta.content` only. Token IDs appear exclusively on the **final chunk** (the one with a non-null `finish_reason`): + +``` +data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk", + "choices":[{"index":0,"delta":{},"finish_reason":"stop", + "nvext":{"completion_token_ids":[67,1245,893,15]}}], + "prompt_token_ids":[151644,8948,198,151645,198,151644,872,198, + 49,1075,513,420,25,24748,1879,198,151645, + 198,151644,77091,198]} + +data: [DONE] +``` + +#### RL Post-Processing Pipeline + +For non-streaming requests, the handler performs the following after the backend response is fully aggregated: + +```mermaid +flowchart LR + A["Request arrives
DYN_ENABLE_RL=true"] --> B["Save messages
for later tokenization"] + B --> C["Inject nvext.extra_fields:
[token_ids, completion_token_ids]
Force logprobs=true"] + C --> D["Standard pipeline
(preprocessor, backend,
delta generator, aggregator)"] + D --> E["Aggregate response
(nvext.completion_token_ids
accumulated in delta.rs)"] + E --> F["rl_tokenize_prompt()
messages -> prompt_token_ids
via model chat template"] + F --> G["rl_promote_token_ids()
nvext.completion_token_ids
-> choices[i].token_ids"] + G --> H["Return enriched
JSON response"] +``` + +--- + +### 4.2 Token-In / Token-Out (TITO) + +``` +POST /v1/chat/completions/tokens +``` + +Dedicated endpoint for Prime-RL's pre-tokenized prompt flow (multi-turn RL, turn 2+). The orchestrator sends raw token IDs instead of text messages, bypassing the frontend's tokenizer entirely. This avoids redundant encode/decode round-trips and ensures token-level alignment. + +#### Sample Request + +```bash +curl -s -X POST http://localhost:8000/v1/chat/completions/tokens \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "tokens": [151644, 8948, 198, 151645, 198, 151644, 872, 198, + 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, + 198, 151644, 77091, 198, 67, 1245, 893, 15], + "max_tokens": 64 + }' +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `tokens` | **Yes** | Pre-tokenized prompt IDs. Must be non-empty. Injected as `nvext.token_data`. | +| `messages` | Yes (can be placeholder) | A placeholder `{"role": "user", "content": "(token-in mode)"}` is auto-injected if empty. | + +#### Behavior + +1. Extracts `tokens`, returns 400 if missing or empty +2. Injects into `nvext.token_data` (triggers tokenizer bypass in `preprocessor.rs`) +3. Adds `extra_fields: ["token_ids", "completion_token_ids"]` +4. Forces `logprobs = true` +5. Delegates to the standard `chat_completions()` pipeline (zero HTTP proxy) + +#### Sample Response + +Same shape as section 4.1. The response includes `prompt_token_ids` and `choices[i].token_ids`. + +--- + +### 4.3 Tokenization + +``` +POST /v1/tokenize +POST /v1/detokenize +``` + +Consistent tokenization using the model's tokenizer and chat template, running entirely in Rust. These are critical for RL: prompt token IDs in the chat completion response must match what the tokenizer produces for the same messages. Both endpoints use the same tokenizer instance that the frontend uses for its own request preprocessing. + +#### Sample: Tokenize (Chat variant) + +```bash +curl -s -X POST http://localhost:8000/v1/tokenize \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", + "messages": [ + {"role": "user", "content": "Reverse this: hello world"} + ], + "add_generation_prompt": true, + "add_special_tokens": true + }' +``` + +**Response:** + +```json +{ + "count": 21, + "max_model_len": 2048, + "tokens": [151644, 8948, 198, 151645, 198, 151644, 872, 198, + 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, + 198, 151644, 77091, 198], + "token_strs": null +} +``` + +#### Tokenize Request Fields (Chat variant) + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `model` | `string?` | auto-resolve | Model name for tokenizer lookup | +| `messages` | `ChatMessage[]` | -- | Messages to tokenize through the model's chat template | +| `add_generation_prompt` | `bool` | `true` | Append generation prompt (e.g., `<\|im_start\|>assistant\n`) | +| `add_special_tokens` | `bool` | `true` | Add BOS/EOS tokens | +| `return_token_strs` | `bool` | `false` | Include human-readable string representation of each token | +| `chat_template` | `string?` | `null` | Override the model's default chat template (Jinja2) | +| `chat_template_kwargs` | `object?` | `null` | Extra template variables | +| `continue_final_message` | `bool` | `false` | Continue last message instead of starting a new turn | + +The chat variant renders messages through the model's chat template before tokenizing, so the token count is the exact number of tokens that a corresponding chat completion request would consume. + +#### Sample: Tokenize (Completion variant) + +```bash +curl -s -X POST http://localhost:8000/v1/tokenize \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", + "prompt": "hello world", + "add_special_tokens": true, + "return_token_strs": true + }' +``` + +**Response:** + +```json +{ + "count": 3, + "max_model_len": 2048, + "tokens": [9707, 1879, 3], + "token_strs": ["hello", " world", ""] +} +``` + +#### Tokenize Response Fields + +| Field | Type | Description | +|-------|------|-------------| +| `count` | `int` | Number of tokens | +| `max_model_len` | `int` | Model's configured maximum context length | +| `tokens` | `list[int]` | Token ID list | +| `token_strs` | `list[str]?` | Human-readable token strings; only present if `return_token_strs: true` | + +#### Sample: Detokenize + +```bash +curl -s -X POST http://localhost:8000/v1/detokenize \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", + "tokens": [9707, 1879, 3] + }' +``` + +**Response:** + +```json +{"prompt": "hello world"} +``` + +--- + +### 4.4 Fleet Control (`/v1/rl/*`) + +All `/v1/rl/*` routes are mounted only when `DYN_ENABLE_RL=true`. They fan out to vLLM worker system ports defined by `DYN_RL_WORKER_SYSTEM_URLS`. + +In prime-rl's config: + +```toml +[client] +base_url = ["http://:8000/v1"] +admin_base_url = ["http://:8000/v1/rl"] +``` + +--- + +#### `GET /v1/rl/health` + +Lightweight liveness probe. Returns immediately as long as the frontend process is running. Used by prime-rl's `check_health()` on the admin client. + +```bash +curl -s http://localhost:8000/v1/rl/health +``` + +```json +{"status": "ok"} +``` + +--- + +#### `GET /v1/rl/ready` + +Composite readiness probe. Polls `/health` on every configured worker system URL concurrently. Returns 200 only when all workers respond with HTTP 2xx. + +```bash +curl -s http://localhost:8000/v1/rl/ready +``` + +```json +// All workers ready (200) +{ + "status": "ready", + "workers": [ + {"url": "http://localhost:8081", "healthy": true} + ] +} + +// Not all workers ready (503) +{ + "status": "not_ready", + "workers_ready": 0, + "workers_total": 1, + "workers": [ + {"url": "http://localhost:8081", "healthy": false, "error": "connection refused"} + ] +} +``` + +--- + +#### `POST /v1/rl/pause` + +Quiesces generation on all workers. Each worker calls `engine_client.pause_generation()` which drains in-flight requests without unloading the model from GPU memory. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/pause -H 'Content-Type: application/json' -d '{}' +``` + +```json +// Success (200) +{ + "status": "ok", + "workers": [ + {"status": "ok", "message": "Engine paused"} + ] +} + +// Failure (502) +{ + "status": "error", + "workers": [ + {"status": "ok", "message": "Engine paused"}, + {"status": "error", "message": "timeout"} + ] +} +``` + +--- + +#### `POST /v1/rl/resume` + +Resumes generation on all workers after a weight update. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/resume -H 'Content-Type: application/json' -d '{}' +``` + +```json +{ + "status": "ok", + "workers": [ + {"status": "ok", "message": "Engine resumed"} + ] +} +``` + +--- + +#### `POST /v1/rl/update_weights` + +Atomic weight-loading sequence: `flush_cache` then `update_weights_from_path` on all workers concurrently. The frontend performs the two-phase fan-out internally; prime-rl does not need to call these separately. + +For filesystem-backed weight broadcast, the trainer writes safetensors files to a shared PVC directory and passes that path here. The worker calls `engine_client.collective_rpc("reload_weights", kwargs={"weights_path": path})` which triggers vLLM's layerwise in-place weight reload on every GPU worker. + +For NCCL-based weight broadcast (`weight_dir: null`), Dynamo returns 200 immediately -- the actual weight transfer happens out-of-band via NCCL and Dynamo does not participate. + +```bash +# Filesystem mode +curl -s -X POST http://localhost:8000/v1/rl/update_weights \ + -H 'Content-Type: application/json' \ + -d '{"weight_dir": "/data/outputs/run_default/broadcasts/step_5"}' + +# NCCL mode (Dynamo no-op) +curl -s -X POST http://localhost:8000/v1/rl/update_weights \ + -H 'Content-Type: application/json' \ + -d '{"weight_dir": null}' +``` + +```json +// Success (200) +{ + "status": "ok", + "version": "step_5", + "workers": [ + {"status": "ok", "message": "Weights loaded from /data/outputs/...", "version": "step_5"} + ] +} + +// Failure at flush_cache stage (502) +{ + "status": "error", + "stage": "flush_cache", + "workers": [...] +} + +// Failure at update_weights stage (502) +{ + "status": "error", + "stage": "update_weights_from_path", + "workers": [...] +} +``` + +The `version` string is derived from the basename of `weight_dir` (e.g., `step_5` from `/data/outputs/run_default/broadcasts/step_5`). This version is stored in the worker and retrievable via `/v1/rl/weight_version`. + +--- + +#### `POST /v1/rl/load_lora_adapter` + +Hot-load or hot-swap a LoRA adapter from a filesystem path. The adapter directory must contain PEFT-style `adapter_model.safetensors` and `adapter_config.json` -- the default output layout of prime-rl's LoRA trainer. + +This is the RL-native LoRA path, distinct from Dynamo's URI-based `load_lora` gRPC endpoint (which downloads from S3/file URIs via `LoRAManager`). The admin route is optimized for the training loop: no URI fetch, no MDC churn on hot-swap. + +- **First call for a given `lora_name`**: `add_lora` in the engine, publish a ModelDeploymentCard so subsequent inference requests with `model=` route to this worker. +- **Subsequent calls (hot-swap)**: `remove_lora(old_id)` → `add_lora` with new weights → `reset_prefix_cache`. The MDC is left in place since it already points at this worker. + +Pair with `/v1/rl/pause` and `/v1/rl/resume` for full drain-swap-resume semantics. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/load_lora_adapter \ + -H 'Content-Type: application/json' \ + -d '{"lora_name": "r16-a32", "lora_path": "/data/outputs/run_default/broadcasts/step_5"}' +``` + +```json +// Success (200) +{ + "status": "ok", + "workers": [ + { + "status": "ok", + "message": "LoRA adapter 'r16-a32' loaded from /data/outputs/...", + "lora_name": "r16-a32", + "lora_id": 788776416, + "hot_swap": false + } + ] +} + +// Missing / empty field (400) +{ + "status": "error", + "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)" +} + +// Worker-side failure (502) -- e.g. bad adapter file, rank mismatch, vLLM not --enable-lora +{ + "status": "error", + "workers": [{"status": "error", "message": "..."}] +} +``` + +**vLLM worker requirements**: the engine must be started with `--enable-lora --max-lora-rank R --max-loras N`, with `R` ≥ the adapter rank and `N` ≥ the number of distinct `lora_name` values you expect to have loaded at once. For Prime-RL's single-adapter training loop, `--max-loras 1` is sufficient. + +--- + +#### `POST /v1/rl/unload_lora_adapter` + +Remove a previously loaded LoRA adapter by name. Idempotent: unloading an already-absent adapter returns `status: ok` so callers can retry safely. + +Unregisters the adapter's ModelDeploymentCard so the frontend stops routing `model=` requests to this worker. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/unload_lora_adapter \ + -H 'Content-Type: application/json' \ + -d '{"lora_name": "r16-a32"}' +``` + +```json +// Success (200) +{ + "status": "ok", + "workers": [ + { + "status": "ok", + "message": "LoRA adapter 'r16-a32' unloaded", + "lora_name": "r16-a32", + "lora_id": 788776416 + } + ] +} + +// Already absent -- still ok (200) +{ + "status": "ok", + "workers": [{"status": "ok", "message": "LoRA adapter 'r16-a32' not loaded (no-op)", "lora_name": "r16-a32"}] +} +``` + +--- + +#### `GET /v1/rl/weight_version` + +Returns the currently loaded weight version from all workers. Useful for debugging weight update races or confirming that all workers converged to the same checkpoint. + +```bash +curl -s http://localhost:8000/v1/rl/weight_version +``` + +```json +// All workers consistent (200) +{ + "status": "ok", + "version": "step_5", + "workers": [ + {"version": "step_5"}, + {"version": "step_5"} + ] +} + +// Workers inconsistent (200, with warning) +{ + "status": "inconsistent", + "versions": ["step_4", "step_5"], + "workers": [ + {"version": "step_4"}, + {"version": "step_5"} + ] +} +``` + +Returns HTTP 200 even when versions are inconsistent -- the `status` field distinguishes the cases. A 502 is only returned for network-level failures. + +--- + +## 5. Data Flow + +### 5.1 Rollout (Inference) Path + +```mermaid +sequenceDiagram + participant Orch as prime-rl Orchestrator + participant FE as Dynamo Frontend (Rust) + participant Worker as vLLM Worker (GPU) + + Orch->>FE: POST /v1/chat/completions
{messages, max_tokens, ...} + Note over FE: DYN_ENABLE_RL=true:
inject nvext.extra_fields
= ["token_ids", "completion_token_ids"]
force logprobs=true
save messages for tokenization + FE->>Worker: forward request (TCP/NATS) + Worker-->>FE: SSE chunks
(delta.content + delta.token_ids per chunk) + Note over FE: DeltaGenerator accumulates
completion_token_ids across chunks + Worker-->>FE: final chunk
(finish_reason + nvext.completion_token_ids) + Note over FE: Post-process:
1. rl_tokenize_prompt(messages)
-> response.prompt_token_ids
2. Promote nvext.completion_token_ids
-> choices[i].token_ids + FE-->>Orch: Enriched response:
prompt_token_ids + choices[i].token_ids
+ nvext.completion_token_ids +``` + +### 5.2 Weight Update Path + +```mermaid +sequenceDiagram + participant Trainer as prime-rl Trainer + participant PVC as Shared Storage + participant Orch as prime-rl Orchestrator + participant FE as Dynamo Frontend (Rust) + participant W1 as vLLM Worker 1 + participant W2 as vLLM Worker 2 + + Trainer->>PVC: write checkpoint
/data/outputs/.../step_N/*.safetensors + Trainer->>Orch: notify weight update ready
(internal IPC) + Orch->>FE: POST /v1/rl/pause + FE->>W1: POST /engine/pause_generation + FE->>W2: POST /engine/pause_generation + W1-->>FE: {status: ok} + W2-->>FE: {status: ok} + FE-->>Orch: {status: ok} + Orch->>FE: POST /v1/rl/update_weights
{weight_dir: /data/outputs/.../step_N} + FE->>W1: POST /engine/flush_cache + FE->>W2: POST /engine/flush_cache + W1-->>FE: {status: ok} + W2-->>FE: {status: ok} + FE->>W1: POST /engine/update_weights_from_path
{path: ..., version: step_N} + FE->>W2: POST /engine/update_weights_from_path
{path: ..., version: step_N} + Note over W1,W2: collective_rpc(reload_weights)
vLLM GPUModelRunner.reload_weights()
in-place layer-by-layer load from safetensors + W1-->>FE: {status: ok, version: step_N} + W2-->>FE: {status: ok, version: step_N} + FE-->>Orch: {status: ok, version: step_N} + Orch->>FE: POST /v1/rl/resume + FE->>W1: POST /engine/resume_generation + FE->>W2: POST /engine/resume_generation + W1-->>FE: {status: ok} + W2-->>FE: {status: ok} + FE-->>Orch: {status: ok} + Note over Orch: Continue rollouts
with updated weights +``` + +### 5.3 TITO (Tokens-In, Tokens-Out) Path + +```mermaid +sequenceDiagram + participant Client as Orchestrator (turn 2+) + participant FE as Dynamo Frontend (Rust) + participant PP as Preprocessor + participant Worker as vLLM Worker + + Client->>FE: POST /v1/chat/completions/tokens
{tokens: [9707, 1879, 3], max_tokens: 64} + Note over FE: Extract tokens field
Inject as nvext.token_data
Force extra_fields, logprobs + FE->>PP: Request with nvext.token_data + Note over PP: token_data present:
skip tokenization,
use provided IDs directly + PP->>Worker: Token IDs sent as-is + Worker-->>FE: Completion response
(with completion_token_ids) + FE-->>Client: Enriched response
(prompt_token_ids + token_ids) +``` + +--- + +## 6. Key Data Structures + +### `NvExtResponse` (Rust -- response side) + +Serialized as the `nvext` field in each SSE chunk or the unary response body: + +``` +NvExtResponse { + worker_id?: WorkerIdInfo -- prefill/decode worker IDs for disaggregated serving + timing?: TimingInfo -- request timing (enabled via extra_fields: ["timing"]) + token_ids?: Vec -- GAIE Stage 1: tokenized prompt for Stage 2 reuse + routed_experts?: serde_json::Value -- SGLang-specific expert routing payload + completion_token_ids?: Vec -- RL: generated output token IDs (final chunk only) +} +``` + +The `completion_token_ids` field is populated automatically for all requests when `DYN_ENABLE_RL=true`, or when the client sends `nvext.extra_fields: ["completion_token_ids"]`. + +### `NvCreateChatCompletionRequest` (Rust -- request side) + +New fields relevant to RL: + +| Field | Serialized | Description | +|-------|-----------|-------------| +| `tokens` | No (`skip_serializing`) | Pre-tokenized prompt token IDs (TITO path via `/v1/chat/completions/tokens`) | +| `return_token_ids` | No (`skip_serializing`) | prime-rl compat field; accepted but ignored on the standard endpoint -- use `DYN_ENABLE_RL` or `nvext.extra_fields` instead | + +Both fields are stripped before the request is forwarded to the vLLM worker, preventing 400 errors from the vLLM OpenAI-compatible API. + +### `NvCreateChatCompletionResponse` (Rust -- response side) + +``` +NvCreateChatCompletionResponse { + inner: CreateChatCompletionResponse -- standard OpenAI response fields + nvext?: serde_json::Value -- NvExtResponse serialized as JSON + prompt_token_ids?: Vec -- RL: tokenized prompt IDs (DYN_ENABLE_RL only) +} +``` + +### `DeltaGenerator` (Rust -- streaming pipeline) + +Manages per-request streaming state. Accumulates output token IDs across chunks: + +``` +DeltaGenerator { + ... + accumulated_completion_token_ids: Vec -- grows per chunk +} +``` + +- **Activation:** `options.enable_completion_token_ids` is set to `true` when `extra_fields` includes `"completion_token_ids"` (auto-set when `DYN_ENABLE_RL=true`). +- **Accumulation:** On each postprocessor output chunk, appends `delta.token_ids` to the accumulator. +- **Emission:** On the final chunk (`finish_reason` is set), the full list is emitted in `nvext.completion_token_ids` and the accumulator is cleared. + +### `NvExt` (Rust -- request-side NVIDIA extensions) + +Relevant fields for RL: + +``` +NvExt { + token_data?: Vec -- Pre-tokenized prompt IDs (TITO / EPP bypass) + extra_fields?: Vec -- Request extra response fields, e.g. ["completion_token_ids"] + backend_instance_id?: u64 -- Targeted routing to a specific worker + ... +} +``` + +### Tokenization Types + +``` +TokenizeRequest = Completion { prompt, model?, add_special_tokens? } + | Chat { messages, model?, add_generation_prompt?, chat_template?, ... } + +TokenizeResponse { count: int, max_model_len: int, tokens: Vec, token_strs?: Vec } + +DetokenizeRequest { model?: String, tokens: Vec } +DetokenizeResponse { prompt: String } +``` + +### Post-Processing Helpers + +**`rl_tokenize_prompt(state, model, messages) -> Option>`** + +Tokenizes the original prompt messages using the model's chat template and tokenizer: resolves model card from `state`, gets the tokenizer instance, builds a `PromptFormatter` from the model deployment card, renders messages through the chat template (same logic as the preprocessor), tokenizes the rendered string, and returns the token IDs. + +**`rl_promote_token_ids_in_response(json_val)`** + +Copies `response.nvext.completion_token_ids` to `response.choices[i].token_ids` for each choice. Bridges Dynamo's `nvext` convention with the field paths that Prime-RL/verifiers expects. + +--- + +## 7. Worker Engine Routes (Internal) + +Five engine route handlers are registered on each vLLM worker's system HTTP port (default `8081` local / `9090` in k8s). These are **internal** routes called by the Rust frontend's `/v1/rl/*` handlers -- not called directly by Prime-RL. + +| Route | Method | vLLM API called | Description | +|-------|--------|-----------------|-------------| +| `/engine/pause_generation` | POST | `engine_client.pause_generation()` | Drain in-flight requests, keep model loaded in GPU memory | +| `/engine/resume_generation` | POST | `engine_client.resume_generation()` | Resume accepting inference requests | +| `/engine/flush_cache` | POST | `engine_client.reset_prefix_cache()` | Invalidate prefix/KV cache (required before weight reload) | +| `/engine/update_weights_from_path` | POST | `engine_client.collective_rpc("reload_weights", ...)` | Load weights from filesystem (safetensors checkpoint) | +| `/engine/get_weight_version` | POST | `self._weight_version` | Return current weight version string | + +Both decode and prefill worker types register all 5 routes. Route signatures are compatible with SGLang's merged `#6094` routes for backend interoperability. + +### Registration (worker_factory.py) + +```python +runtime.register_engine_route("pause_generation", handler.pause_generation) +runtime.register_engine_route("resume_generation", handler.resume_generation) +runtime.register_engine_route("flush_cache", handler.flush_cache) +runtime.register_engine_route("update_weights_from_path", handler.update_weights_from_path) +runtime.register_engine_route("get_weight_version", handler.get_weight_version) +``` + +### publisher.py Crash Guard + +The `DynamoStatLoggerPublisher.record()` method includes a guard for `scheduler_stats is None`. This prevents an `AttributeError` crash during the transient window right after a weight reload when the vLLM engine's stats logger fires before the engine core has re-initialized its scheduler stats. + +--- + +## 8. Known Limitations + +| Limitation | Workaround | Notes | +|-----------|-----------|-------| +| `cache_salt` not supported -- returns 400 for requests with `cache_salt` in body | Set `[experimental] use_prefix_cache_salt = false` in prime-rl `orch.toml` | verifiers dev6+ defaults `use_prefix_cache_salt=True` | +| `prompt_token_ids` only injected for non-streaming responses | Use non-streaming mode for RL rollouts (the default) | Streaming final-chunk injection is planned | +| Weight version `"initial"` before first update | Do not depend on version string for correctness; use `/v1/rl/ready` for readiness | | +| NCCL weight broadcast is a no-op on Dynamo side | Use `type = "filesystem"` in `[weight_broadcast]` for all current deployments | | +| ~~`VLLM_USE_V1=0` required~~ **Resolved on vLLM 0.19.1** | Set `VLLM_USE_V1=1` on images shipping `VLLM_VER≥0.19.1` (current default). Keep `VLLM_USE_V1=0` only on legacy 0.18.x images where Meta-tensor crash with `--enforce-eager` still reproduces. | Verified under Run D (Qwen3.5-35B-A3B-FP8, 12 workers, batch=64) with V1 enabled and CUDA graphs for LoRA decode. | +| Filesystem weight broadcast scales poorly for large models | Acceptable for 0.6B (257ms load); marginal at 7B (~25s); impractical at 70B (~5 min) | RDMA pull transfer planned | + +--- + +## 9. Validation Results + +### Local (2x A6000, Qwen3-0.6B-Reverse-Text-SFT, 20 steps) + +| Metric | vLLM Baseline | Dynamo | Delta | +|--------|:-------------:|:------:|:-----:| +| Steps completed | 20/20 | 20/20 | -- | +| Peak reward | 0.798 | **0.825** | +3.4% | +| Final reward | 0.716 | **0.724** | +1.2% | +| `is_masked/mean` | 1.2% | **0.13%** | -92% (better) | +| Mismatch KL (final) | 0.0075 | **0.0056** | -25% (better) | +| Weight update cycles | 19 | 19 | -- | +| Mean weight cycle time | -- | 257.5ms | pause: 3.2ms, load: 249ms, resume: 4.8ms | + +W&B: https://wandb.ai/test232/prime-rl-parity-apr17 + +### Kubernetes (GB200, same model, 20 steps) + +| Metric | K8s | +|--------|:---:| +| Steps completed | 20/20 | +| Reward at step 13 | 0.714 (climbing) | +| Mismatch KL (steps 0-3) | 0.0007 - 0.0009 | +| Pods | 4 | +| All RL routes verified | Yes | + +The Rust RL API produces **better token alignment than native vLLM** (0.13% masked vs 1.2%). diff --git a/lib/llm/src/audit/stream.rs b/lib/llm/src/audit/stream.rs index 2663cdd7eeb0..4985139d37c8 100644 --- a/lib/llm/src/audit/stream.rs +++ b/lib/llm/src/audit/stream.rs @@ -101,6 +101,7 @@ where service_tier: None, }, nvext: None, + prompt_token_ids: None, } }) }), @@ -138,6 +139,7 @@ where service_tier: None, }, nvext: None, + prompt_token_ids: None, }; let _ = tx.send(fallback.clone()); final_response_to_one_chunk_stream(fallback) @@ -160,6 +162,7 @@ where service_tier: None, }, nvext: None, + prompt_token_ids: None, } }) }); diff --git a/lib/llm/src/entrypoint/input/text.rs b/lib/llm/src/entrypoint/input/text.rs index 1c0138fd34b3..6c7046602d7e 100644 --- a/lib/llm/src/entrypoint/input/text.rs +++ b/lib/llm/src/entrypoint/input/text.rs @@ -116,6 +116,8 @@ async fn main_loop( chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; // Call the model diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index a7733428105e..1855140200b1 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -55,6 +55,10 @@ use crate::protocols::openai::{ embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, images::{NvCreateImageRequest, NvImagesResponse}, responses::{NvCreateResponse, NvResponse, ResponseParams, chat_completion_to_response}, + tokenization::{ + DetokenizeRequest, DetokenizeResponse, TokenizeCompletionRequest, TokenizeRequest, + TokenizeResponse, + }, videos::{NvCreateVideoRequest, NvVideosResponse}, }; use crate::protocols::unified::UnifiedRequest; @@ -203,6 +207,21 @@ impl ErrorMessage { /// Not Implemented Error /// Return this error when the client requests a feature that is not yet implemented. /// This should be used for features that are planned but not available. + /// Bad Request Error + /// Return this error when the client sends an invalid request. + pub fn bad_request(msg: &str) -> ErrorResponse { + let code = StatusCode::BAD_REQUEST; + let error_type = map_error_code_to_error_type(code); + ( + code, + Json(ErrorMessage { + message: msg.to_string(), + error_type, + code: code.as_u16(), + }), + ) + } + pub fn not_implemented_error(msg: T) -> ErrorResponse { tracing::error!("Not Implemented error: {msg}"); let code = StatusCode::NOT_IMPLEMENTED; @@ -910,6 +929,69 @@ async fn handler_chat_completions( request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers); + // RL field promotion: wire `tokens` and `return_token_ids` when provided on the standard + // chat completions endpoint. This eliminates the need for the rl_admin Python proxy to + // intercept and rewrite these fields. + // + // If `return_token_ids` is true, request completion_token_ids in the response. + // Auto-enable when DYN_ENABLE_RL is set -- ensures token IDs flow even if the + // client forgets to request them. + let rl_want_token_ids = request + .return_token_ids + .take() + .unwrap_or_else(|| dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL")); + if rl_want_token_ids { + tracing::info!("RL: want_token_ids=true, will promote nvext.extra_fields"); + } + { + // If `tokens` is provided, inject into nvext.token_data (pre-tokenized prompt path). + let token_data = request.tokens.take(); + + if token_data.is_some() || rl_want_token_ids { + let mut nvext = request.nvext.take().unwrap_or_default(); + + if let Some(ids) = token_data { + if !ids.is_empty() { + nvext.token_data = Some(ids); + // Ensure messages is non-empty for model lookup / chat template + if request.inner.messages.is_empty() { + use dynamo_protocols::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, + }; + request + .inner + .messages + .push(ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text( + "(token-in mode)".to_string(), + ), + name: None, + }, + )); + } + } + } + + if rl_want_token_ids { + let mut extra_fields = nvext.extra_fields.take().unwrap_or_default(); + for field in &["token_ids", "completion_token_ids"] { + if !extra_fields.contains(&field.to_string()) { + extra_fields.push(field.to_string()); + } + } + nvext.extra_fields = Some(extra_fields); + // Also force logprobs on when RL is requesting token IDs + if request.inner.logprobs.is_none() { + request.inner.logprobs = Some(true); + } + } + + request.nvext = Some(nvext); + } + } + // create the context for the request let request_id = get_or_create_request_id(&headers); let streaming = request.inner.stream.unwrap_or(false); @@ -1195,6 +1277,26 @@ async fn chat_completions( // todo - decide on default let streaming = request.inner.stream.unwrap_or(false); + // RL: save messages for post-response prompt tokenization (needed for prompt_token_ids). + let rl_saved_messages = if dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL") && !streaming + { + Some(request.inner.messages.clone()) + } else { + None + }; + + // RL: for TITO requests the caller (handler_chat_completions_tokens) injects a + // placeholder message so Dynamo can select a chat template, but then saves the + // real token IDs in nvext.token_data. Capture them now — before the request is + // consumed by engine.generate() — so the post-processing step can use them + // directly as prompt_token_ids instead of re-tokenizing the placeholder. + let rl_tito_token_ids: Option> = + if dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL") && !streaming { + request.nvext.as_ref().and_then(|nv| nv.token_data.clone()) + } else { + None + }; + // Apply template values first to resolve the model before creating metrics guards if let Some(template) = template { if request.inner.model.is_empty() { @@ -1406,6 +1508,41 @@ async fn chat_completions( if ctx.is_killed() { inflight_guard.mark_error(ErrorType::Cancelled); } + + // RL post-processing: when DYN_ENABLE_RL is active, promote + // token IDs to the top-level locations that Prime-RL / verifiers expects: + // response.prompt_token_ids (from tokenizing the prompt) + // response.choices[i].token_ids (from nvext.completion_token_ids) + let response = if let Some(ref messages) = rl_saved_messages { + let mut response = response; + // For TITO requests, nvext.token_data IS the prompt — use those IDs + // directly. Falling back to rl_tokenize_prompt would re-tokenize the + // placeholder message injected by handler_chat_completions_tokens and + // return the wrong IDs. + response.prompt_token_ids = + rl_tito_token_ids.or_else(|| rl_tokenize_prompt(&state, &model, messages)); + match serde_json::to_value(&response) { + Ok(mut json_val) => { + rl_promote_token_ids_in_response(&mut json_val); + return Ok(Json(json_val).into_response()); + } + Err(e) => { + // This path means choice.token_ids will NOT be promoted — Prime-RL + // will see None for completion token IDs and may silently drop the + // rollout or crash. Log at error so data-loss does not go unnoticed. + tracing::error!( + request_id, + "rl_promote_token_ids: serde_json serialization failed — \ + choice.token_ids will NOT be promoted to top-level; \ + Prime-RL rollout may be dropped or corrupt: {e}" + ); + } + } + response + } else { + response + }; + Ok(Json(response).into_response()) } } @@ -1909,6 +2046,179 @@ pub(crate) fn check_ready(_state: &Arc) -> Result<(), ErrorRe Ok(()) } +// ── Tokenize / Detokenize ──────────────────────────────────────────── + +fn bad_request>(message: T) -> ErrorResponse { + let code = StatusCode::BAD_REQUEST; + ( + code, + Json(ErrorMessage { + message: message.into(), + error_type: map_error_code_to_error_type(code), + code: code.as_u16(), + }), + ) +} + +fn resolve_tokenizer_model_name( + state: &Arc, + requested_model: Option<&str>, +) -> Result { + if let Some(model) = requested_model { + if state.manager().has_model_any(model) { + return Ok(model.to_string()); + } + return Err(ErrorMessage::model_not_found()); + } + let served_models = state.manager().model_display_names(); + if served_models.len() == 1 { + return Ok(served_models.into_iter().next().unwrap()); + } + Err(bad_request( + "Model must be specified when more than one model is served.", + )) +} + +fn resolve_model_card( + state: &Arc, + requested_model: Option<&str>, +) -> Result<(String, crate::model_card::ModelDeploymentCard), ErrorResponse> { + let model = resolve_tokenizer_model_name(state, requested_model)?; + let card = state + .manager() + .get_model_cards() + .into_iter() + .find(|card| card.display_name == model) + .ok_or_else(|| { + ErrorMessage::internal_server_error(&format!( + "Tokenizer metadata is not available for model '{}'", + model + )) + })?; + Ok((model, card)) +} + +async fn tokenize( + State(state): State>, + Json(request): Json, +) -> Result { + check_ready(&state)?; + + let (_, card) = resolve_model_card(&state, request.model())?; + let tokenizer = card + .tokenizer() + .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to load tokenizer"))?; + + let (tokens, token_strs) = match request { + TokenizeRequest::Completion(TokenizeCompletionRequest { + prompt, + add_special_tokens, + return_token_strs, + .. + }) => { + let encoding = tokenizer + .encode_with_special_tokens(&prompt, add_special_tokens) + .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to tokenize prompt"))?; + let token_ids = encoding.token_ids().to_vec(); + let token_strs = if return_token_strs { + Some(tokenizer.convert_ids_to_tokens(&token_ids).map_err(|err| { + ErrorMessage::from_anyhow(err, "Failed to resolve token strings") + })?) + } else { + None + }; + (token_ids, token_strs) + } + TokenizeRequest::Chat(request) => { + let model = request + .model + .clone() + .unwrap_or_else(|| card.display_name.clone()); + // Render the chat messages to a prompt string via the model's chat template + let formatter = crate::preprocessor::prompt::PromptFormatter::from_mdc(&card) + .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to build chat formatter"))?; + let inner_request = dynamo_protocols::types::CreateChatCompletionRequest { + model, + messages: request.messages.clone(), + tools: request.tools.clone(), + ..Default::default() + }; + let wrapped = + crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest { + inner: inner_request, + common: Default::default(), + nvext: None, + chat_template_args: Some(request.merged_chat_template_kwargs()), + media_io_kwargs: None, + tokens: None, + return_token_ids: None, + unsupported_fields: Default::default(), + }; + let prompt = match formatter { + crate::preprocessor::prompt::PromptFormatter::OAI(f) => f.render(&wrapped), + } + .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to render chat prompt"))?; + + let encoding = tokenizer + .encode_with_special_tokens(&prompt, request.add_special_tokens) + .map_err(|err| { + ErrorMessage::from_anyhow(err, "Failed to tokenize rendered chat prompt") + })?; + let token_ids = encoding.token_ids().to_vec(); + let token_strs = if request.return_token_strs { + Some(tokenizer.convert_ids_to_tokens(&token_ids).map_err(|err| { + ErrorMessage::from_anyhow(err, "Failed to resolve token strings") + })?) + } else { + None + }; + (token_ids, token_strs) + } + }; + + Ok(Json(TokenizeResponse { + count: tokens.len(), + max_model_len: card.context_length, + tokens, + token_strs, + }) + .into_response()) +} + +async fn detokenize( + State(state): State>, + Json(request): Json, +) -> Result { + check_ready(&state)?; + + let (_, card) = resolve_model_card(&state, request.model.as_deref())?; + let tokenizer = card + .tokenizer() + .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to load tokenizer"))?; + let prompt: String = tokenizer + .decode(&request.tokens, false) + .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to detokenize prompt"))? + .into(); + + Ok(Json(DetokenizeResponse { prompt }).into_response()) +} + +pub fn tokenization_router(state: Arc) -> (Vec, Router) { + let tokenize_path = "/v1/tokenize"; + let detokenize_path = "/v1/detokenize"; + let docs = vec![ + RouteDoc::new(axum::http::Method::POST, tokenize_path), + RouteDoc::new(axum::http::Method::POST, detokenize_path), + ]; + let router = Router::new() + .route(tokenize_path, post(tokenize)) + .route(detokenize_path, post(detokenize)) + .layer(middleware::from_fn(smart_json_error_middleware)) + .layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) + .with_state(state); + (docs, router) +} + /// openai compatible format /// Example: /// { @@ -2020,6 +2330,137 @@ pub fn chat_completions_router( (vec![doc], router) } +/// Create an Axum [`Router`] for the RL TITO (Token-In / Token-Out) endpoint. +/// +/// This endpoint accepts Prime-RL's `tokens` field (pre-tokenized prompt), +/// translates it to `nvext.token_data`, forces logprobs on, and delegates +/// to the standard chat_completions handler -- all in Rust, eliminating the +/// Python rl-admin proxy from the hot inference path. +/// +/// If no path is provided, the default path is `/v1/chat/completions/tokens` +pub fn chat_completions_tokens_router( + state: Arc, + template: Option, + path: Option, +) -> (Vec, Router) { + let path = path.unwrap_or("/v1/chat/completions/tokens".to_string()); + let doc = RouteDoc::new(axum::http::Method::POST, &path); + let router = Router::new() + .route(&path, post(handler_chat_completions_tokens)) + .layer(middleware::from_fn(smart_json_error_middleware)) + .layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) + .with_state((state, template)); + (vec![doc], router) +} + +/// Handler for TITO (Token-In / Token-Out) chat completions. +/// +/// Accepts Prime-RL's request format which includes a `tokens` field containing +/// pre-tokenized prompt token IDs. The handler: +/// 1. Extracts the `tokens` field +/// 2. Injects them as `nvext.token_data` (Dynamo's native pre-tokenized input) +/// 3. Requests `token_ids` and `completion_token_ids` in the response via `nvext.extra_fields` +/// 4. Forces `logprobs = true` (RL always needs logprobs) +/// 5. Ensures `messages` is non-empty (Dynamo requires it for chat template selection) +/// 6. Delegates to the standard `chat_completions()` internal function (zero HTTP proxy) +async fn handler_chat_completions_tokens( + State((state, template)): State<(Arc, Option)>, + headers: HeaderMap, + Json(mut request): Json, +) -> Result { + check_ready(&state)?; + + // Extract the tokens field (Prime-RL's TITO input) + let tokens = request.tokens.take(); + // Clear return_token_ids (not supported by Dynamo, avoid confusion) + request.return_token_ids = None; + + if let Some(token_ids) = tokens { + if token_ids.is_empty() { + return Err(ErrorMessage::bad_request( + "TITO endpoint requires non-empty 'tokens' field", + )); + } + + // Inject tokens into nvext.token_data + let mut nvext = request.nvext.take().unwrap_or_default(); + nvext.token_data = Some(token_ids); + + // Request token echo and completion token IDs in response + let mut extra_fields = nvext.extra_fields.take().unwrap_or_default(); + for field in &["token_ids", "completion_token_ids"] { + if !extra_fields.contains(&field.to_string()) { + extra_fields.push(field.to_string()); + } + } + nvext.extra_fields = Some(extra_fields); + request.nvext = Some(nvext); + + // Force logprobs on (RL always needs them) + if request.inner.logprobs.is_none() { + request.inner.logprobs = Some(true); + } + + // Ensure messages is non-empty (Dynamo requires it for model lookup / chat template) + if request.inner.messages.is_empty() { + use dynamo_protocols::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, + }; + request + .inner + .messages + .push(ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text( + "(token-in mode)".to_string(), + ), + name: None, + }, + )); + } + } else { + return Err(ErrorMessage::bad_request( + "Missing 'tokens' field for TITO endpoint. \ + Use /v1/chat/completions for message-based requests.", + )); + } + + // Apply header routing overrides + request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers); + + // Delegate to the standard chat completions flow (no HTTP proxy!) + let request_id = get_or_create_request_id(&headers); + let streaming = request.inner.stream.unwrap_or(false); + let cancellation_labels = CancellationLabels { + model: request.inner.model.clone(), + endpoint: Endpoint::ChatCompletions.to_string(), + request_type: if streaming { "stream" } else { "unary" }.to_string(), + }; + let request = Context::with_id(request, request_id); + let context = request.context(); + + let (mut connection_handle, stream_handle) = create_connection_monitor( + context.clone(), + Some(state.metrics_clone()), + cancellation_labels, + ) + .await; + + let response = + tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span()) + .await + .map_err(|e| { + ErrorMessage::internal_server_error(&format!( + "Failed to await TITO chat completions task: {:?}", + e, + )) + })?; + + connection_handle.disarm(); + response +} + /// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint /// If not path is provided, the default path is `/v1/embeddings` pub fn embeddings_router( @@ -2639,6 +3080,511 @@ pub fn audios_router( (vec![doc], router) } +// ────────────────────────────────────────────────────────────────────────── +// RL Admin router: /v1/rl/* +// ────────────────────────────────────────────────────────────────────────── + +/// Environment variable for comma-separated worker system HTTP URLs. +/// Defaults to `http://localhost:8081` when not set. +const DYN_RL_WORKER_SYSTEM_URLS_ENV: &str = "DYN_RL_WORKER_SYSTEM_URLS"; + +/// Shared state for the RL admin router. +#[derive(Clone)] +struct RlState { + /// Worker system HTTP base URLs (e.g. `http://localhost:8081`). + /// Set via `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated list). + worker_system_urls: Vec, + /// Shared HTTP client for all fan-out calls to worker system ports. + http_client: reqwest::Client, +} + +impl RlState { + fn from_env() -> Self { + let worker_system_urls = std::env::var(DYN_RL_WORKER_SYSTEM_URLS_ENV) + .unwrap_or_else(|_| "http://localhost:8081".to_string()) + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect::>(); + tracing::info!( + "RL admin router configured with {} worker(s): {:?}", + worker_system_urls.len(), + worker_system_urls + ); + Self { + worker_system_urls, + http_client: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .expect("Failed to create RL router HTTP client"), + } + } + + /// Call a single engine route on one worker. Returns the JSON body. + async fn call_engine_route( + &self, + url: &str, + route: &str, + body: &serde_json::Value, + ) -> serde_json::Value { + let endpoint = format!("{url}/engine/{route}"); + match self.http_client.post(&endpoint).json(body).send().await { + Ok(resp) => { + let status = resp.status(); + match resp.json::().await { + Ok(v) => v, + Err(e) => serde_json::json!({ + "status": "error", + "message": format!("Failed to decode response from {endpoint}: {e}"), + "http_status": status.as_u16() + }), + } + } + Err(e) => serde_json::json!({ + "status": "error", + "message": format!("Request to {endpoint} failed: {e}") + }), + } + } + + /// Fan out an engine route call to all configured workers concurrently. + async fn fan_out(&self, route: &str, body: serde_json::Value) -> Vec { + let futures: Vec<_> = self + .worker_system_urls + .iter() + .map(|url| self.call_engine_route(url, route, &body)) + .collect(); + futures::future::join_all(futures).await + } + + /// Returns true only if all results have `status: "ok"`. + fn all_ok(results: &[serde_json::Value]) -> bool { + results + .iter() + .all(|r| r.get("status").and_then(|s| s.as_str()) == Some("ok")) + } +} + +/// `GET /v1/rl/ready` — composite readiness check: worker health via system port. +async fn rl_ready(State(state): State>) -> impl IntoResponse { + let futures: Vec<_> = state + .worker_system_urls + .iter() + .map(|url| { + let client = state.http_client.clone(); + let health_url = format!("{url}/health"); + async move { + client + .get(&health_url) + .send() + .await + .map(|r| r.status().is_success()) + .unwrap_or(false) + } + }) + .collect(); + let results = futures::future::join_all(futures).await; + let all_ready = !results.is_empty() && results.iter().all(|ok| *ok); + if all_ready { + (StatusCode::OK, Json(serde_json::json!({"status": "ready"}))) + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "not_ready", + "workers_ready": results.iter().filter(|ok| **ok).count(), + "workers_total": results.len() + })), + ) + } +} + +/// `POST /v1/rl/pause` — fan out `pause_generation` to all workers. +async fn rl_pause(State(state): State>) -> impl IntoResponse { + let results = state + .fan_out("pause_generation", serde_json::json!({})) + .await; + if RlState::all_ok(&results) { + tracing::info!("RL pause: all {} worker(s) paused", results.len()); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!("RL pause: some workers failed: {:?}", results); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `POST /v1/rl/resume` — fan out `resume_generation` to all workers. +async fn rl_resume(State(state): State>) -> impl IntoResponse { + let results = state + .fan_out("resume_generation", serde_json::json!({})) + .await; + if RlState::all_ok(&results) { + tracing::info!("RL resume: all {} worker(s) resumed", results.len()); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!("RL resume: some workers failed: {:?}", results); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `POST /v1/rl/update_weights` — atomic `flush_cache → update_weights_from_path` across all workers. +/// +/// Expected body: `{"weight_dir": "/path/to/checkpoint"}` or `{"weight_dir": null}` for NCCL mode. +/// +/// The sequence per worker is: `flush_cache → update_weights_from_path`. +/// The pause/resume envelope is left to Prime-RL, which can call `/v1/rl/pause` and +/// `/v1/rl/resume` explicitly for full drain-and-swap semantics. +async fn rl_update_weights( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let weight_dir = body + .get("weight_dir") + .and_then(|v| v.as_str()) + .map(str::to_string); + + if weight_dir.is_none() { + tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); + return ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "message": "NCCL mode, no-op on Dynamo side"})), + ); + } + + let weight_dir = weight_dir.unwrap(); + tracing::info!("RL update_weights: weight_dir={weight_dir}"); + + // Step 1: flush_cache across all workers + let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; + if !RlState::all_ok(&flush_results) { + tracing::warn!("RL update_weights: flush_cache failed: {:?}", flush_results); + return ( + StatusCode::BAD_GATEWAY, + Json( + serde_json::json!({"status": "error", "stage": "flush_cache", "workers": flush_results}), + ), + ); + } + + // Step 2: update_weights_from_path across all workers + let version = std::path::Path::new(&weight_dir) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string(); + let load_body = serde_json::json!({"path": weight_dir, "version": version}); + let load_results = state.fan_out("update_weights_from_path", load_body).await; + if RlState::all_ok(&load_results) { + tracing::info!( + "RL update_weights: all {} worker(s) updated weights to {weight_dir}", + load_results.len() + ); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": load_results})), + ) + } else { + tracing::warn!( + "RL update_weights: update_weights_from_path failed: {:?}", + load_results + ); + ( + StatusCode::BAD_GATEWAY, + Json( + serde_json::json!({"status": "error", "stage": "update_weights_from_path", "workers": load_results}), + ), + ) + } +} + +/// `POST /v1/rl/load_lora_adapter` — hot-load/swap a LoRA adapter from a filesystem path. +/// +/// Expected body: `{"lora_name": "r16-a32.0", "lora_path": "/path/to/adapter_dir"}` +/// +/// The adapter directory must contain PEFT-style `adapter_model.safetensors` and +/// `adapter_config.json`. This is the RL-specific LoRA path used by Prime-RL every +/// training step (separate from Dynamo's URI-based `load_lora` gRPC endpoint which +/// downloads adapters from S3/file URIs and publishes a new ModelDeploymentCard). +/// +/// Hot-swap semantics: calling with a `lora_name` that is already loaded removes +/// the previous adapter and loads the new one under the same deterministic int ID, +/// then resets the prefix cache so stale KV entries don't poison new rollouts. +/// +/// Pair with `/v1/rl/pause` and `/v1/rl/resume` for a full drain-swap-resume cycle. +async fn rl_load_lora_adapter( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let lora_name = body.get("lora_name").and_then(|v| v.as_str()); + let lora_path = body.get("lora_path").and_then(|v| v.as_str()); + + let (lora_name, lora_path) = match (lora_name, lora_path) { + (Some(n), Some(p)) if !n.is_empty() && !p.is_empty() => (n.to_string(), p.to_string()), + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "status": "error", + "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)" + })), + ); + } + }; + + tracing::info!("RL load_lora_adapter: lora_name={lora_name} lora_path={lora_path}"); + let results = state + .fan_out( + "load_lora_adapter", + serde_json::json!({"lora_name": lora_name, "lora_path": lora_path}), + ) + .await; + + if RlState::all_ok(&results) { + tracing::info!( + "RL load_lora_adapter: all {} worker(s) loaded LoRA '{lora_name}' from {lora_path}", + results.len() + ); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!("RL load_lora_adapter: some workers failed: {:?}", results); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `POST /v1/rl/unload_lora_adapter` — remove a previously loaded LoRA adapter by name. +/// +/// Expected body: `{"lora_name": "r16-a32.0"}` +/// +/// Idempotent: unloading an already-absent LoRA returns `status: ok` so callers +/// can retry safely without special-casing not-found. +async fn rl_unload_lora_adapter( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let lora_name = body + .get("lora_name") + .and_then(|v| v.as_str()) + .map(str::to_string); + + let lora_name = match lora_name { + Some(n) if !n.is_empty() => n, + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "status": "error", + "message": "Expected body: {\"lora_name\": str} (required, non-empty)" + })), + ); + } + }; + + tracing::info!("RL unload_lora_adapter: lora_name={lora_name}"); + let results = state + .fan_out( + "unload_lora_adapter", + serde_json::json!({"lora_name": lora_name}), + ) + .await; + + if RlState::all_ok(&results) { + tracing::info!( + "RL unload_lora_adapter: all {} worker(s) unloaded LoRA '{lora_name}'", + results.len() + ); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!("RL unload_lora_adapter: some workers failed: {:?}", results); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `GET /v1/rl/weight_version` — query weight version from all workers. +async fn rl_weight_version(State(state): State>) -> impl IntoResponse { + let results = state + .fan_out("get_weight_version", serde_json::json!({})) + .await; + + // Collect distinct versions and check for consistency + let versions: Vec<_> = results + .iter() + .filter_map(|r| { + r.get("version") + .and_then(|v| v.as_str()) + .map(str::to_string) + }) + .collect(); + + let unique: std::collections::HashSet<&str> = versions.iter().map(String::as_str).collect(); + if unique.len() == 1 { + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "version": unique.into_iter().next().unwrap_or(""), + "workers": results + })), + ) + } else { + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "inconsistent", + "versions": unique.into_iter().collect::>(), + "workers": results + })), + ) + } +} + +/// Promote token IDs from the Dynamo `nvext` response object to the top-level +/// locations that Prime-RL / verifiers expects: +/// +/// response.nvext.completion_token_ids → response.choices[i].token_ids +/// +/// Tokenize chat messages using the model's tokenizer and return prompt token IDs. +/// Used by the RL post-processing path to populate `response.prompt_token_ids`. +fn rl_tokenize_prompt( + state: &Arc, + model: &str, + messages: &[dynamo_protocols::types::ChatCompletionRequestMessage], +) -> Option> { + if messages.is_empty() { + return None; + } + let (_, card) = resolve_model_card(state, Some(model)).ok()?; + let tokenizer = card.tokenizer().ok()?; + let formatter = crate::preprocessor::prompt::PromptFormatter::from_mdc(&card).ok()?; + let inner_request = dynamo_protocols::types::CreateChatCompletionRequest { + model: model.to_string(), + messages: messages.to_vec(), + ..Default::default() + }; + let wrapped = crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest { + inner: inner_request, + common: Default::default(), + nvext: None, + chat_template_args: None, + media_io_kwargs: None, + tokens: None, + return_token_ids: None, + unsupported_fields: Default::default(), + }; + let prompt = match formatter { + crate::preprocessor::prompt::PromptFormatter::OAI(f) => f.render(&wrapped), + } + .ok()?; + let encoding = tokenizer.encode_with_special_tokens(&prompt, true).ok()?; + Some(encoding.token_ids().to_vec()) +} + +/// This lets Prime-RL read `choice.token_ids` without knowing about the `nvext` +/// extension structure. Called on non-streaming responses when RL token ID mode +/// is active. +fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { + // Move completion_token_ids from response-level nvext to each choice. + // Prime-RL / verifiers expects: + // response.choices[i].token_ids (not response.nvext.completion_token_ids) + let has_nvext = json_val.get("nvext").is_some(); + let has_completion_ids = json_val + .get("nvext") + .and_then(|nv| nv.get("completion_token_ids")) + .is_some(); + + tracing::debug!( + has_nvext, + has_completion_ids, + "rl_promote_token_ids_in_response: inspecting response" + ); + + if let Some(nvext) = json_val.get("nvext") { + if let Some(completion_ids) = nvext.get("completion_token_ids").cloned() { + let n = completion_ids.as_array().map(|a| a.len()).unwrap_or(0); + tracing::info!( + n_completion_ids = n, + "rl_promote: copying completion_token_ids to choices[].token_ids" + ); + if let Some(choices) = json_val.get_mut("choices").and_then(|c| c.as_array_mut()) { + for choice in choices.iter_mut() { + if let Some(obj) = choice.as_object_mut() { + obj.insert("token_ids".to_string(), completion_ids.clone()); + } + } + } + } + } +} + +/// `GET /v1/rl/health` — lightweight health check for Prime-RL admin client. +/// +/// Prime-RL's `check_health()` calls `GET /health` on the admin client. When +/// `admin_base_url = ["http://dynamo:8000/v1/rl"]` the request arrives here. +/// Returns 200 OK if the frontend process is running (no deep probe needed — +/// the frontend's own `/health` endpoint handles that separately). +async fn rl_health() -> impl IntoResponse { + (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) +} + +/// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. +/// +/// Worker system URLs are read from the `DYN_RL_WORKER_SYSTEM_URLS` environment +/// variable (comma-separated, defaults to `http://localhost:8081`). +/// +/// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` is set. +/// +/// Prime-RL usage: set `admin_base_url = ["http://dynamo-frontend:8000/v1/rl"]` +/// in the orchestrator config. Prime-RL strips the trailing `/v1` suffix only +/// if present, so `/v1/rl` is preserved and all routes resolve correctly. +pub fn rl_router() -> (Vec, Router) { + let rl_state = Arc::new(RlState::from_env()); + let docs = vec![ + RouteDoc::new(axum::http::Method::GET, "/v1/rl/health"), + RouteDoc::new(axum::http::Method::GET, "/v1/rl/ready"), + RouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), + RouteDoc::new(axum::http::Method::POST, "/v1/rl/resume"), + RouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), + RouteDoc::new(axum::http::Method::POST, "/v1/rl/load_lora_adapter"), + RouteDoc::new(axum::http::Method::POST, "/v1/rl/unload_lora_adapter"), + RouteDoc::new(axum::http::Method::GET, "/v1/rl/weight_version"), + ]; + let router = Router::new() + .route("/v1/rl/health", get(rl_health)) + .route("/v1/rl/ready", get(rl_ready)) + .route("/v1/rl/pause", post(rl_pause)) + .route("/v1/rl/resume", post(rl_resume)) + .route("/v1/rl/update_weights", post(rl_update_weights)) + .route("/v1/rl/load_lora_adapter", post(rl_load_lora_adapter)) + .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) + .route("/v1/rl/weight_version", get(rl_weight_version)) + .layer(middleware::from_fn(smart_json_error_middleware)) + .with_state(rl_state); + (docs, router) +} + #[cfg(test)] mod tests { diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 1eeaefdff912..fe5d8ea6ba13 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -246,6 +246,12 @@ pub struct HttpServiceConfig { #[builder(default = "false")] enable_anthropic_endpoints: bool, + /// When true, expose the RL admin routes at `/v1/rl/*` (pause, resume, + /// update_weights, weight_version, ready). Worker system URLs are read + /// from `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated, default `http://localhost:8081`). + #[builder(default = "false")] + enable_rl: bool, + #[builder(default = "None")] request_template: Option, @@ -518,7 +524,7 @@ impl HttpServiceConfigBuilder { }; // System routes (health, metrics, models) — debug-level spans - let system_routes = vec![ + let mut system_routes = vec![ metrics::router( registry, var(HTTP_SVC_METRICS_PATH_ENV).ok(), @@ -532,10 +538,16 @@ impl HttpServiceConfigBuilder { } else { super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()) }, + super::openai::tokenization_router(state.clone()), super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), ]; + // RL admin routes: enabled when builder flag is set OR when DYN_ENABLE_RL env var is truthy. + if config.enable_rl || env_is_truthy("DYN_ENABLE_RL") { + tracing::info!("RL admin routes enabled at /v1/rl/*"); + system_routes.push(super::openai::rl_router()); + } let mut system_router = axum::Router::new(); for (route_docs, route) in system_routes { system_router = system_router.merge(route); @@ -600,6 +612,15 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_CHAT_PATH_ENV).ok(), ); + // RL TITO (Token-In / Token-Out) endpoint -- mounted alongside chat completions. + // Accepts Prime-RL's `tokens` field, translates to nvext.token_data, and delegates + // to the standard chat completions pipeline. Eliminates the Python rl-admin proxy. + let (tito_docs, tito_route) = super::openai::chat_completions_tokens_router( + state.clone(), + request_template.clone(), + None, + ); + let (cmpl_docs, cmpl_route) = super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok()); let (embed_docs, embed_route) = @@ -612,8 +633,13 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_RESPONSES_PATH_ENV).ok(), ); + // Merge TITO route and docs into the chat route (shares enable/disable flag) + let chat_route = chat_route.merge(tito_route); + let mut combined_chat_docs = chat_docs; + combined_chat_docs.extend(tito_docs); + let mut endpoint_routes = HashMap::new(); - endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route)); + endpoint_routes.insert(EndpointType::Chat, (combined_chat_docs, chat_route)); endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route)); endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route)); endpoint_routes.insert(EndpointType::Images, (images_docs, images_route)); diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 53756bb90c49..e8838513d588 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -596,15 +596,24 @@ impl OpenAIPreprocessor { let token_data = request.nvext().and_then(|ext| ext.token_data.as_ref()); + // Use token_data when provided (TITO / EPP / RL), + // regardless of backend_instance_id. + // + // skip_token_annotation = has_backend_instance_id: GAIE EPP-style + // callers (which set backend_instance_id) pre-tokenize and don't + // want the annotation echoed back; RL / TITO callers (no + // backend_instance_id) DO want the token_ids annotation in the + // response so the trainer can validate. let (tokens_vec, skip_token_annotation) = if let Some(tokens) = token_data { tracing::info!( token_count = tokens.len(), first_tokens = ?&tokens[..std::cmp::min(5, tokens.len())], + backend_instance_id = has_backend_instance_id, "[SIDECAR-SKIP-TOKENIZE] Found nvext.token_data — using pre-computed tokens, SKIPPING tokenization" ); - (tokens.clone(), true) + (tokens.clone(), has_backend_instance_id) } else if has_backend_instance_id { tracing::warn!( "backend_instance_id provided but no token_data; tokenizing prompt" diff --git a/lib/llm/src/protocols/anthropic/types.rs b/lib/llm/src/protocols/anthropic/types.rs index 5214ee10b0a7..4104191a61ae 100644 --- a/lib/llm/src/protocols/anthropic/types.rs +++ b/lib/llm/src/protocols/anthropic/types.rs @@ -141,6 +141,8 @@ impl TryFrom for NvCreateChatCompletionRequest { }, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }) } } diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 42ef621f8797..4d022ac01c83 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -20,6 +20,7 @@ pub mod images; pub mod models; pub mod nvext; pub mod responses; +pub mod tokenization; pub mod tools; pub mod validate; pub mod videos; @@ -90,6 +91,10 @@ pub(crate) trait OpenAIOutputOptionsProvider { fn get_skip_special_tokens(&self) -> Option; fn get_formatted_prompt(&self) -> Option; + + fn get_return_tokens_as_token_ids(&self) -> Option { + None + } } impl SamplingOptionsProvider for T { @@ -203,7 +208,6 @@ impl OutputOptionsProvider for T { let prompt_logprobs = self.get_prompt_logprobs(); let skip_special_tokens = self.get_skip_special_tokens(); let formatted_prompt = self.get_formatted_prompt(); - Ok(common::OutputOptions { logprobs, prompt_logprobs, diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index 8a77038d5834..6e200e53bd67 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -59,6 +59,18 @@ pub struct NvCreateChatCompletionRequest { #[serde(default, skip_serializing_if = "Option::is_none")] pub media_io_kwargs: Option, + /// RL: Pre-tokenized prompt tokens from Prime-RL's TITO interface. + /// On the standard `/v1/chat/completions` endpoint this field is accepted but ignored + /// (use `/v1/chat/completions/tokens` for TITO mode where tokens are authoritative). + /// Accepting it here avoids 400 errors when Prime-RL sends it without the rl-admin proxy. + #[serde(default, skip_serializing)] + pub tokens: Option>, + + /// RL: Prime-RL requests token IDs in the response via this field. + /// Accepted but ignored on standard chat completions (use `nvext.extra_fields` instead). + #[serde(default, skip_serializing)] + pub return_token_ids: Option, + /// Catch-all for unsupported fields - checked during validation #[serde(flatten, default, skip_serializing)] pub unsupported_fields: std::collections::HashMap, @@ -72,6 +84,10 @@ pub struct NvCreateChatCompletionResponse { pub inner: dynamo_protocols::types::CreateChatCompletionResponse, #[serde(skip_serializing_if = "Option::is_none")] pub nvext: Option, + /// RL: Prompt token IDs for Prime-RL/verifiers alignment. + /// Populated when `DYN_ENABLE_RL=true` or `return_token_ids=true`. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_token_ids: Option>, } /// A response structure for streamed chat completions, embedding OpenAI's diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index 5def709448d8..2b5fc16e1414 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -419,6 +419,7 @@ impl DeltaAggregator { service_tier: aggregator.service_tier, }, nvext: aggregator.nvext, + prompt_token_ids: None, }; Ok(response) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 4a3dfae2d405..27943317b7a2 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -10,7 +10,7 @@ use crate::{ common::{self, timing::RequestTracker}, openai::{ convert_backend_top_logprobs, - nvext::{NvExtProvider, NvExtResponseFieldSelection}, + nvext::{NvExtProvider, NvExtResponse, NvExtResponseFieldSelection}, token_to_utf8_bytes, }, }, @@ -51,6 +51,7 @@ impl NvCreateChatCompletionRequest { /// # Returns /// * [`DeltaGenerator`] configured with model name and response options. pub fn response_generator(&self, request_id: String) -> DeltaGenerator { + // `completion_token_ids` is parsed by from_nvext into response_fields. let response_fields = NvExtResponseFieldSelection::from_nvext(self.nvext()); let options = DeltaGeneratorOptions { @@ -86,6 +87,7 @@ pub struct DeltaGeneratorOptions { /// Determines whether log probabilities should be included in the response. pub enable_logprobs: bool, /// Determines which nvext response fields may be emitted for this request. + /// (Includes `completion_token_ids` for the RL inference path.) pub response_fields: NvExtResponseFieldSelection, pub runtime_config: ModelRuntimeConfig, @@ -112,6 +114,10 @@ pub struct DeltaGenerator { options: DeltaGeneratorOptions, /// Optional request tracker for per-request metrics (shared with PreprocessedRequest). tracker: Option>, + /// Accumulated output token IDs across chunks. Only used when + /// `options.response_fields.completion_token_ids` is true. Emitted in `nvext.completion_token_ids` + /// on the final (finish_reason-bearing) chunk. + accumulated_completion_token_ids: Vec, } impl DeltaGenerator { @@ -160,6 +166,7 @@ impl DeltaGenerator { msg_counter: 0, options, tracker, + accumulated_completion_token_ids: Vec::new(), } } @@ -353,6 +360,12 @@ impl crate::protocols::openai::DeltaGeneratorExt, + + /// Output token IDs generated by the engine (RL inference path). + /// Populated when client requests `extra_fields: ["completion_token_ids"]` + /// or auto-enabled under `DYN_ENABLE_RL=true` for the chat-completions path. + /// For RL: `len(completion_token_ids) == len(logprobs.content)` is a hard invariant. + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_token_ids: Option>, } /// Response nvext fields requested for a given request. @@ -137,6 +144,7 @@ pub struct NvExtResponseFieldSelection { pub token_ids: bool, pub routed_experts: bool, pub engine_data: bool, + pub completion_token_ids: bool, } impl NvExtResponseFieldSelection { @@ -153,6 +161,8 @@ impl NvExtResponseFieldSelection { "timing" => selection.timing = true, "routed_experts" => selection.routed_experts = true, "engine_data" => selection.engine_data = true, + "completion_token_ids" => selection.completion_token_ids = true, + "token_ids" => selection.token_ids = true, _ => {} } } @@ -181,6 +191,8 @@ impl NvExtResponseFieldSelection { /// - `worker_id` requires the selection flag **and** `tracker.get_worker_info()` to return `Some`. /// - `token_ids` requires the selection flag **and** a `"token_ids"` key on `disaggregated_params` /// that deserializes into `Vec`; malformed values silently fall back to `None`. + /// - `completion_token_ids` requires the selection flag **and** a `"completion_token_ids"` key on + /// `disaggregated_params` that deserializes into `Vec`; malformed values silently fall back to `None`. /// - `routed_experts` requires the selection flag **and** a `"routed_experts"` key on /// `disaggregated_params` (cloned as-is, no validation). /// - `timing` requires the selection flag, `finish_reason_present == true`, **and** a tracker. @@ -206,6 +218,14 @@ impl NvExtResponseFieldSelection { None }; + let completion_token_ids = if self.completion_token_ids { + disaggregated_params + .and_then(|params| params.get("completion_token_ids")) + .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + } else { + None + }; + let routed_experts = if self.routed_experts { disaggregated_params .and_then(|params| params.get("routed_experts")) @@ -228,6 +248,7 @@ impl NvExtResponseFieldSelection { if worker_id.is_none() && token_ids.is_none() + && completion_token_ids.is_none() && routed_experts.is_none() && timing.is_none() && engine_data.is_none() @@ -241,6 +262,7 @@ impl NvExtResponseFieldSelection { token_ids, routed_experts, engine_data, + completion_token_ids, }) } } diff --git a/lib/llm/src/protocols/openai/responses/mod.rs b/lib/llm/src/protocols/openai/responses/mod.rs index 2dcc530573c6..e806b3c74732 100644 --- a/lib/llm/src/protocols/openai/responses/mod.rs +++ b/lib/llm/src/protocols/openai/responses/mod.rs @@ -732,6 +732,8 @@ impl TryFrom for NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }) } } diff --git a/lib/llm/src/protocols/openai/tokenization.rs b/lib/llm/src/protocols/openai/tokenization.rs new file mode 100644 index 000000000000..95559684ad89 --- /dev/null +++ b/lib/llm/src/protocols/openai/tokenization.rs @@ -0,0 +1,124 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::preprocessor::media::MediaDecoder; +use crate::types::TokenIdType; + +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct TokenizeCompletionRequest { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, + pub prompt: String, + #[serde(default = "default_true")] + pub add_special_tokens: bool, + #[serde(default = "default_false")] + pub return_token_strs: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct TokenizeChatRequest { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, + pub messages: Vec, + #[serde(default = "default_true")] + pub add_generation_prompt: bool, + #[serde(default = "default_false")] + pub return_token_strs: bool, + #[serde(default = "default_false")] + pub continue_final_message: bool, + #[serde(default = "default_false")] + pub add_special_tokens: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub chat_template: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + alias = "chat_template_args" + )] + pub chat_template_kwargs: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub media_io_kwargs: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mm_processor_kwargs: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, +} + +impl TokenizeChatRequest { + pub fn validate(&self) -> Result<(), String> { + if self.continue_final_message && self.add_generation_prompt { + return Err( + "Cannot set both `continue_final_message` and `add_generation_prompt` to True." + .to_string(), + ); + } + + Ok(()) + } + + pub fn merged_chat_template_kwargs(&self) -> HashMap { + let mut kwargs = self.chat_template_kwargs.clone().unwrap_or_default(); + kwargs.insert( + "add_generation_prompt".to_string(), + serde_json::Value::Bool(self.add_generation_prompt), + ); + kwargs.insert( + "continue_final_message".to_string(), + serde_json::Value::Bool(self.continue_final_message), + ); + kwargs + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +#[allow(clippy::large_enum_variant)] +pub enum TokenizeRequest { + Completion(TokenizeCompletionRequest), + Chat(TokenizeChatRequest), +} + +impl TokenizeRequest { + pub fn model(&self) -> Option<&str> { + match self { + Self::Completion(request) => request.model.as_deref(), + Self::Chat(request) => request.model.as_deref(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct TokenizeResponse { + pub count: usize, + pub max_model_len: u32, + pub tokens: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_strs: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct DetokenizeRequest { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, + pub tokens: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct DetokenizeResponse { + pub prompt: String, +} diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index 237e84bc75be..fff6ea8edd65 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -97,16 +97,24 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0; // Shared Fields // +/// Fields that Prime-RL / verifiers may send as extra_body hints which Dynamo +/// does not implement but should not reject with a 400. They are silently +/// accepted and ignored so the RL client stack is forward-compatible. +const PASSTHROUGH_EXTRA_FIELDS: &[&str] = &[ + "cache_salt", // KV prefix-cache isolation hint from prime-rl orchestrator +]; + /// Validates that no unsupported fields are present in the request pub fn validate_no_unsupported_fields( unsupported_fields: &std::collections::HashMap, ) -> Result<(), anyhow::Error> { - if !unsupported_fields.is_empty() { - let fields: Vec<_> = unsupported_fields - .keys() - .map(|s| format!("`{}`", s)) - .collect(); - anyhow::bail!("Unsupported parameter(s): {}", fields.join(", ")); + let unknown: Vec<_> = unsupported_fields + .keys() + .filter(|k| !PASSTHROUGH_EXTRA_FIELDS.contains(&k.as_str())) + .map(|s| format!("`{}`", s)) + .collect(); + if !unknown.is_empty() { + anyhow::bail!("Unsupported parameter(s): {}", unknown.join(", ")); } Ok(()) } diff --git a/lib/tokenizers/src/fastokens.rs b/lib/tokenizers/src/fastokens.rs index 93e855cc5c58..4295e69fcb91 100644 --- a/lib/tokenizers/src/fastokens.rs +++ b/lib/tokenizers/src/fastokens.rs @@ -39,16 +39,28 @@ impl FastTokenizer { impl Encoder for FastTokenizer { fn encode(&self, input: &str) -> Result { + self.encode_with_special_tokens(input, false) + } + + fn encode_batch(&self, inputs: &[&str]) -> Result> { + inputs.par_iter().map(|input| self.encode(input)).collect() + } + + fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + if add_special_tokens { + return self.hf_decoder.encode_with_special_tokens(input, true); + } + let ids = self .fast_encoder .encode(input) .map_err(|e| Error::msg(format!("Fastokens encode error: {e}")))?; Ok(Encoding::Sp(ids)) } - - fn encode_batch(&self, inputs: &[&str]) -> Result> { - inputs.par_iter().map(|input| self.encode(input)).collect() - } } impl Decoder for FastTokenizer { @@ -57,7 +69,11 @@ impl Decoder for FastTokenizer { } } -impl Tokenizer for FastTokenizer {} +impl Tokenizer for FastTokenizer { + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + self.hf_decoder.convert_ids_to_tokens(token_ids) + } +} #[cfg(test)] mod tests { diff --git a/lib/tokenizers/src/hf.rs b/lib/tokenizers/src/hf.rs index 080a775719fe..68c720c3ea8c 100644 --- a/lib/tokenizers/src/hf.rs +++ b/lib/tokenizers/src/hf.rs @@ -27,19 +27,18 @@ impl HuggingFaceTokenizer { impl Encoder for HuggingFaceTokenizer { fn encode(&self, input: &str) -> Result { - // This self.tokenizer is the library - let encoding = self - .tokenizer - .encode(input, false) - .map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?; - - Ok(Encoding::Hf(Box::new(encoding))) + // Use add_special_tokens=true to match TikTokenTokenizer::encode() behaviour. + // Both backends must agree on whether BOS/EOS are included so that callers + // (e.g. /v1/tokenize, rl_tokenize_prompt) get consistent token counts + // regardless of which backend is active. Callers that explicitly need no + // special tokens should call encode_with_special_tokens(input, false) directly. + self.encode_with_special_tokens(input, true) } fn encode_batch(&self, inputs: &[&str]) -> Result> { let hf_encodings = self .tokenizer - .encode_batch(inputs.to_vec(), false) + .encode_batch(inputs.to_vec(), true) // true to match encode() above .map_err(|err| Error::msg(format!("Error batch tokenizing input: {err}")))?; let encodings = hf_encodings @@ -49,6 +48,20 @@ impl Encoder for HuggingFaceTokenizer { Ok(encodings) } + + fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + // This self.tokenizer is the library + let encoding = self + .tokenizer + .encode(input, add_special_tokens) + .map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?; + + Ok(Encoding::Hf(Box::new(encoding))) + } } impl Decoder for HuggingFaceTokenizer { @@ -63,7 +76,14 @@ impl Decoder for HuggingFaceTokenizer { } } -impl Tokenizer for HuggingFaceTokenizer {} +impl Tokenizer for HuggingFaceTokenizer { + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + Ok(token_ids + .iter() + .map(|&id| self.tokenizer.id_to_token(id).unwrap_or_default()) + .collect()) + } +} impl From for HuggingFaceTokenizer { fn from(tokenizer: HfTokenizer) -> Self { diff --git a/lib/tokenizers/src/lib.rs b/lib/tokenizers/src/lib.rs index 95494b4f73f0..dbf00955b0bc 100644 --- a/lib/tokenizers/src/lib.rs +++ b/lib/tokenizers/src/lib.rs @@ -63,6 +63,14 @@ pub mod traits { pub trait Encoder: Send + Sync { fn encode(&self, input: &str) -> Result; fn encode_batch(&self, inputs: &[&str]) -> Result>; + + fn encode_with_special_tokens( + &self, + input: &str, + _add_special_tokens: bool, + ) -> Result { + self.encode(input) + } } /// Result of decoding token IDs to text. @@ -128,8 +136,17 @@ pub mod traits { } pub trait Tokenizer: Encoder + Decoder { - // fn get_vocab_size(&self) -> usize; - // fn make_unique_clone(&self) -> Box; + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + // Decoder::decode returns DecodeResult (Complete/Partial); the existing + // `impl From for String` unwraps to the inner string. + token_ids + .iter() + .map(|id| { + self.decode(std::slice::from_ref(id), false) + .map(String::from) + }) + .collect() + } } } @@ -224,6 +241,18 @@ impl Tokenizer { Ok(Tokenizer(create_tokenizer_from_file(file_path)?)) } + pub fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + self.0.encode_with_special_tokens(input, add_special_tokens) + } + + pub fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + self.0.convert_ids_to_tokens(token_ids) + } + /// Create a stateful sequence object for decoding token_ids into text pub fn decode_stream( &self, diff --git a/lib/tokenizers/src/tiktoken.rs b/lib/tokenizers/src/tiktoken.rs index 7082acb0f6a6..cc049c13001d 100644 --- a/lib/tokenizers/src/tiktoken.rs +++ b/lib/tokenizers/src/tiktoken.rs @@ -24,6 +24,8 @@ const KIMI_PATTERN: &str = r#"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p pub struct TikTokenTokenizer { bpe: CoreBPE, special_token_ids: HashSet, + decoder_tokens: FxHashMap>, + special_tokens_decoder: FxHashMap>, } impl TikTokenTokenizer { @@ -39,6 +41,14 @@ impl TikTokenTokenizer { special_tokens: FxHashMap, ) -> Result { let encoder = parse_tiktoken_file(path)?; + let decoder_tokens: FxHashMap> = encoder + .iter() + .map(|(bytes, &id)| (id, bytes.clone())) + .collect(); + let special_tokens_decoder: FxHashMap> = special_tokens + .iter() + .map(|(token, &id)| (id, token.as_bytes().to_vec())) + .collect(); let special_token_ids: HashSet = special_tokens.values().copied().collect(); let bpe = CoreBPE::new(encoder, special_tokens, pattern) @@ -47,6 +57,8 @@ impl TikTokenTokenizer { Ok(Self { bpe, special_token_ids, + decoder_tokens, + special_tokens_decoder, }) } @@ -62,9 +74,17 @@ impl TikTokenTokenizer { let pattern = detect_bpe_pattern(directory)?; let encoder = parse_tiktoken_file(path)?; + let decoder_tokens: FxHashMap> = encoder + .iter() + .map(|(bytes, &id)| (id, bytes.clone())) + .collect(); // Use max rank + 1 (not len) to avoid ID collisions with sparse/non-contiguous ranks let num_base_tokens = encoder.values().max().map_or(0, |&m| m + 1) as usize; let special_tokens = load_special_tokens(directory, num_base_tokens)?; + let special_tokens_decoder: FxHashMap> = special_tokens + .iter() + .map(|(token, &id)| (id, token.as_bytes().to_vec())) + .collect(); let special_token_ids: HashSet = special_tokens.values().copied().collect(); let bpe = CoreBPE::new(encoder, special_tokens, pattern) @@ -73,19 +93,33 @@ impl TikTokenTokenizer { Ok(Self { bpe, special_token_ids, + decoder_tokens, + special_tokens_decoder, }) } } impl Encoder for TikTokenTokenizer { fn encode(&self, input: &str) -> Result { - let token_ids: Vec = self.bpe.encode_with_special_tokens(input); - Ok(Encoding::Sp(token_ids)) + self.encode_with_special_tokens(input, true) } fn encode_batch(&self, inputs: &[&str]) -> Result> { inputs.par_iter().map(|input| self.encode(input)).collect() } + + fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + let token_ids: Vec = if add_special_tokens { + self.bpe.encode_with_special_tokens(input) + } else { + self.bpe.encode_ordinary(input) + }; + Ok(Encoding::Sp(token_ids)) + } } impl Decoder for TikTokenTokenizer { @@ -119,7 +153,20 @@ impl Decoder for TikTokenTokenizer { } } -impl Tokenizer for TikTokenTokenizer {} +impl Tokenizer for TikTokenTokenizer { + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + Ok(token_ids + .iter() + .map(|id| { + self.decoder_tokens + .get(id) + .or_else(|| self.special_tokens_decoder.get(id)) + .map(|bytes| String::from_utf8_lossy(bytes).into_owned()) + .unwrap_or_default() + }) + .collect()) + } +} /// Parse a tiktoken model file (base64-encoded token + rank per line). fn parse_tiktoken_file(path: &str) -> Result, u32>> { From d295ebc658798df4e88dba3f506a8b7f2aca8771 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Mon, 4 May 2026 13:49:30 -0700 Subject: [PATCH 02/18] feat(rl): composite state endpoint, 3-mode pause, RL extras, drop dead URIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Incremental refactor on bis/dynamo-rl per docs/design-docs/rl-support.md. Test loop ~/dev/rl/work/bis-dev/4-02/{lora,sft}/run.sh — both smokes PASS end-to-end after this commit. Composite state + liveness probe handlers.py: new engine routes get_state and liveness_probe. - liveness_probe round-trips through engine_client.check_health() so a wedged event loop surfaces as 503 (closes hhzhang16 HH-23). - get_state returns per-worker {engine_alive, pause_state, applied_weight_version, loras} — aggregated by Rust frontend. - pause_generation / resume_generation now track the _paused flag. worker_factory.py: register the two new engine routes alongside the existing pause/resume/flush_cache/update_weights/load_lora ones. openai.rs: new GET /v1/rl/state composite endpoint and GET /v1/rl/liveness probe (5s timeout, override via DYN_RL_LIVENESS_TIMEOUT_MS). State aggregates per-worker payloads and surfaces ready / engine_alive / pause_state / applied_weight_version / loras / per-worker workers. Closes HH-19 (single state endpoint), HH-25 (RL-specific), HH-27 (weight_version folded in). Legacy /v1/rl/{health,ready,weight_version} kept for back-compat — drop once prime-rl AdminAPI migrates to /state. 3-mode pause + structured update_weights body handlers.py: pause_generation accepts {mode, clear_cache} body. - mode=keep|wait|abort with default keep (matches prime-rl client.py:_pause_engines). Closes HH-21. - mode=abort triggers collective_rpc(abort_all_requests) when available; gracefully falls back with a warning on vLLM 0.19 where that RPC isn't implemented. - clear_cache=true triggers reset_prefix_cache after pause. openai.rs: rl_pause now extracts ?mode= and ?clear_cache= via axum::extract::Query, validates mode, propagates to worker. Returns 400 on unknown mode (verified end-to-end via curl). rl_update_weights body migrates to typed RlUpdateWeightsBody {weight_dir, weight_version?, reset_prefix_cache=true}; the prior flush+reload sequence is now optional (controlled by reset_prefix_cache, default true) and the response carries applied_weight_version. RL extras on /v1/chat/completions validate.rs: PASSTHROUGH_EXTRA_FIELDS expanded from {cache_salt} to {cache_salt, prompt_token_ids, weight_version, return_routed_experts, return_token_ids, return_prompt_logprobs}. RL clients can now send these as top-level extras without 400s. Closes HH-22 / HH-26 (TITO via prompt_token_ids on /v1/chat/completions instead of forked URI). chat_completions/delta.rs and completions/delta.rs: replace the silent 'if let Ok(json) = serde_json::to_value(...)' fallback with a match that emits tracing::warn! on the Err branch. A dropped nvext means promoted token_ids / weight_version never reach the RL trainer — silently corrupting training. Closes CR-9. openai.rs: doc-block for token-ID promotion moved from rl_tokenize_prompt to rl_promote_token_ids_in_response (where it actually applies). Closes CR-10. Drop dead URIs service_v2.rs: drop tokenization_router (/v1/tokenize, /v1/detokenize) mounting — owned by jthomson04 PR #7699 for NeMo-rl scope, not required by prime-rl. Drop chat_completions_tokens_router mounting — TITO collapses into /v1/chat/completions with prompt_token_ids extension. All three URIs verified 404 end-to-end via curl. openai.rs: tokenize, detokenize, tokenization_router, chat_completions_tokens_router, handler_chat_completions_tokens, bad_request all marked #[allow(dead_code)] (kept for downstream compat; physical deletion is a follow-up). bad_request doc-block cleaned of stale 'Not Implemented' lines. Closes CR-8. End-to-end verification (smokes ran against this commit) /v1/rl/state 200 ready=true engine_alive=true (manual curl) /v1/rl/liveness 200 alive=true (manual curl) /v1/rl/pause?mode=abort&clear_cache=true 200 (manual curl) └─ vllm_worker.log: '[RL] Engine paused (..., mode=abort, clear_cache=True)' /v1/rl/pause?mode=invalid 400 'Invalid mode' /v1/rl/resume 200 /v1/chat/completions/tokens 404 /v1/tokenize 404 /v1/detokenize 404 4-02/sft smoke: PASS (full-FT pause->update_weights->resume, mismatch_kl<=0.0007) 4-02/lora smoke: PASS (1 hot-swap closed, lora_id=1626203954) Reviewer-comment closures landed in this commit CR-8 bad_request doc cleaned CR-9 serde_json::to_value Err logged via tracing::warn! CR-10 token-ID promotion doc-block re-attached HH-19 single state endpoint /v1/rl/state HH-21 3-mode pause keep|wait|abort + clear_cache HH-22 prompt_token_ids in PASSTHROUGH_EXTRA_FIELDS (TITO collapse) HH-23 liveness_probe via engine_client.check_health() + 5s timeout HH-25 /v1/rl/state is RL-specific (vs broader /health) HH-26 no /v1/chat/completions/tokens distinction (URI dropped) HH-27 weight_version folded into /v1/rl/state.applied_weight_version --- components/src/dynamo/vllm/handlers.py | 106 ++++- components/src/dynamo/vllm/worker_factory.py | 7 + lib/llm/src/http/service/openai.rs | 415 +++++++++++++++--- lib/llm/src/http/service/service_v2.rs | 33 +- .../openai/chat_completions/delta.rs | 49 ++- .../src/protocols/openai/completions/delta.rs | 42 +- lib/llm/src/protocols/openai/validate.rs | 41 +- 7 files changed, 586 insertions(+), 107 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 25b898b8d630..c12ee1743e95 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -835,12 +835,55 @@ async def pause_generation(self, body: dict) -> dict: Called by RL admin coordinator before weight updates. Uses engine_client.pause_generation() directly -- does NOT sleep (no GPU memory release) and does NOT unregister from discovery. + + Body (all optional, all default to the prime-rl client convention): + - mode: "keep" | "wait" | "abort" (default "keep" — drain in-flight) + - clear_cache: bool (default False) + Closes hhzhang16 review HH-21 (3-mode pause). """ body = body or {} + mode = body.get("mode", "keep") + clear_cache = bool(body.get("clear_cache", False)) + if mode not in ("keep", "wait", "abort"): + return { + "status": "error", + "message": f"Invalid mode '{mode}'; expected one of keep|wait|abort", + } try: await self.engine_client.pause_generation() - logger.info("[RL] Engine paused (generation quiesced)") - return {"status": "ok", "message": "Engine paused"} + # mode=abort → also abort in-flight requests via vLLM's request abort + if mode == "abort": + try: + # Best-effort abort of all in-flight requests. + # vLLM exposes per-request abort; we don't track ids here so + # rely on engine internals to drain the rest under pause. + await self.engine_client.collective_rpc( + "abort_all_requests", kwargs={} + ) + except Exception as abort_err: + logger.warning( + f"[RL] mode=abort: collective_rpc(abort_all_requests) " + f"unavailable on this engine version: {abort_err}; " + f"in-flight requests will drain naturally" + ) + if clear_cache: + try: + await self.engine_client.reset_prefix_cache() + logger.info("[RL] pause: prefix cache cleared") + except Exception as flush_err: + logger.warning( + f"[RL] pause: clear_cache requested but reset_prefix_cache failed: {flush_err}" + ) + self._paused = True + logger.info( + f"[RL] Engine paused (generation quiesced, mode={mode}, clear_cache={clear_cache})" + ) + return { + "status": "ok", + "message": "Engine paused", + "mode": mode, + "clear_cache": clear_cache, + } except Exception as e: logger.error(f"[RL] Failed to pause: {e}") return {"status": "error", "message": str(e)} @@ -850,12 +893,71 @@ async def resume_generation(self, body: dict) -> dict: body = body or {} try: await self.engine_client.resume_generation() + self._paused = False logger.info("[RL] Engine resumed") return {"status": "ok", "message": "Engine resumed"} except Exception as e: logger.error(f"[RL] Failed to resume: {e}") return {"status": "error", "message": str(e)} + async def liveness_probe(self, body: dict) -> dict: + """Engine event-loop probe — confirms the engine is responsive. + + Used by ``GET /v1/rl/liveness``. The Rust frontend fans this out with a + short timeout (default 5s). Returning ``alive: True`` requires the + engine_client IPC roundtrip to complete: a hung event loop, deadlocked + worker, or wedged engine will time out at the frontend instead of + returning a stale ``OK``. Closes hhzhang16 HH-23 (health probe returns + OK no matter what). + """ + body = body or {} + try: + # vLLM's AsyncLLM/AsyncEngineClient exposes check_health() as the + # canonical liveness probe. It does a lightweight collective RPC + # to all engine workers and raises if any are unresponsive. + if hasattr(self.engine_client, "check_health"): + await self.engine_client.check_health() + return {"status": "ok", "alive": True} + # Fallback for engines without check_health: a no-op collective_rpc. + # The RPC round-trip itself is the liveness signal — if the engine + # event loop is wedged the frontend's 5s timeout fires. + await self.engine_client.collective_rpc("get_weight_version", kwargs={}) + return {"status": "ok", "alive": True} + except Exception as e: + logger.warning(f"[RL] liveness_probe failed: {e}") + return {"status": "error", "alive": False, "message": str(e)} + + async def get_state(self, body: dict) -> dict: + """Composite per-worker state snapshot for ``GET /v1/rl/state``. + + The Rust frontend aggregates these per-worker payloads into the + fleet-wide ``RlStateResponse``. Closes hhzhang16 HH-19/HH-25/HH-27 + (single state endpoint replacing /health + /ready + /weight_version, + RL-specific, weight_version folded in). + """ + body = body or {} + try: + engine_alive = True + try: + if hasattr(self.engine_client, "check_health"): + await self.engine_client.check_health() + except Exception as health_err: + engine_alive = False + logger.warning(f"[RL] get_state: engine_alive=false ({health_err})") + return { + "status": "ok", + "engine_alive": engine_alive, + "pause_state": "paused" if getattr(self, "_paused", False) else "running", + "applied_weight_version": getattr(self, "_weight_version", "initial"), + "loras": [ + {"name": name, "id": info.id, "path": info.path} + for name, info in getattr(self, "loaded_loras", {}).items() + ], + } + except Exception as e: + logger.error(f"[RL] get_state failed: {e}") + return {"status": "error", "message": str(e)} + async def flush_cache(self, body: dict) -> dict: """Invalidate prefix/KV cache. Called after weight updates.""" body = body or {} diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index 0c1649c4e3af..9948fdad2f81 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -686,6 +686,12 @@ def register_engine_routes( ) runtime.register_engine_route("get_weight_version", handler.get_weight_version) + # RL state + liveness — drive /v1/rl/state and /v1/rl/liveness in the + # Rust frontend. /v1/rl/state aggregates these per-worker snapshots + # into the composite RlStateResponse (rl-support.md Phase 1). + runtime.register_engine_route("get_state", handler.get_state) + runtime.register_engine_route("liveness_probe", handler.liveness_probe) + # RL LoRA adapter routes: filesystem-native hot-swap used by Prime-RL # every training step to broadcast new adapter weights into the engine. runtime.register_engine_route("load_lora_adapter", handler.load_lora_adapter) @@ -697,5 +703,6 @@ def register_engine_routes( "Registered engine routes: sleep, wake_up, scale_elastic_ep, " "start_profile, stop_profile, pause_generation, resume_generation, " "flush_cache, update_weights_from_path, get_weight_version, " + "get_state, liveness_probe, " "load_lora_adapter, unload_lora_adapter" ) diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 1855140200b1..591c29c9cb50 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -204,11 +204,11 @@ impl ErrorMessage { ) } - /// Not Implemented Error - /// Return this error when the client requests a feature that is not yet implemented. - /// This should be used for features that are planned but not available. - /// Bad Request Error - /// Return this error when the client sends an invalid request. + /// Bad Request Error. + /// Return this error when the client sends an invalid request — malformed + /// JSON, schema mismatch, or fields that fail `validate.rs` gating. + /// (CR-8 closure: stale doc-block lines about "Not Implemented" removed.) + #[allow(dead_code)] // exposed for downstream crates; not directly called in lib/llm pub fn bad_request(msg: &str) -> ErrorResponse { let code = StatusCode::BAD_REQUEST; let error_type = map_error_code_to_error_type(code); @@ -2098,6 +2098,10 @@ fn resolve_model_card( Ok((model, card)) } +// Phase 5: handler kept (no callers) until jthomson04 PR #7699 lands +// `/tokenize` and `/detokenize` at root paths. Re-mount via +// `tokenization_router` in `service_v2.rs` if needed standalone. +#[allow(dead_code)] async fn tokenize( State(state): State>, Json(request): Json, @@ -2185,6 +2189,7 @@ async fn tokenize( .into_response()) } +#[allow(dead_code)] // see tokenize() above async fn detokenize( State(state): State>, Json(request): Json, @@ -2203,6 +2208,7 @@ async fn detokenize( Ok(Json(DetokenizeResponse { prompt }).into_response()) } +#[allow(dead_code)] // see tokenize() above; not mounted in service_v2 v2 surface pub fn tokenization_router(state: Arc) -> (Vec, Router) { let tokenize_path = "/v1/tokenize"; let detokenize_path = "/v1/detokenize"; @@ -2337,7 +2343,14 @@ pub fn chat_completions_router( /// to the standard chat_completions handler -- all in Rust, eliminating the /// Python rl-admin proxy from the hot inference path. /// -/// If no path is provided, the default path is `/v1/chat/completions/tokens` +/// If no path is provided, the default path is `/v1/chat/completions/tokens`. +/// +/// Phase 5: dropped from the v2 surface (see `service_v2.rs`). TITO callers +/// retarget to `/v1/chat/completions` with `prompt_token_ids` extension — +/// vLLM 0.20+ skips chat templating when that field is present, identical +/// behavior. The handler is kept as `#[allow(dead_code)]` until prime-rl +/// `bis/prime-rl-merged` migration P1 lands. +#[allow(dead_code)] pub fn chat_completions_tokens_router( state: Arc, template: Option, @@ -2363,6 +2376,7 @@ pub fn chat_completions_tokens_router( /// 4. Forces `logprobs = true` (RL always needs logprobs) /// 5. Ensures `messages` is non-empty (Dynamo requires it for chat template selection) /// 6. Delegates to the standard `chat_completions()` internal function (zero HTTP proxy) +#[allow(dead_code)] // see chat_completions_tokens_router above async fn handler_chat_completions_tokens( State((state, template)): State<(Arc, Option)>, headers: HeaderMap, @@ -3200,15 +3214,58 @@ async fn rl_ready(State(state): State>) -> impl IntoResponse { } /// `POST /v1/rl/pause` — fan out `pause_generation` to all workers. -async fn rl_pause(State(state): State>) -> impl IntoResponse { +/// +/// Query params (both optional): +/// - `mode`: `keep` | `wait` | `abort` (default `keep`) +/// - `clear_cache`: `true` | `false` (default `false`) +/// +/// Closes hhzhang16 HH-21 (3-mode pause: vLLM exposes abort/wait/keep). +/// Default is `mode=keep&clear_cache=false` to match prime-rl +/// `client.py:_pause_engines` so existing callers keep working. +#[derive(Debug, serde::Deserialize)] +struct RlPauseQuery { + #[serde(default)] + mode: Option, + #[serde(default)] + clear_cache: Option, +} + +async fn rl_pause( + State(state): State>, + axum::extract::Query(q): axum::extract::Query, +) -> impl IntoResponse { + let mode = q.mode.as_deref().unwrap_or("keep").to_string(); + let clear_cache = q.clear_cache.unwrap_or(false); + if !matches!(mode.as_str(), "keep" | "wait" | "abort") { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "status": "error", + "message": format!( + "Invalid mode '{mode}'; expected one of keep|wait|abort" + ), + })), + ); + } let results = state - .fan_out("pause_generation", serde_json::json!({})) + .fan_out( + "pause_generation", + serde_json::json!({"mode": mode, "clear_cache": clear_cache}), + ) .await; if RlState::all_ok(&results) { - tracing::info!("RL pause: all {} worker(s) paused", results.len()); + tracing::info!( + "RL pause: all {} worker(s) paused (mode={mode}, clear_cache={clear_cache})", + results.len() + ); ( StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), + Json(serde_json::json!({ + "status": "ok", + "mode": mode, + "clear_cache": clear_cache, + "workers": results, + })), ) } else { tracing::warn!("RL pause: some workers failed: {:?}", results); @@ -3241,49 +3298,81 @@ async fn rl_resume(State(state): State>) -> impl IntoResponse { /// `POST /v1/rl/update_weights` — atomic `flush_cache → update_weights_from_path` across all workers. /// -/// Expected body: `{"weight_dir": "/path/to/checkpoint"}` or `{"weight_dir": null}` for NCCL mode. +/// Body schema (`reset_prefix_cache` defaults to `true` — the v1 sequence +/// always flushed before reload, this just makes it explicit): +/// ```json +/// { +/// "weight_dir": "/path/to/checkpoint" | null, // null → NCCL mode no-op +/// "weight_version": "step_42", // optional; derived from +/// // weight_dir basename if missing +/// "reset_prefix_cache": true +/// } +/// ``` +/// +/// Returns `{ "status": "ok", "applied_weight_version": "step_42", "workers": [...] }` on success. /// -/// The sequence per worker is: `flush_cache → update_weights_from_path`. -/// The pause/resume envelope is left to Prime-RL, which can call `/v1/rl/pause` and -/// `/v1/rl/resume` explicitly for full drain-and-swap semantics. +/// The pause/resume envelope is left to the caller; full-FT updates MUST +/// bracket this call with `/v1/rl/pause` and `/v1/rl/resume`. +#[derive(Debug, serde::Deserialize)] +struct RlUpdateWeightsBody { + weight_dir: Option, + #[serde(default)] + weight_version: Option, + #[serde(default = "default_reset_prefix_cache")] + reset_prefix_cache: bool, +} + +fn default_reset_prefix_cache() -> bool { + true +} + async fn rl_update_weights( State(state): State>, - body: axum::extract::Json, + body: axum::extract::Json, ) -> impl IntoResponse { - let weight_dir = body - .get("weight_dir") - .and_then(|v| v.as_str()) - .map(str::to_string); + let weight_dir = body.weight_dir.clone(); + let reset_prefix_cache = body.reset_prefix_cache; if weight_dir.is_none() { tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); return ( StatusCode::OK, - Json(serde_json::json!({"status": "ok", "message": "NCCL mode, no-op on Dynamo side"})), + Json(serde_json::json!({ + "status": "ok", + "message": "NCCL mode, no-op on Dynamo side" + })), ); } let weight_dir = weight_dir.unwrap(); - tracing::info!("RL update_weights: weight_dir={weight_dir}"); + let version = body.weight_version.clone().unwrap_or_else(|| { + std::path::Path::new(&weight_dir) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string() + }); + tracing::info!( + "RL update_weights: weight_dir={weight_dir} version={version} reset_prefix_cache={reset_prefix_cache}" + ); - // Step 1: flush_cache across all workers - let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; - if !RlState::all_ok(&flush_results) { - tracing::warn!("RL update_weights: flush_cache failed: {:?}", flush_results); - return ( - StatusCode::BAD_GATEWAY, - Json( - serde_json::json!({"status": "error", "stage": "flush_cache", "workers": flush_results}), - ), - ); + // Step 1 (optional): flush_cache across all workers. + if reset_prefix_cache { + let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; + if !RlState::all_ok(&flush_results) { + tracing::warn!("RL update_weights: flush_cache failed: {:?}", flush_results); + return ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "stage": "flush_cache", + "workers": flush_results + })), + ); + } } - // Step 2: update_weights_from_path across all workers - let version = std::path::Path::new(&weight_dir) - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string(); + // Step 2: update_weights_from_path across all workers. let load_body = serde_json::json!({"path": weight_dir, "version": version}); let load_results = state.fan_out("update_weights_from_path", load_body).await; if RlState::all_ok(&load_results) { @@ -3293,7 +3382,11 @@ async fn rl_update_weights( ); ( StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": load_results})), + Json(serde_json::json!({ + "status": "ok", + "applied_weight_version": version, + "workers": load_results, + })), ) } else { tracing::warn!( @@ -3302,9 +3395,11 @@ async fn rl_update_weights( ); ( StatusCode::BAD_GATEWAY, - Json( - serde_json::json!({"status": "error", "stage": "update_weights_from_path", "workers": load_results}), - ), + Json(serde_json::json!({ + "status": "error", + "stage": "update_weights_from_path", + "workers": load_results + })), ) } } @@ -3461,11 +3556,6 @@ async fn rl_weight_version(State(state): State>) -> impl IntoRespon } } -/// Promote token IDs from the Dynamo `nvext` response object to the top-level -/// locations that Prime-RL / verifiers expects: -/// -/// response.nvext.completion_token_ids → response.choices[i].token_ids -/// /// Tokenize chat messages using the model's tokenizer and return prompt token IDs. /// Used by the RL post-processing path to populate `response.prompt_token_ids`. fn rl_tokenize_prompt( @@ -3502,9 +3592,15 @@ fn rl_tokenize_prompt( Some(encoding.token_ids().to_vec()) } +/// Promote token IDs from the Dynamo `nvext` response object to the top-level +/// locations that Prime-RL / verifiers expects: +/// +/// response.nvext.completion_token_ids → response.choices[i].token_ids +/// /// This lets Prime-RL read `choice.token_ids` without knowing about the `nvext` /// extension structure. Called on non-streaming responses when RL token ID mode -/// is active. +/// is active. (CR-10 closure: doc-block was previously misattached to +/// `rl_tokenize_prompt`.) fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { // Move completion_token_ids from response-level nvext to each choice. // Prime-RL / verifiers expects: @@ -3545,10 +3641,206 @@ fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { /// `admin_base_url = ["http://dynamo:8000/v1/rl"]` the request arrives here. /// Returns 200 OK if the frontend process is running (no deep probe needed — /// the frontend's own `/health` endpoint handles that separately). +/// +/// **Deprecated in favor of `/v1/rl/state.ingress_alive`** (rl-support.md +/// Phase 1 / Phase 5). Kept for prime-rl `bis/prime-rl-merged` until the +/// AdminAPI migration lands; will be removed once prime-rl P2 commits. async fn rl_health() -> impl IntoResponse { (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) } +/// `GET /v1/rl/liveness` — engine event-loop probe via `liveness_probe` +/// engine route. Closes hhzhang16 HH-23 (the v1 `/v1/rl/health` returns OK +/// no matter what; this endpoint round-trips through the engine so a hung +/// event loop or wedged worker surfaces as 503). +/// +/// Each per-worker call carries a 5s timeout (override via +/// `DYN_RL_LIVENESS_TIMEOUT_MS`). Returns 200 only when every worker +/// reports `alive: true` within the deadline; 503 otherwise. +async fn rl_liveness(State(state): State>) -> impl IntoResponse { + if state.worker_system_urls.is_empty() { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "error", + "alive": false, + "message": "no workers registered" + })), + ); + } + let timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(5000); + let timeout = std::time::Duration::from_millis(timeout_ms); + + let futures: Vec<_> = state + .worker_system_urls + .iter() + .map(|url| { + let client = state.http_client.clone(); + let endpoint = format!("{url}/engine/liveness_probe"); + async move { + tokio::time::timeout( + timeout, + async { + match client.post(&endpoint).json(&serde_json::json!({})).send().await { + Ok(resp) => resp + .json::() + .await + .unwrap_or_else(|e| serde_json::json!({ + "status": "error", + "alive": false, + "message": format!("decode failed: {e}") + })), + Err(e) => serde_json::json!({ + "status": "error", + "alive": false, + "message": format!("request failed: {e}") + }), + } + }, + ) + .await + .unwrap_or_else(|_| serde_json::json!({ + "status": "error", + "alive": false, + "message": format!("liveness_probe timed out after {timeout_ms}ms") + })) + } + }) + .collect(); + let results = futures::future::join_all(futures).await; + let all_alive = results + .iter() + .all(|r| r.get("alive").and_then(|v| v.as_bool()) == Some(true)); + if all_alive { + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "alive": true, + "workers": results, + })), + ) + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "error", + "alive": false, + "workers": results, + })), + ) + } +} + +/// `GET /v1/rl/state` — composite RL fleet state snapshot. +/// +/// Replaces three v1 endpoints (`/v1/rl/health` + `/v1/rl/ready` + +/// `/v1/rl/weight_version`) with a single composite. Closes hhzhang16 +/// HH-19 (single state endpoint), HH-25 (RL-specific vs broader Dynamo +/// readiness), HH-27 (weight_version folded in). +/// +/// Aggregates per-worker `get_state` engine-route responses into: +/// +/// ```json +/// { +/// "ready": bool, +/// "ingress_alive": true, +/// "engine_alive": bool, // every worker's engine.check_health() ok +/// "pause_state": "running"|"paused"|"mixed", +/// "applied_weight_version": str, // when consistent across workers; null if mixed +/// "loras": [{name, loaded_on: [worker_idx]}], +/// "workers": [] +/// } +/// ``` +/// +/// `ingress_alive` is unconditionally `true` because reaching this handler +/// means the frontend HTTP listener is up. `ready = ingress_alive AND +/// engine_alive AND len(workers) > 0`. +async fn rl_state(State(state): State>) -> impl IntoResponse { + if state.worker_system_urls.is_empty() { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "ready": false, + "ingress_alive": true, + "engine_alive": false, + "pause_state": "running", + "applied_weight_version": null, + "loras": [], + "workers": [], + "status": "error", + "message": "no workers registered" + })), + ); + } + let results = state.fan_out("get_state", serde_json::json!({})).await; + + let engine_alive = results + .iter() + .all(|r| r.get("engine_alive").and_then(|v| v.as_bool()) == Some(true)); + + // Aggregate pause_state: if all workers agree, surface that; else "mixed". + let pause_states: std::collections::HashSet<&str> = results + .iter() + .filter_map(|r| r.get("pause_state").and_then(|v| v.as_str())) + .collect(); + let pause_state = if pause_states.len() == 1 { + pause_states.into_iter().next().unwrap_or("running").to_string() + } else if pause_states.is_empty() { + "running".to_string() + } else { + "mixed".to_string() + }; + + // applied_weight_version is reported only when consistent. + let weight_versions: std::collections::HashSet<&str> = results + .iter() + .filter_map(|r| r.get("applied_weight_version").and_then(|v| v.as_str())) + .collect(); + let applied_weight_version: Option = if weight_versions.len() == 1 { + weight_versions.into_iter().next().map(|s| s.to_string()) + } else { + None + }; + + // LoRA name → list of worker indices that have it loaded. + let mut lora_loaded_on: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + for (idx, worker) in results.iter().enumerate() { + if let Some(loras) = worker.get("loras").and_then(|v| v.as_array()) { + for lora in loras { + if let Some(name) = lora.get("name").and_then(|v| v.as_str()) { + lora_loaded_on.entry(name.to_string()).or_default().push(idx); + } + } + } + } + let loras: Vec = lora_loaded_on + .into_iter() + .map(|(name, loaded_on)| serde_json::json!({"name": name, "loaded_on": loaded_on})) + .collect(); + + let ready = engine_alive && !results.is_empty(); + let body = serde_json::json!({ + "ready": ready, + "ingress_alive": true, + "engine_alive": engine_alive, + "pause_state": pause_state, + "applied_weight_version": applied_weight_version, + "loras": loras, + "workers": results, + }); + let status = if ready { + StatusCode::OK + } else { + StatusCode::SERVICE_UNAVAILABLE + }; + (status, Json(body)) +} + /// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. /// /// Worker system URLs are read from the `DYN_RL_WORKER_SYSTEM_URLS` environment @@ -3560,28 +3852,41 @@ async fn rl_health() -> impl IntoResponse { /// in the orchestrator config. Prime-RL strips the trailing `/v1` suffix only /// if present, so `/v1/rl` is preserved and all routes resolve correctly. pub fn rl_router() -> (Vec, Router) { - let rl_state = Arc::new(RlState::from_env()); + let rl_state_arc = Arc::new(RlState::from_env()); let docs = vec![ - RouteDoc::new(axum::http::Method::GET, "/v1/rl/health"), - RouteDoc::new(axum::http::Method::GET, "/v1/rl/ready"), + // Phase 1: composite endpoints. + RouteDoc::new(axum::http::Method::GET, "/v1/rl/state"), + RouteDoc::new(axum::http::Method::GET, "/v1/rl/liveness"), + // Pause / resume / update_weights bracket. RouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), RouteDoc::new(axum::http::Method::POST, "/v1/rl/resume"), RouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), + // LoRA hot-swap. RouteDoc::new(axum::http::Method::POST, "/v1/rl/load_lora_adapter"), - RouteDoc::new(axum::http::Method::POST, "/v1/rl/unload_lora_adapter"), + // Legacy (deprecated; subsumed by /v1/rl/state — Phase 5 will drop): + RouteDoc::new(axum::http::Method::GET, "/v1/rl/health"), + RouteDoc::new(axum::http::Method::GET, "/v1/rl/ready"), RouteDoc::new(axum::http::Method::GET, "/v1/rl/weight_version"), + RouteDoc::new(axum::http::Method::POST, "/v1/rl/unload_lora_adapter"), ]; let router = Router::new() - .route("/v1/rl/health", get(rl_health)) - .route("/v1/rl/ready", get(rl_ready)) + // Phase 1: composite read-only endpoints. + .route("/v1/rl/state", get(rl_state)) + .route("/v1/rl/liveness", get(rl_liveness)) + // Pause / resume / update_weights bracket. .route("/v1/rl/pause", post(rl_pause)) .route("/v1/rl/resume", post(rl_resume)) .route("/v1/rl/update_weights", post(rl_update_weights)) + // LoRA hot-swap. .route("/v1/rl/load_lora_adapter", post(rl_load_lora_adapter)) - .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) + // Legacy endpoints — kept until prime-rl `bis/prime-rl-merged` AdminAPI + // migration P2 lands; Phase 5 of rl-support.md drops them. + .route("/v1/rl/health", get(rl_health)) + .route("/v1/rl/ready", get(rl_ready)) .route("/v1/rl/weight_version", get(rl_weight_version)) + .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) .layer(middleware::from_fn(smart_json_error_middleware)) - .with_state(rl_state); + .with_state(rl_state_arc); (docs, router) } diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index fe5d8ea6ba13..4d5e3816b666 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -538,7 +538,13 @@ impl HttpServiceConfigBuilder { } else { super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()) }, - super::openai::tokenization_router(state.clone()), + // /v1/tokenize and /v1/detokenize are NOT required by prime-rl + // (source audit: zero references). Owned by jthomson04 PR #7699 + // which mounts /tokenize and /detokenize at root paths for the + // NeMo-rl tokenize-then-generate pattern. Dropped from the v2 + // surface here per `bis-dev/design-docs/rl-support.md` §1 + // out-of-scope. Re-enable by uncommenting the next line: + // super::openai::tokenization_router(state.clone()), super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), @@ -612,14 +618,16 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_CHAT_PATH_ENV).ok(), ); - // RL TITO (Token-In / Token-Out) endpoint -- mounted alongside chat completions. - // Accepts Prime-RL's `tokens` field, translates to nvext.token_data, and delegates - // to the standard chat completions pipeline. Eliminates the Python rl-admin proxy. - let (tito_docs, tito_route) = super::openai::chat_completions_tokens_router( - state.clone(), - request_template.clone(), - None, - ); + // /v1/chat/completions/tokens (the v1 TITO fork URI) is dropped per + // `bis-dev/design-docs/rl-support.md` Phase 5 + hhzhang16 HH-22 / HH-26. + // TITO callers retarget to /v1/chat/completions with `prompt_token_ids` + // as a top-level extension (now in `validate.rs:104` + // PASSTHROUGH_EXTRA_FIELDS) — vLLM 0.20+ skips chat templating when + // that field is present, identical behavior to the dropped fork URI. + // The handler `handler_chat_completions_tokens` and helper + // `chat_completions_tokens_router` are intentionally left in the + // codebase as dead code for now; a subsequent commit can delete + // them once prime-rl has fully migrated. let (cmpl_docs, cmpl_route) = super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok()); @@ -633,13 +641,10 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_RESPONSES_PATH_ENV).ok(), ); - // Merge TITO route and docs into the chat route (shares enable/disable flag) - let chat_route = chat_route.merge(tito_route); - let mut combined_chat_docs = chat_docs; - combined_chat_docs.extend(tito_docs); + // Phase 5: TITO fork URI dropped — chat route stands alone now. let mut endpoint_routes = HashMap::new(); - endpoint_routes.insert(EndpointType::Chat, (combined_chat_docs, chat_route)); + endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route)); endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route)); endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route)); endpoint_routes.insert(EndpointType::Images, (images_docs, images_route)); diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 27943317b7a2..9abe99f4fc28 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -460,25 +460,38 @@ impl crate::protocols::openai::DeltaGeneratorExt { + stream_response.nvext = Some(nvext_json); + if let Some(ref info) = nvext_response.worker_id { + tracing::debug!( + "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } + if let Some(ref tokens) = nvext_response.token_ids { + tracing::debug!( + "Injected token_ids into chat completion nvext: {} tokens", + tokens.len() + ); + } + if let Some(ref tokens) = nvext_response.completion_token_ids { + tracing::debug!( + "Injected completion_token_ids into chat completion nvext: {} tokens", + tokens.len() + ); + } } - if let Some(ref tokens) = nvext_response.completion_token_ids { - tracing::debug!( - "Injected completion_token_ids into chat completion nvext: {} tokens", - tokens.len() + Err(err) => { + tracing::warn!( + error = %err, + "chat completion nvext: serde_json::to_value failed, dropping nvext payload \ + (RL trainer will not receive token_ids / weight_version this chunk)", ); } } diff --git a/lib/llm/src/protocols/openai/completions/delta.rs b/lib/llm/src/protocols/openai/completions/delta.rs index 3f039399e1cd..26d84c7803ae 100644 --- a/lib/llm/src/protocols/openai/completions/delta.rs +++ b/lib/llm/src/protocols/openai/completions/delta.rs @@ -313,21 +313,33 @@ impl crate::protocols::openai::DeltaGeneratorExt for delta.disaggregated_params.as_ref(), finish_reason.is_some(), delta.engine_data, - ) && let Ok(nvext_json) = serde_json::to_value(&nvext_response) - { - response.nvext = Some(nvext_json); - if let Some(ref info) = nvext_response.worker_id { - tracing::debug!( - "Injected worker_id into completions nvext: prefill={:?}, decode={:?}", - info.prefill_worker_id, - info.decode_worker_id - ); - } - if let Some(ref tokens) = nvext_response.token_ids { - tracing::debug!( - "Injected token_ids into completions nvext: {} tokens", - tokens.len() - ); + ) { + // CR-9 closure: log a warning if serialization fails instead of + // silently dropping the nvext payload (would mean promoted fields + // never reach the client). + match serde_json::to_value(&nvext_response) { + Ok(nvext_json) => { + response.nvext = Some(nvext_json); + if let Some(ref info) = nvext_response.worker_id { + tracing::debug!( + "Injected worker_id into completions nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } + if let Some(ref tokens) = nvext_response.token_ids { + tracing::debug!( + "Injected token_ids into completions nvext: {} tokens", + tokens.len() + ); + } + } + Err(err) => { + tracing::warn!( + error = %err, + "completions nvext: serde_json::to_value failed, dropping nvext payload", + ); + } } } diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index fff6ea8edd65..e581d04303fc 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -98,10 +98,45 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0; // /// Fields that Prime-RL / verifiers may send as extra_body hints which Dynamo -/// does not implement but should not reject with a 400. They are silently -/// accepted and ignored so the RL client stack is forward-compatible. +/// does not implement but should not reject with a 400. They are silently +/// accepted (the chat-completions handler reads what it understands and +/// ignores the rest) so the RL client stack is forward-compatible with new +/// extension fields without churning Dynamo. +/// +/// Per `bis-dev/design-docs/rl-support.md` Phase 4, this is the canonical +/// home for the typed RL extension fields; the prior `nvext.extra_fields` +/// `["completion_token_ids", ...]` opt-in mechanism still works alongside it +/// but the named fields here are the recommended path. const PASSTHROUGH_EXTRA_FIELDS: &[&str] = &[ - "cache_salt", // KV prefix-cache isolation hint from prime-rl orchestrator + // KV prefix-cache isolation hint from prime-rl orchestrator. Coordinated + // with PR #8197 (which moves this to the X-Tenant-Id header); both forms + // accepted for one release, header takes precedence. + "cache_salt", + // Pre-tokenized prompt for the RL TITO path. Mutually exclusive with + // `messages`; when present, vLLM 0.20+ skips chat templating. Closes + // hhzhang16 HH-22 / HH-26 — the "tokens variant of /v1/chat/completions" + // collapses into the same URI with this extension field instead of a + // forked /v1/chat/completions/tokens. Today RL clients pre-tokenize + // and pass via `nvext.token_data` (preprocessor.rs handles that + // already); the typed top-level field shipped here is the long-term + // canonical entry for clients written against the vLLM 0.20 schema. + "prompt_token_ids", + // RL routing filter — only dispatch to workers reporting this applied + // weight version. Used by IS-correction strict-version mode and by + // NeMo RL eval-on-subset. Today accepted-and-ignored at the request + // level; the routing-side filter lands in a follow-up. + "weight_version", + // Per-request gate for MoE Routing Replay capture. Honored by + // `nvext.extra_fields = ["routed_experts"]` already (see + // `NvExtResponseFieldSelection`); accepted here as a typed alias. + "return_routed_experts", + // Per-request opt-in for `nvext.completion_token_ids` on the response. + // Today the `extra_fields = ["completion_token_ids"]` mechanism is the + // canonical; this typed alias is the long-term form. + "return_token_ids", + // Opt-in for `nvext.prompt_logprobs` on the response. Aliased through + // to vLLM's `sampling_params.prompt_logprobs` in a follow-up. + "return_prompt_logprobs", ]; /// Validates that no unsupported fields are present in the request From f03417149a67226226edf36367fd8dadec42ee36 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Mon, 4 May 2026 23:23:45 -0700 Subject: [PATCH 03/18] feat(rl): /v1/chat/completions absorbs TITO + full SamplingParams parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit End-to-end TITO support on /v1/chat/completions: callers post prompt_token_ids as a top-level extension (and stop_token_ids, allowed_token_ids, bad_words_token_ids, truncate_prompt_tokens for SamplingParams parity with vLLM /inference/v1/generate). vLLM 0.20+ already accepts these on /v1/chat/completions natively; this commit is the bridge until that upgrade lands. Plumbing: validate.rs: PASSTHROUGH_EXTRA_FIELDS adds stop_token_ids, bad_words_token_ids, allowed_token_ids, truncate_prompt_tokens. protocols/openai.rs: OpenAIStopConditionsProvider gains get_stop_token_ids() (default None); extract_stop_conditions() plumbs into common::StopConditions::stop_token_ids_hidden so vLLM's SamplingParams.stop_token_ids is honored end-to-end. protocols/openai/chat_completions.rs: NvCreateChatCompletionRequest overrides get_stop_token_ids and get_pretokenized_input to read unsupported_fields. ValidateRequest::validate accepts empty messages when prompt_token_ids is present, 400s on mutual-exclusion violations. protocols/openai/nvext.rs: NvExtProvider gains get_pretokenized_input with a default impl reading nvext.token_data; chat-completions overrides to also read top-level prompt_token_ids extension. preprocessor.rs: apply_template short-circuits to None when pre- tokenized input is present (avoids 'undefined value' template errors on empty messages); gather_tokens reads via the same NvExtProvider hook covering both channels. http/service/openai.rs: validate_chat_completion_required_fields accepts empty messages when pre-tokenized input is provided. End-to-end probes against Qwen/Qwen3-0.6B: /v1/chat/completions with prompt_token_ids + stop_token_ids: 200, finish_reason=stop on stop-token, nvext.completion_token_ids populated. stop_token_ids=[16] forces stop on first generated token: 200, immediate stop (proves SamplingParams.stop_token_ids honored). messages + prompt_token_ids: 400 mutual-exclusion error. messages-only (MITO): unchanged — chat template applied normally. Smoke regression on bis-dev/4-02 against Qwen/Qwen3-0.6B: sft (full-FT bracket): PASS, mismatch_kl 0.0006/0.0007 lora (hot-swap): PASS, lora_id=1626203954 Pairs with prime-rl bis/prime-rl-merged commits that add VLLMGenerateClient / DynamoGenerateClient + setup_generate_client and thread them through compute_teacher_logprobs. Same client.backend axis as setup_admin_api; one config field drives both admin and data paths. --- lib/llm/src/http/service/openai.rs | 14 ++++- lib/llm/src/preprocessor.rs | 22 ++++++-- lib/llm/src/protocols/openai.rs | 18 +++++- .../src/protocols/openai/chat_completions.rs | 55 ++++++++++++++++++- lib/llm/src/protocols/openai/nvext.rs | 20 +++++++ lib/llm/src/protocols/openai/validate.rs | 10 ++++ 6 files changed, 132 insertions(+), 7 deletions(-) diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 591c29c9cb50..92bd2724da2f 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -1578,7 +1578,19 @@ pub fn validate_chat_completion_required_fields( ) -> Result<(), ErrorResponse> { let inner = &request.inner; - if inner.messages.is_empty() { + // RL renderer / TITO callers send `prompt_token_ids` (or legacy + // `nvext.token_data`) in place of `messages`. Treat either pre-tokenized + // input as satisfying the "non-empty input" requirement. + let has_pretokenized_input = request + .unsupported_fields + .contains_key("prompt_token_ids") + || request + .nvext + .as_ref() + .and_then(|ext| ext.token_data.as_ref()) + .is_some(); + + if inner.messages.is_empty() && !has_pretokenized_input { return Err(ErrorMessage::from_http_error(HttpError { code: 400, message: VALIDATION_PREFIX.to_string() diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index e8838513d588..3fa355a98053 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -379,6 +379,16 @@ impl OpenAIPreprocessor { &self, request: &R, ) -> Result> { + // Renderer / TITO callers post `prompt_token_ids` (or legacy + // `nvext.token_data`); chat templating is bypassed entirely. + // `gather_tokens` reads the same channel via `get_pretokenized_input` + // and feeds the engine directly, so we must not attempt to render + // a chat template here (would fail with "undefined value" when + // `messages` is empty). + if request.get_pretokenized_input().is_some() { + return Ok(None); + } + if let PromptInput::Text(_) = request.prompt_input_type() && let Some(TextInput::Single(_)) = request.extract_text() { @@ -593,8 +603,12 @@ impl OpenAIPreprocessor { .and_then(|ext| ext.backend_instance_id) .is_some(); - let token_data = - request.nvext().and_then(|ext| ext.token_data.as_ref()); + // get_pretokenized_input() consults both + // `nvext.token_data` (legacy GAIE/EPP/TITO path) AND + // top-level `prompt_token_ids` extension (renderer / TITO + // canonical path now that `/v1/chat/completions/tokens` is + // dropped). Either channel produces the same engine input. + let token_data = request.get_pretokenized_input(); // Use token_data when provided (TITO / EPP / RL), // regardless of backend_instance_id. @@ -611,9 +625,9 @@ impl OpenAIPreprocessor { token_count = tokens.len(), first_tokens = ?&tokens[..std::cmp::min(5, tokens.len())], backend_instance_id = has_backend_instance_id, - "[SIDECAR-SKIP-TOKENIZE] Found nvext.token_data — using pre-computed tokens, SKIPPING tokenization" + "[SIDECAR-SKIP-TOKENIZE] Found pre-tokenized input (nvext.token_data or prompt_token_ids extension) — using pre-computed tokens, SKIPPING tokenization" ); - (tokens.clone(), has_backend_instance_id) + (tokens, has_backend_instance_id) } else if has_backend_instance_id { tracing::warn!( "backend_instance_id provided but no token_data; tokenizing prompt" diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 4d022ac01c83..961ec44c2c8b 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -81,6 +81,17 @@ pub(crate) trait OpenAIStopConditionsProvider { fn get_max_thinking_tokens(&self) -> Option { self.nvext().and_then(|nv| nv.max_thinking_tokens) } + + /// Get token-id-based stop conditions (renderer / TITO parity with vLLM + /// `/inference/v1/generate`'s `sampling_params.stop_token_ids`). + /// + /// Default returns None; chat-completions / completions impls override to + /// read the field from `unsupported_fields["stop_token_ids"]` after the + /// `validate.rs` PASSTHROUGH allowlist accepts it. Plumbed through into + /// `common::StopConditions::stop_token_ids_hidden` by `extract_stop_conditions`. + fn get_stop_token_ids(&self) -> Option> { + None + } } pub(crate) trait OpenAIOutputOptionsProvider { @@ -181,6 +192,11 @@ impl StopConditionsProvider for T { let min_tokens = self.get_min_tokens(); let stop = self.get_stop(); let max_thinking_tokens = self.get_max_thinking_tokens(); + // Token-id stop conditions ride through PASSTHROUGH_EXTRA_FIELDS on the + // chat-completions surface; impls of this trait read it from the + // request's `unsupported_fields` map. Engine receives it as + // `stop_token_ids_hidden` (already wired in `common::StopConditions`). + let stop_token_ids_hidden = self.get_stop_token_ids(); if let Some(stop) = &stop && stop.len() > 4 @@ -195,7 +211,7 @@ impl StopConditionsProvider for T { max_tokens, min_tokens, stop, - stop_token_ids_hidden: None, + stop_token_ids_hidden, ignore_eos, max_thinking_tokens, }) diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index 6e200e53bd67..b49c61518558 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -112,6 +112,25 @@ impl NvExtProvider for NvCreateChatCompletionRequest { fn raw_prompt(&self) -> Option { None } + + /// Pre-tokenized input — checks `nvext.token_data` first (legacy path), + /// falls back to top-level `prompt_token_ids` extension that + /// `validate.rs` PASSTHROUGH_EXTRA_FIELDS allowlists. The two channels + /// are equivalent at the engine level; the top-level extension is the + /// canonical home now that `/v1/chat/completions/tokens` is dropped. + fn get_pretokenized_input(&self) -> Option> { + // 1. Prefer nvext.token_data when present (existing GAIE/EPP path). + if let Some(token_data) = self.nvext.as_ref().and_then(|ext| ext.token_data.as_ref()) { + return Some(token_data.clone()); + } + // 2. Fall back to top-level `prompt_token_ids` extension. Renderer + // and TITO callers post here directly — the field rides through + // PASSTHROUGH_EXTRA_FIELDS without 400, then we promote it to + // the engine path here. + self.unsupported_fields + .get("prompt_token_ids") + .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + } } /// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`, @@ -321,6 +340,19 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { fn get_ignore_eos(&self) -> Option { self.common.ignore_eos } + + /// Read `stop_token_ids` from `unsupported_fields` (allowlisted via + /// `validate.rs` PASSTHROUGH_EXTRA_FIELDS). RL renderer / TITO callers + /// rely on this for stop-on-token-id conditions that don't tokenize + /// cleanly as strings (custom EOS, model-specific control tokens). + /// Malformed values silently fall back to None — `validate.rs` already + /// took the "is this allowed at all" decision; here we only choose + /// between "we got a usable list" and "we got nothing". + fn get_stop_token_ids(&self) -> Option> { + self.unsupported_fields + .get("stop_token_ids") + .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + } } impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { @@ -353,7 +385,28 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { impl ValidateRequest for NvCreateChatCompletionRequest { fn validate(&self) -> Result<(), anyhow::Error> { validate::validate_no_unsupported_fields(&self.unsupported_fields)?; - validate::validate_messages(&self.inner.messages)?; + // Mutual-exclusivity: messages OR prompt_token_ids extension, not both. + // When prompt_token_ids is present (renderer / TITO), messages can be + // empty and chat templating is bypassed by the preprocessor's + // `get_pretokenized_input()` path. + let has_pretokenized_input = + self.unsupported_fields.contains_key("prompt_token_ids") + || self + .nvext + .as_ref() + .and_then(|ext| ext.token_data.as_ref()) + .is_some(); + if has_pretokenized_input { + if !self.inner.messages.is_empty() { + anyhow::bail!( + "messages and prompt_token_ids are mutually exclusive; \ + send one (use prompt_token_ids for renderer / TITO mode, \ + messages for MITO mode)" + ); + } + } else { + validate::validate_messages(&self.inner.messages)?; + } validate::validate_model(&self.inner.model)?; // none for store validate::validate_reasoning_effort(&self.inner.reasoning_effort)?; diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index a2fbb8aa64fa..1c4711c3a4cc 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -73,6 +73,26 @@ pub fn apply_header_routing_overrides(nvext: Option, headers: &HeaderMap) pub trait NvExtProvider { fn nvext(&self) -> Option<&NvExt>; fn raw_prompt(&self) -> Option; + + /// Pre-tokenized input that bypasses chat templating. + /// + /// Two callers populate this today: + /// - GAIE EPP / TITO via `nvext.token_data` (existing path). + /// - Renderer / TITO via top-level `prompt_token_ids` extension on + /// `/v1/chat/completions` (allowlisted by `validate.rs` + /// PASSTHROUGH_EXTRA_FIELDS). This is the canonical home now that + /// `/v1/chat/completions/tokens` is dropped. + /// + /// Default reads only `nvext.token_data`; the chat-completions impl + /// also falls back to `unsupported_fields["prompt_token_ids"]` so the + /// preprocessor sees one effective value regardless of which channel + /// the client used. Returns owned Vec because the top-level field + /// arrives as a JSON value that has to be deserialized fresh. + fn get_pretokenized_input(&self) -> Option> { + self.nvext() + .and_then(|ext| ext.token_data.as_ref()) + .cloned() + } } /// Worker ID information for disaggregated serving diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index e581d04303fc..9864744d6b8e 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -137,6 +137,16 @@ const PASSTHROUGH_EXTRA_FIELDS: &[&str] = &[ // Opt-in for `nvext.prompt_logprobs` on the response. Aliased through // to vLLM's `sampling_params.prompt_logprobs` in a follow-up. "return_prompt_logprobs", + // Token-level sampling controls. Without these, callers in renderer / + // TITO mode can't express stop-on-token-id, constrained sampling, or + // bad-word filtering — which is the whole reason vLLM 0.20's + // `/inference/v1/generate` exists. Promoting them to the chat-completions + // surface as PASSTHROUGH extras keeps `/v1/chat/completions` as the + // single canonical RL data path with full SamplingParams parity. + "stop_token_ids", + "bad_words_token_ids", + "allowed_token_ids", + "truncate_prompt_tokens", ]; /// Validates that no unsupported fields are present in the request From a2cc90da6d6fc200e53084bc7d76dc98aa4cc7b6 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Tue, 5 May 2026 00:08:25 -0700 Subject: [PATCH 04/18] fix(rl): typed Result on get_stop_token_ids + tests + relax mutual-exclusion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three improvements borrowed from PR #9141 (Ameen Patel) on top of the existing PASSTHROUGH_EXTRA_FIELDS expansion: 1. get_stop_token_ids returns Result>>, not Option>. Malformed payloads (e.g. stop_token_ids: "not-an-array") now surface as a typed 400 with the diagnostic 'stop_token_ids must be an array of unsigned token IDs: {err}'. extract_stop_conditions propagates the Result via ?. Replaces the prior silent-fallback Option<> variant which dropped malformed inputs without telling the caller. Silent drops on RL correctness primitives (stop conditions affect what tokens the engine emits) is the bug class CR-9 was about — same principle applies here. 2. Mutual-exclusion between messages and pre-tokenized input is now scoped to the canonical TOP-LEVEL prompt_token_ids extension only. The legacy nvext.token_data channel — which the verifiers dynamo_chat_nvext renderer transport (#1287) uses with placeholder messages 'role: user, content: (token-in mode)' — is allowed to coexist with non-empty messages. validate_messages still gates the empty-messages-with-no-tokens case. Without this relaxation, the renderer transport's placeholder pattern would 400 on every request. 3. Two new tests in test_common_ext.rs: - test_chat_completions_stop_token_ids_extraction: positive case with nvext.token_data + top-level stop_token_ids (lifted from PR #9141 verbatim). - test_chat_completions_stop_token_ids_malformed_returns_400: verifies the typed-error path on bad input. Pre-existing test struct-init sites in test_common_ext.rs were missing required fields (return_token_ids, tokens) added to the NvCreateChatCompletionRequest struct since the tests were written. Three sites updated to construct cleanly. cargo test test_common_ext: 15 tests, 15 passes. --- lib/llm/src/protocols/openai.rs | 18 +++-- .../src/protocols/openai/chat_completions.rs | 64 ++++++++++------- lib/llm/tests/test_common_ext.rs | 71 +++++++++++++++++-- 3 files changed, 112 insertions(+), 41 deletions(-) diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 961ec44c2c8b..a3af818a4edd 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -85,12 +85,14 @@ pub(crate) trait OpenAIStopConditionsProvider { /// Get token-id-based stop conditions (renderer / TITO parity with vLLM /// `/inference/v1/generate`'s `sampling_params.stop_token_ids`). /// - /// Default returns None; chat-completions / completions impls override to - /// read the field from `unsupported_fields["stop_token_ids"]` after the - /// `validate.rs` PASSTHROUGH allowlist accepts it. Plumbed through into - /// `common::StopConditions::stop_token_ids_hidden` by `extract_stop_conditions`. - fn get_stop_token_ids(&self) -> Option> { - None + /// Default returns `Ok(None)`; chat-completions / completions impls + /// override to read the field from `unsupported_fields["stop_token_ids"]` + /// after the `validate.rs` PASSTHROUGH allowlist accepts it. Plumbed + /// through into `common::StopConditions::stop_token_ids_hidden` by + /// `extract_stop_conditions`. Returns `Result` so malformed payloads + /// surface as a typed 400 instead of silently dropping the field. + fn get_stop_token_ids(&self) -> anyhow::Result>> { + Ok(None) } } @@ -196,7 +198,9 @@ impl StopConditionsProvider for T { // chat-completions surface; impls of this trait read it from the // request's `unsupported_fields` map. Engine receives it as // `stop_token_ids_hidden` (already wired in `common::StopConditions`). - let stop_token_ids_hidden = self.get_stop_token_ids(); + // The `?` propagates a typed 400 on malformed payloads (e.g. + // `stop_token_ids: "not-an-array"`). + let stop_token_ids_hidden = self.get_stop_token_ids()?; if let Some(stop) = &stop && stop.len() > 4 diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index b49c61518558..328711139e5f 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -345,13 +345,20 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { /// `validate.rs` PASSTHROUGH_EXTRA_FIELDS). RL renderer / TITO callers /// rely on this for stop-on-token-id conditions that don't tokenize /// cleanly as strings (custom EOS, model-specific control tokens). - /// Malformed values silently fall back to None — `validate.rs` already - /// took the "is this allowed at all" decision; here we only choose - /// between "we got a usable list" and "we got nothing". - fn get_stop_token_ids(&self) -> Option> { - self.unsupported_fields - .get("stop_token_ids") - .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + /// Malformed values surface as a typed `anyhow::Error` so the caller + /// gets a 400 with a useful diagnostic rather than a silent drop. + fn get_stop_token_ids(&self) -> anyhow::Result>> { + let Some(value) = self.unsupported_fields.get("stop_token_ids") else { + return Ok(None); + }; + if value.is_null() { + return Ok(None); + } + serde_json::from_value(value.clone()) + .map(Some) + .map_err(|err| { + anyhow::anyhow!("stop_token_ids must be an array of unsigned token IDs: {err}") + }) } } @@ -385,26 +392,29 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { impl ValidateRequest for NvCreateChatCompletionRequest { fn validate(&self) -> Result<(), anyhow::Error> { validate::validate_no_unsupported_fields(&self.unsupported_fields)?; - // Mutual-exclusivity: messages OR prompt_token_ids extension, not both. - // When prompt_token_ids is present (renderer / TITO), messages can be - // empty and chat templating is bypassed by the preprocessor's - // `get_pretokenized_input()` path. - let has_pretokenized_input = - self.unsupported_fields.contains_key("prompt_token_ids") - || self - .nvext - .as_ref() - .and_then(|ext| ext.token_data.as_ref()) - .is_some(); - if has_pretokenized_input { - if !self.inner.messages.is_empty() { - anyhow::bail!( - "messages and prompt_token_ids are mutually exclusive; \ - send one (use prompt_token_ids for renderer / TITO mode, \ - messages for MITO mode)" - ); - } - } else { + // Mutual-exclusivity applies ONLY to the canonical top-level + // `prompt_token_ids` extension (the new vLLM-0.20-aligned channel). + // The legacy `nvext.token_data` channel is intentionally allowed to + // coexist with non-empty messages — that's how the renderer transport + // ships pre-tokenized inputs alongside placeholder messages + // (PrimeIntellect-ai/verifiers PR #1287's `dynamo_chat_nvext` mode). + // Empty messages are accepted when EITHER channel carries tokens. + let has_top_level_prompt_token_ids = + self.unsupported_fields.contains_key("prompt_token_ids"); + let has_nvext_token_data = self + .nvext + .as_ref() + .and_then(|ext| ext.token_data.as_ref()) + .is_some(); + + if has_top_level_prompt_token_ids && !self.inner.messages.is_empty() { + anyhow::bail!( + "messages and prompt_token_ids are mutually exclusive; \ + send one (use prompt_token_ids for renderer / TITO mode, \ + messages for MITO mode)" + ); + } + if !has_top_level_prompt_token_ids && !has_nvext_token_data { validate::validate_messages(&self.inner.messages)?; } validate::validate_model(&self.inner.model)?; diff --git a/lib/llm/tests/test_common_ext.rs b/lib/llm/tests/test_common_ext.rs index 8e49c7377b09..f8bfd3ed6232 100644 --- a/lib/llm/tests/test_common_ext.rs +++ b/lib/llm/tests/test_common_ext.rs @@ -1,13 +1,16 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use dynamo_llm::protocols::{ - common::StopConditionsProvider, - openai::{ - chat_completions::NvCreateChatCompletionRequest, - common_ext::{CommonExt, CommonExtProvider}, - completions::NvCreateCompletionRequest, - nvext::NvExt, +use dynamo_llm::{ + engines::ValidateRequest, + protocols::{ + common::StopConditionsProvider, + openai::{ + chat_completions::NvCreateChatCompletionRequest, + common_ext::{CommonExt, CommonExtProvider}, + completions::NvCreateCompletionRequest, + nvext::NvExt, + }, }, }; @@ -70,6 +73,8 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; let sampling = request.extract_sampling_options().unwrap(); @@ -213,6 +218,54 @@ fn test_max_thinking_tokens_extraction() { assert_eq!(stop_conditions_none.max_thinking_tokens, None); } +#[test] +fn test_chat_completions_stop_token_ids_extraction() { + // Renderer / TITO callers send `stop_token_ids` as a top-level field + // alongside `nvext.token_data`. Both ride PASSTHROUGH_EXTRA_FIELDS; + // `extract_stop_conditions` plumbs the IDs into + // `common::StopConditions::stop_token_ids_hidden` so the engine layer + // honors them. (Lifted from PR #9141.) + let json_str = r#"{ + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "nvext": { + "token_data": [1, 2, 3] + }, + "stop_token_ids": [151645, 151643] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + + request.validate().unwrap(); + let stop_conditions = request.extract_stop_conditions().unwrap(); + assert_eq!( + stop_conditions.stop_token_ids_hidden, + Some(vec![151645, 151643]) + ); +} + +#[test] +fn test_chat_completions_stop_token_ids_malformed_returns_400() { + // Malformed stop_token_ids must NOT silently fall back to None — it + // surfaces as a typed anyhow::Error so the HTTP layer returns 400 with + // a useful diagnostic. (PR #9141 contract.) + let json_str = r#"{ + "model": "test-model", + "messages": [{"role": "user", "content": "x"}], + "stop_token_ids": "not-an-array" + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let err = request + .extract_stop_conditions() + .expect_err("malformed stop_token_ids should error"); + assert!( + err.to_string() + .contains("stop_token_ids must be an array of unsigned token IDs"), + "got: {err}" + ); +} + #[test] fn test_chat_completions_no_common_values() { // Test that when no common values are set, we get None @@ -300,6 +353,8 @@ fn test_serialization_preserves_structure() { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; let json = serde_json::to_value(&request).unwrap(); @@ -352,6 +407,8 @@ fn test_sampling_parameters_extraction() { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; let sampling_options = request.extract_sampling_options().unwrap(); From 2cb5e6040916384c9d65e5e4abfd997e1c308fd2 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 7 May 2026 15:59:34 -0700 Subject: [PATCH 05/18] rl api docs --- docs/Dynamo-RL-api-draft.md | 973 ------------------------------------ docs/dynamo-RL-api.md | 590 ++++++++++++++++++++++ 2 files changed, 590 insertions(+), 973 deletions(-) delete mode 100644 docs/Dynamo-RL-api-draft.md create mode 100644 docs/dynamo-RL-api.md diff --git a/docs/Dynamo-RL-api-draft.md b/docs/Dynamo-RL-api-draft.md deleted file mode 100644 index ae8113ed3073..000000000000 --- a/docs/Dynamo-RL-api-draft.md +++ /dev/null @@ -1,973 +0,0 @@ -# Dynamo RL API Draft - -**Branch:** `bis/parity-tokenize-tcp` (HEAD: `d837fbd67`) - -Commit `70f84570b` (the current auto-enable-token-ids commit on HEAD) is an -equivalent cherry-pick of the earlier `19d1bf13d` referenced in prior drafts — -same subject, same patch semantics, different parent tree after rebase onto -`origin/main`. - ---- - -## Table of Contents - -1. [Overview](#1-overview) -2. [Architecture](#2-architecture) -3. [Configuration](#3-configuration) -4. [API Reference](#4-api-reference) - - 4.1 Chat Completions (RL-enhanced) - - 4.2 Token-In / Token-Out (TITO) - - 4.3 Tokenization - - 4.4 Fleet Control (`/v1/rl/*`) -5. [Data Flow](#5-data-flow) -6. [Key Data Structures](#6-key-data-structures) -7. [Worker Engine Routes (Internal)](#7-worker-engine-routes-internal) -8. [Known Limitations](#8-known-limitations) -9. [Validation Results](#9-validation-results) - ---- - -## 1. Overview - -This document describes the RL training API surface on the Dynamo serving stack for integration with prime-rl. The Dynamo frontend (Rust) exposes: - -- An `/v1/rl/*` router for the full RL control-plane lifecycle (pause/resume, weight updates, readiness checks) -- Automatic token-level data injection (`prompt_token_ids`, `completion_token_ids`) in chat completion responses -- `/v1/tokenize` and `/v1/detokenize` endpoints -- A `/v1/chat/completions/tokens` TITO endpoint for pre-tokenized prompt bypass - -Zero Python in the inference or admin data path. The Rust frontend handles all HTTP API surface while vLLM workers expose engine routes for weight lifecycle operations on the GPU. - -### Endpoint Summary - -| Capability | Endpoint | Purpose | -|------------|----------|---------| -| Inference | `POST /v1/chat/completions` | Generate rollouts; responses include `prompt_token_ids` + `choice.token_ids` | -| TITO inference | `POST /v1/chat/completions/tokens` | Pre-tokenized prompt bypass (turn 2+ in multi-turn RL) | -| Tokenization | `POST /v1/tokenize` | Consistent tokenization using the model's chat template | -| Detokenization | `POST /v1/detokenize` | Token IDs back to text | -| Pause fleet | `POST /v1/rl/pause` | Drain in-flight requests before weight update | -| Resume fleet | `POST /v1/rl/resume` | Resume generation after weight update | -| Update weights | `POST /v1/rl/update_weights` | Atomic flush + reload from checkpoint directory | -| Load LoRA adapter | `POST /v1/rl/load_lora_adapter` | Hot-load/swap a PEFT-style adapter from filesystem path | -| Unload LoRA adapter | `POST /v1/rl/unload_lora_adapter` | Remove a previously loaded adapter by name | -| Weight version | `GET /v1/rl/weight_version` | Query current weight version across workers | -| Health | `GET /v1/rl/health` | Lightweight frontend health check | -| Readiness | `GET /v1/rl/ready` | Deep check: are workers reachable and healthy? | - -### What Changed vs. Stock Dynamo - -All changes are on `bis/parity-tokenize-tcp` (18 commits, 26 files, +3030/-41 — the diff counts include this doc). Nothing touches Dynamo's core serving pipeline (NATS, scheduler, KV cache, disaggregation). The changes are additive: - -- **Rust frontend** (`lib/llm/`): New routes, response post-processing, tokenization endpoints, LoRA hot-swap admin routes -- **vLLM worker** (`components/`): 7 engine route handlers (5 weight-lifecycle + 2 LoRA), publisher crash guard -- **Deps** (`container/`): default `VLLM_VER` bumped 0.19.0 → 0.19.1; prime-rl plugin installed via `pip` so `vllm.general_plugins` patches apply at engine start -- **Compat fixes**: `/v1/tokenize` and `/v1/detokenize` adapted to upstream `DecodeResult`-returning decoder (main commit `2cabf4414`, #8022) - ---- - -## 2. Architecture - -### Component Topology - -```mermaid -flowchart TD - subgraph prime_rl["prime-rl"] - orch["Orchestrator
(prime_rl.orchestrator)"] - trainer["Trainer
(prime_rl.trainer.rl.train)
torchrun --nproc-per-node=N"] - end - - subgraph dynamo["Dynamo Serving Stack"] - subgraph frontend["Frontend Pod (Rust, port 8000)"] - cc["/v1/chat/completions
+ prompt_token_ids
+ choice.token_ids"] - tito["/v1/chat/completions/tokens
(TITO)"] - tok["/v1/tokenize   /v1/detokenize"] - rl["/v1/rl/*
health, ready, pause, resume,
update_weights, weight_version"] - end - subgraph worker["vLLM Worker Pod (Python, system port 9090)"] - eng["/engine/*
pause_generation
resume_generation
flush_cache
update_weights_from_path
get_weight_version"] - gpu["GPU
Model Weights"] - end - end - - subgraph storage["Shared Storage (PVC)"] - pvc["prime-rl-shared-data
safetensors checkpoints"] - end - - orch -- "rollouts
POST /v1/chat/completions" --> cc - orch -- "TITO turn 2+
POST /v1/chat/completions/tokens" --> tito - orch -- "POST /v1/tokenize" --> tok - orch -- "weight lifecycle
pause / update_weights / resume" --> rl - rl -- "HTTP fan-out
(concurrent to all workers)" --> eng - eng --> gpu - trainer -- "write checkpoint" --> pvc - eng -- "reload_weights
(collective_rpc)" --> pvc -``` - -### Key Design Decisions - -1. **Single entry point.** Prime-RL points both `base_url` and `admin_base_url` at the Dynamo frontend. No separate admin service to deploy. - -2. **Fan-out in Rust.** The `/v1/rl/*` handlers fan out to all vLLM workers via `DYN_RL_WORKER_SYSTEM_URLS`. This supports DP>1 without Prime-RL needing to discover workers. The frontend returns HTTP 200 only when all workers respond OK, and HTTP 502 otherwise with per-worker details. - -3. **Token IDs as a response extension.** When `DYN_ENABLE_RL=true`, `prompt_token_ids` and `choices[i].token_ids` are injected into every non-streaming response automatically. No client-side configuration needed. - -4. **Backward compatible.** All new response fields use `#[serde(skip_serializing_if = "Option::is_none")]`. Clients that don't set `DYN_ENABLE_RL` see standard OpenAI-compatible responses with no extra fields. - ---- - -## 3. Configuration - -### Environment Variables (Frontend) - -| Variable | Default | Description | -|----------|---------|-------------| -| `DYN_ENABLE_RL` | `false` | Master switch. Mounts `/v1/rl/*` routes, auto-injects token IDs in chat completion responses, mounts TITO endpoint. | -| `DYN_RL_WORKER_SYSTEM_URLS` | `http://localhost:8081` | Comma-separated list of vLLM worker system HTTP base URLs for fan-out. | - -### Environment Variables (Worker) - -| Variable | Default | Description | -|----------|---------|-------------| -| `DYN_SYSTEM_PORT` | `8081` (local) / `9090` (k8s) | Worker's system HTTP port where engine routes are registered. | - -### Prime-RL Configuration (`orch.toml`) - -```toml -max_steps = 20 -seq_len = 512 -batch_size = 16 -rollouts_per_example = 4 -use_token_client = false - -[model] -name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" - -[sampling] -max_tokens = 64 - -[[env]] -id = "reverse-text" - -[client] -# Point BOTH base_url and admin_base_url at the Dynamo frontend. -# admin_base_url uses /v1/rl because Prime-RL strips trailing /v1 -# from admin URLs, but /v1/rl is preserved. -base_url = ["http://:8000/v1"] -admin_base_url = ["http://:8000/v1/rl"] -skip_model_check = true - -[weight_broadcast] -type = "filesystem" - -[experimental] -# Disable prefix cache salt until Dynamo supports it. -# verifiers dev6+ defaults use_prefix_cache_salt=True; current image returns 400. -use_prefix_cache_salt = false -``` - -**Important:** Do NOT set `send_return_token_ids = true` in `[sampling]`. The Rust frontend handles token ID injection automatically when `DYN_ENABLE_RL=true`. Sending `return_token_ids=true` in the request causes the OpenAI SDK to parse the response and strip unknown fields. - -### Kubernetes (DGD) - -```yaml -# Frontend pod env -- name: DYN_ENABLE_RL - value: "true" -- name: DYN_RL_WORKER_SYSTEM_URLS - value: "http://prime-rl-dynamo-vllmworker..svc.cluster.local:9090" -``` - -### Launch Commands (Local) - -```bash -# Frontend with RL routes enabled -DYN_ENABLE_RL=true \ -DYN_RL_WORKER_SYSTEM_URLS=http://localhost:8081 \ - python -m dynamo.frontend - -# vLLM Worker -CUDA_VISIBLE_DEVICES=0 \ -DYN_SYSTEM_PORT=8081 \ - python -m dynamo.vllm \ - --model PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT \ - --served-model-name PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT \ - --enforce-eager \ - --max-model-len 2048 \ - --gpu-memory-utilization 0.5 -``` - ---- - -## 4. API Reference - -All endpoints live on the Dynamo Rust frontend (default port 8000). Unless noted, request/response formats follow the OpenAI API specification. - -### 4.1 Chat Completions (RL-enhanced) - -``` -POST /v1/chat/completions -``` - -Standard OpenAI chat completions with RL extensions. When `DYN_ENABLE_RL=true`, every non-streaming response is automatically enriched with token IDs for the trainer. - -#### Request - -Standard OpenAI `ChatCompletionRequest`. Two additional fields are accepted and silently consumed (never forwarded to the vLLM worker): - -| Field | Type | Default | Mapped to | -|-------|------|---------|-----------| -| `tokens` | `u32[]` | `null` | `nvext.token_data` (tokenizer bypass) | -| `return_token_ids` | `bool` | `null` | `nvext.extra_fields: ["token_ids", "completion_token_ids"]` + `logprobs: true` | - -When `DYN_ENABLE_RL=true`, `return_token_ids` is implicitly `true` for every request. - -#### Sample Request - -```bash -curl -s -X POST http://localhost:8000/v1/chat/completions \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", - "messages": [ - {"role": "user", "content": "Reverse this: hello world"} - ], - "max_tokens": 64, - "temperature": 1.0 - }' -``` - -#### Sample Response (Non-Streaming, with `DYN_ENABLE_RL=true`) - -```json -{ - "id": "chatcmpl-abc123", - "object": "chat.completion", - "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", - "choices": [{ - "index": 0, - "message": {"role": "assistant", "content": "dlrow olleh"}, - "finish_reason": "stop", - "logprobs": { - "content": [ - {"token": "dl", "logprob": -0.523, "top_logprobs": []}, - {"token": "row", "logprob": -0.102, "top_logprobs": []}, - {"token": " ol", "logprob": -0.834, "top_logprobs": []}, - {"token": "leh", "logprob": -0.211, "top_logprobs": []} - ] - }, - "token_ids": [67, 1245, 893, 15] - }], - "prompt_token_ids": [151644, 8948, 198, 151645, 198, 151644, 872, 198, - 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, - 198, 151644, 77091, 198], - "usage": {"prompt_tokens": 21, "completion_tokens": 4, "total_tokens": 25}, - "nvext": { - "completion_token_ids": [67, 1245, 893, 15] - } -} -``` - -#### Response Field Reference - -| Field | JSON path | Description | -|-------|-----------|-------------| -| `prompt_token_ids` | `response.prompt_token_ids` | Token IDs from tokenizing the prompt messages through the model's chat template. Generated by the Rust frontend's tokenizer after the response is fully received. | -| `token_ids` | `response.choices[i].token_ids` | Completion token IDs generated by the engine. Promoted from `nvext.completion_token_ids`. | -| `completion_token_ids` | `response.nvext.completion_token_ids` | Canonical Dynamo location for output token IDs. Accumulated across all SSE chunks by `DeltaGenerator`. | - -**Why `token_ids` appears in two locations:** Prime-RL's verifiers library reads `response.prompt_token_ids` and `choices[i].token_ids` (top-level on the choice object). Dynamo natively emits output token IDs in `nvext.completion_token_ids`. The Rust post-processor promotes the latter to the former for compatibility. Both contain the same values. - -**Invariant:** `len(completion_token_ids) == len(logprobs.content)` -- the output token IDs are in exact 1:1 correspondence with the logprob entries. - -#### Sample Response (Streaming / SSE -- final chunk only) - -Intermediate chunks carry `delta.content` only. Token IDs appear exclusively on the **final chunk** (the one with a non-null `finish_reason`): - -``` -data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk", - "choices":[{"index":0,"delta":{},"finish_reason":"stop", - "nvext":{"completion_token_ids":[67,1245,893,15]}}], - "prompt_token_ids":[151644,8948,198,151645,198,151644,872,198, - 49,1075,513,420,25,24748,1879,198,151645, - 198,151644,77091,198]} - -data: [DONE] -``` - -#### RL Post-Processing Pipeline - -For non-streaming requests, the handler performs the following after the backend response is fully aggregated: - -```mermaid -flowchart LR - A["Request arrives
DYN_ENABLE_RL=true"] --> B["Save messages
for later tokenization"] - B --> C["Inject nvext.extra_fields:
[token_ids, completion_token_ids]
Force logprobs=true"] - C --> D["Standard pipeline
(preprocessor, backend,
delta generator, aggregator)"] - D --> E["Aggregate response
(nvext.completion_token_ids
accumulated in delta.rs)"] - E --> F["rl_tokenize_prompt()
messages -> prompt_token_ids
via model chat template"] - F --> G["rl_promote_token_ids()
nvext.completion_token_ids
-> choices[i].token_ids"] - G --> H["Return enriched
JSON response"] -``` - ---- - -### 4.2 Token-In / Token-Out (TITO) - -``` -POST /v1/chat/completions/tokens -``` - -Dedicated endpoint for Prime-RL's pre-tokenized prompt flow (multi-turn RL, turn 2+). The orchestrator sends raw token IDs instead of text messages, bypassing the frontend's tokenizer entirely. This avoids redundant encode/decode round-trips and ensures token-level alignment. - -#### Sample Request - -```bash -curl -s -X POST http://localhost:8000/v1/chat/completions/tokens \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", - "messages": [{"role": "user", "content": "(token-in mode)"}], - "tokens": [151644, 8948, 198, 151645, 198, 151644, 872, 198, - 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, - 198, 151644, 77091, 198, 67, 1245, 893, 15], - "max_tokens": 64 - }' -``` - -| Field | Required | Description | -|-------|----------|-------------| -| `tokens` | **Yes** | Pre-tokenized prompt IDs. Must be non-empty. Injected as `nvext.token_data`. | -| `messages` | Yes (can be placeholder) | A placeholder `{"role": "user", "content": "(token-in mode)"}` is auto-injected if empty. | - -#### Behavior - -1. Extracts `tokens`, returns 400 if missing or empty -2. Injects into `nvext.token_data` (triggers tokenizer bypass in `preprocessor.rs`) -3. Adds `extra_fields: ["token_ids", "completion_token_ids"]` -4. Forces `logprobs = true` -5. Delegates to the standard `chat_completions()` pipeline (zero HTTP proxy) - -#### Sample Response - -Same shape as section 4.1. The response includes `prompt_token_ids` and `choices[i].token_ids`. - ---- - -### 4.3 Tokenization - -``` -POST /v1/tokenize -POST /v1/detokenize -``` - -Consistent tokenization using the model's tokenizer and chat template, running entirely in Rust. These are critical for RL: prompt token IDs in the chat completion response must match what the tokenizer produces for the same messages. Both endpoints use the same tokenizer instance that the frontend uses for its own request preprocessing. - -#### Sample: Tokenize (Chat variant) - -```bash -curl -s -X POST http://localhost:8000/v1/tokenize \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", - "messages": [ - {"role": "user", "content": "Reverse this: hello world"} - ], - "add_generation_prompt": true, - "add_special_tokens": true - }' -``` - -**Response:** - -```json -{ - "count": 21, - "max_model_len": 2048, - "tokens": [151644, 8948, 198, 151645, 198, 151644, 872, 198, - 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, - 198, 151644, 77091, 198], - "token_strs": null -} -``` - -#### Tokenize Request Fields (Chat variant) - -| Field | Type | Default | Description | -|-------|------|---------|-------------| -| `model` | `string?` | auto-resolve | Model name for tokenizer lookup | -| `messages` | `ChatMessage[]` | -- | Messages to tokenize through the model's chat template | -| `add_generation_prompt` | `bool` | `true` | Append generation prompt (e.g., `<\|im_start\|>assistant\n`) | -| `add_special_tokens` | `bool` | `true` | Add BOS/EOS tokens | -| `return_token_strs` | `bool` | `false` | Include human-readable string representation of each token | -| `chat_template` | `string?` | `null` | Override the model's default chat template (Jinja2) | -| `chat_template_kwargs` | `object?` | `null` | Extra template variables | -| `continue_final_message` | `bool` | `false` | Continue last message instead of starting a new turn | - -The chat variant renders messages through the model's chat template before tokenizing, so the token count is the exact number of tokens that a corresponding chat completion request would consume. - -#### Sample: Tokenize (Completion variant) - -```bash -curl -s -X POST http://localhost:8000/v1/tokenize \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", - "prompt": "hello world", - "add_special_tokens": true, - "return_token_strs": true - }' -``` - -**Response:** - -```json -{ - "count": 3, - "max_model_len": 2048, - "tokens": [9707, 1879, 3], - "token_strs": ["hello", " world", ""] -} -``` - -#### Tokenize Response Fields - -| Field | Type | Description | -|-------|------|-------------| -| `count` | `int` | Number of tokens | -| `max_model_len` | `int` | Model's configured maximum context length | -| `tokens` | `list[int]` | Token ID list | -| `token_strs` | `list[str]?` | Human-readable token strings; only present if `return_token_strs: true` | - -#### Sample: Detokenize - -```bash -curl -s -X POST http://localhost:8000/v1/detokenize \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT", - "tokens": [9707, 1879, 3] - }' -``` - -**Response:** - -```json -{"prompt": "hello world"} -``` - ---- - -### 4.4 Fleet Control (`/v1/rl/*`) - -All `/v1/rl/*` routes are mounted only when `DYN_ENABLE_RL=true`. They fan out to vLLM worker system ports defined by `DYN_RL_WORKER_SYSTEM_URLS`. - -In prime-rl's config: - -```toml -[client] -base_url = ["http://:8000/v1"] -admin_base_url = ["http://:8000/v1/rl"] -``` - ---- - -#### `GET /v1/rl/health` - -Lightweight liveness probe. Returns immediately as long as the frontend process is running. Used by prime-rl's `check_health()` on the admin client. - -```bash -curl -s http://localhost:8000/v1/rl/health -``` - -```json -{"status": "ok"} -``` - ---- - -#### `GET /v1/rl/ready` - -Composite readiness probe. Polls `/health` on every configured worker system URL concurrently. Returns 200 only when all workers respond with HTTP 2xx. - -```bash -curl -s http://localhost:8000/v1/rl/ready -``` - -```json -// All workers ready (200) -{ - "status": "ready", - "workers": [ - {"url": "http://localhost:8081", "healthy": true} - ] -} - -// Not all workers ready (503) -{ - "status": "not_ready", - "workers_ready": 0, - "workers_total": 1, - "workers": [ - {"url": "http://localhost:8081", "healthy": false, "error": "connection refused"} - ] -} -``` - ---- - -#### `POST /v1/rl/pause` - -Quiesces generation on all workers. Each worker calls `engine_client.pause_generation()` which drains in-flight requests without unloading the model from GPU memory. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/pause -H 'Content-Type: application/json' -d '{}' -``` - -```json -// Success (200) -{ - "status": "ok", - "workers": [ - {"status": "ok", "message": "Engine paused"} - ] -} - -// Failure (502) -{ - "status": "error", - "workers": [ - {"status": "ok", "message": "Engine paused"}, - {"status": "error", "message": "timeout"} - ] -} -``` - ---- - -#### `POST /v1/rl/resume` - -Resumes generation on all workers after a weight update. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/resume -H 'Content-Type: application/json' -d '{}' -``` - -```json -{ - "status": "ok", - "workers": [ - {"status": "ok", "message": "Engine resumed"} - ] -} -``` - ---- - -#### `POST /v1/rl/update_weights` - -Atomic weight-loading sequence: `flush_cache` then `update_weights_from_path` on all workers concurrently. The frontend performs the two-phase fan-out internally; prime-rl does not need to call these separately. - -For filesystem-backed weight broadcast, the trainer writes safetensors files to a shared PVC directory and passes that path here. The worker calls `engine_client.collective_rpc("reload_weights", kwargs={"weights_path": path})` which triggers vLLM's layerwise in-place weight reload on every GPU worker. - -For NCCL-based weight broadcast (`weight_dir: null`), Dynamo returns 200 immediately -- the actual weight transfer happens out-of-band via NCCL and Dynamo does not participate. - -```bash -# Filesystem mode -curl -s -X POST http://localhost:8000/v1/rl/update_weights \ - -H 'Content-Type: application/json' \ - -d '{"weight_dir": "/data/outputs/run_default/broadcasts/step_5"}' - -# NCCL mode (Dynamo no-op) -curl -s -X POST http://localhost:8000/v1/rl/update_weights \ - -H 'Content-Type: application/json' \ - -d '{"weight_dir": null}' -``` - -```json -// Success (200) -{ - "status": "ok", - "version": "step_5", - "workers": [ - {"status": "ok", "message": "Weights loaded from /data/outputs/...", "version": "step_5"} - ] -} - -// Failure at flush_cache stage (502) -{ - "status": "error", - "stage": "flush_cache", - "workers": [...] -} - -// Failure at update_weights stage (502) -{ - "status": "error", - "stage": "update_weights_from_path", - "workers": [...] -} -``` - -The `version` string is derived from the basename of `weight_dir` (e.g., `step_5` from `/data/outputs/run_default/broadcasts/step_5`). This version is stored in the worker and retrievable via `/v1/rl/weight_version`. - ---- - -#### `POST /v1/rl/load_lora_adapter` - -Hot-load or hot-swap a LoRA adapter from a filesystem path. The adapter directory must contain PEFT-style `adapter_model.safetensors` and `adapter_config.json` -- the default output layout of prime-rl's LoRA trainer. - -This is the RL-native LoRA path, distinct from Dynamo's URI-based `load_lora` gRPC endpoint (which downloads from S3/file URIs via `LoRAManager`). The admin route is optimized for the training loop: no URI fetch, no MDC churn on hot-swap. - -- **First call for a given `lora_name`**: `add_lora` in the engine, publish a ModelDeploymentCard so subsequent inference requests with `model=` route to this worker. -- **Subsequent calls (hot-swap)**: `remove_lora(old_id)` → `add_lora` with new weights → `reset_prefix_cache`. The MDC is left in place since it already points at this worker. - -Pair with `/v1/rl/pause` and `/v1/rl/resume` for full drain-swap-resume semantics. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/load_lora_adapter \ - -H 'Content-Type: application/json' \ - -d '{"lora_name": "r16-a32", "lora_path": "/data/outputs/run_default/broadcasts/step_5"}' -``` - -```json -// Success (200) -{ - "status": "ok", - "workers": [ - { - "status": "ok", - "message": "LoRA adapter 'r16-a32' loaded from /data/outputs/...", - "lora_name": "r16-a32", - "lora_id": 788776416, - "hot_swap": false - } - ] -} - -// Missing / empty field (400) -{ - "status": "error", - "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)" -} - -// Worker-side failure (502) -- e.g. bad adapter file, rank mismatch, vLLM not --enable-lora -{ - "status": "error", - "workers": [{"status": "error", "message": "..."}] -} -``` - -**vLLM worker requirements**: the engine must be started with `--enable-lora --max-lora-rank R --max-loras N`, with `R` ≥ the adapter rank and `N` ≥ the number of distinct `lora_name` values you expect to have loaded at once. For Prime-RL's single-adapter training loop, `--max-loras 1` is sufficient. - ---- - -#### `POST /v1/rl/unload_lora_adapter` - -Remove a previously loaded LoRA adapter by name. Idempotent: unloading an already-absent adapter returns `status: ok` so callers can retry safely. - -Unregisters the adapter's ModelDeploymentCard so the frontend stops routing `model=` requests to this worker. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/unload_lora_adapter \ - -H 'Content-Type: application/json' \ - -d '{"lora_name": "r16-a32"}' -``` - -```json -// Success (200) -{ - "status": "ok", - "workers": [ - { - "status": "ok", - "message": "LoRA adapter 'r16-a32' unloaded", - "lora_name": "r16-a32", - "lora_id": 788776416 - } - ] -} - -// Already absent -- still ok (200) -{ - "status": "ok", - "workers": [{"status": "ok", "message": "LoRA adapter 'r16-a32' not loaded (no-op)", "lora_name": "r16-a32"}] -} -``` - ---- - -#### `GET /v1/rl/weight_version` - -Returns the currently loaded weight version from all workers. Useful for debugging weight update races or confirming that all workers converged to the same checkpoint. - -```bash -curl -s http://localhost:8000/v1/rl/weight_version -``` - -```json -// All workers consistent (200) -{ - "status": "ok", - "version": "step_5", - "workers": [ - {"version": "step_5"}, - {"version": "step_5"} - ] -} - -// Workers inconsistent (200, with warning) -{ - "status": "inconsistent", - "versions": ["step_4", "step_5"], - "workers": [ - {"version": "step_4"}, - {"version": "step_5"} - ] -} -``` - -Returns HTTP 200 even when versions are inconsistent -- the `status` field distinguishes the cases. A 502 is only returned for network-level failures. - ---- - -## 5. Data Flow - -### 5.1 Rollout (Inference) Path - -```mermaid -sequenceDiagram - participant Orch as prime-rl Orchestrator - participant FE as Dynamo Frontend (Rust) - participant Worker as vLLM Worker (GPU) - - Orch->>FE: POST /v1/chat/completions
{messages, max_tokens, ...} - Note over FE: DYN_ENABLE_RL=true:
inject nvext.extra_fields
= ["token_ids", "completion_token_ids"]
force logprobs=true
save messages for tokenization - FE->>Worker: forward request (TCP/NATS) - Worker-->>FE: SSE chunks
(delta.content + delta.token_ids per chunk) - Note over FE: DeltaGenerator accumulates
completion_token_ids across chunks - Worker-->>FE: final chunk
(finish_reason + nvext.completion_token_ids) - Note over FE: Post-process:
1. rl_tokenize_prompt(messages)
-> response.prompt_token_ids
2. Promote nvext.completion_token_ids
-> choices[i].token_ids - FE-->>Orch: Enriched response:
prompt_token_ids + choices[i].token_ids
+ nvext.completion_token_ids -``` - -### 5.2 Weight Update Path - -```mermaid -sequenceDiagram - participant Trainer as prime-rl Trainer - participant PVC as Shared Storage - participant Orch as prime-rl Orchestrator - participant FE as Dynamo Frontend (Rust) - participant W1 as vLLM Worker 1 - participant W2 as vLLM Worker 2 - - Trainer->>PVC: write checkpoint
/data/outputs/.../step_N/*.safetensors - Trainer->>Orch: notify weight update ready
(internal IPC) - Orch->>FE: POST /v1/rl/pause - FE->>W1: POST /engine/pause_generation - FE->>W2: POST /engine/pause_generation - W1-->>FE: {status: ok} - W2-->>FE: {status: ok} - FE-->>Orch: {status: ok} - Orch->>FE: POST /v1/rl/update_weights
{weight_dir: /data/outputs/.../step_N} - FE->>W1: POST /engine/flush_cache - FE->>W2: POST /engine/flush_cache - W1-->>FE: {status: ok} - W2-->>FE: {status: ok} - FE->>W1: POST /engine/update_weights_from_path
{path: ..., version: step_N} - FE->>W2: POST /engine/update_weights_from_path
{path: ..., version: step_N} - Note over W1,W2: collective_rpc(reload_weights)
vLLM GPUModelRunner.reload_weights()
in-place layer-by-layer load from safetensors - W1-->>FE: {status: ok, version: step_N} - W2-->>FE: {status: ok, version: step_N} - FE-->>Orch: {status: ok, version: step_N} - Orch->>FE: POST /v1/rl/resume - FE->>W1: POST /engine/resume_generation - FE->>W2: POST /engine/resume_generation - W1-->>FE: {status: ok} - W2-->>FE: {status: ok} - FE-->>Orch: {status: ok} - Note over Orch: Continue rollouts
with updated weights -``` - -### 5.3 TITO (Tokens-In, Tokens-Out) Path - -```mermaid -sequenceDiagram - participant Client as Orchestrator (turn 2+) - participant FE as Dynamo Frontend (Rust) - participant PP as Preprocessor - participant Worker as vLLM Worker - - Client->>FE: POST /v1/chat/completions/tokens
{tokens: [9707, 1879, 3], max_tokens: 64} - Note over FE: Extract tokens field
Inject as nvext.token_data
Force extra_fields, logprobs - FE->>PP: Request with nvext.token_data - Note over PP: token_data present:
skip tokenization,
use provided IDs directly - PP->>Worker: Token IDs sent as-is - Worker-->>FE: Completion response
(with completion_token_ids) - FE-->>Client: Enriched response
(prompt_token_ids + token_ids) -``` - ---- - -## 6. Key Data Structures - -### `NvExtResponse` (Rust -- response side) - -Serialized as the `nvext` field in each SSE chunk or the unary response body: - -``` -NvExtResponse { - worker_id?: WorkerIdInfo -- prefill/decode worker IDs for disaggregated serving - timing?: TimingInfo -- request timing (enabled via extra_fields: ["timing"]) - token_ids?: Vec -- GAIE Stage 1: tokenized prompt for Stage 2 reuse - routed_experts?: serde_json::Value -- SGLang-specific expert routing payload - completion_token_ids?: Vec -- RL: generated output token IDs (final chunk only) -} -``` - -The `completion_token_ids` field is populated automatically for all requests when `DYN_ENABLE_RL=true`, or when the client sends `nvext.extra_fields: ["completion_token_ids"]`. - -### `NvCreateChatCompletionRequest` (Rust -- request side) - -New fields relevant to RL: - -| Field | Serialized | Description | -|-------|-----------|-------------| -| `tokens` | No (`skip_serializing`) | Pre-tokenized prompt token IDs (TITO path via `/v1/chat/completions/tokens`) | -| `return_token_ids` | No (`skip_serializing`) | prime-rl compat field; accepted but ignored on the standard endpoint -- use `DYN_ENABLE_RL` or `nvext.extra_fields` instead | - -Both fields are stripped before the request is forwarded to the vLLM worker, preventing 400 errors from the vLLM OpenAI-compatible API. - -### `NvCreateChatCompletionResponse` (Rust -- response side) - -``` -NvCreateChatCompletionResponse { - inner: CreateChatCompletionResponse -- standard OpenAI response fields - nvext?: serde_json::Value -- NvExtResponse serialized as JSON - prompt_token_ids?: Vec -- RL: tokenized prompt IDs (DYN_ENABLE_RL only) -} -``` - -### `DeltaGenerator` (Rust -- streaming pipeline) - -Manages per-request streaming state. Accumulates output token IDs across chunks: - -``` -DeltaGenerator { - ... - accumulated_completion_token_ids: Vec -- grows per chunk -} -``` - -- **Activation:** `options.enable_completion_token_ids` is set to `true` when `extra_fields` includes `"completion_token_ids"` (auto-set when `DYN_ENABLE_RL=true`). -- **Accumulation:** On each postprocessor output chunk, appends `delta.token_ids` to the accumulator. -- **Emission:** On the final chunk (`finish_reason` is set), the full list is emitted in `nvext.completion_token_ids` and the accumulator is cleared. - -### `NvExt` (Rust -- request-side NVIDIA extensions) - -Relevant fields for RL: - -``` -NvExt { - token_data?: Vec -- Pre-tokenized prompt IDs (TITO / EPP bypass) - extra_fields?: Vec -- Request extra response fields, e.g. ["completion_token_ids"] - backend_instance_id?: u64 -- Targeted routing to a specific worker - ... -} -``` - -### Tokenization Types - -``` -TokenizeRequest = Completion { prompt, model?, add_special_tokens? } - | Chat { messages, model?, add_generation_prompt?, chat_template?, ... } - -TokenizeResponse { count: int, max_model_len: int, tokens: Vec, token_strs?: Vec } - -DetokenizeRequest { model?: String, tokens: Vec } -DetokenizeResponse { prompt: String } -``` - -### Post-Processing Helpers - -**`rl_tokenize_prompt(state, model, messages) -> Option>`** - -Tokenizes the original prompt messages using the model's chat template and tokenizer: resolves model card from `state`, gets the tokenizer instance, builds a `PromptFormatter` from the model deployment card, renders messages through the chat template (same logic as the preprocessor), tokenizes the rendered string, and returns the token IDs. - -**`rl_promote_token_ids_in_response(json_val)`** - -Copies `response.nvext.completion_token_ids` to `response.choices[i].token_ids` for each choice. Bridges Dynamo's `nvext` convention with the field paths that Prime-RL/verifiers expects. - ---- - -## 7. Worker Engine Routes (Internal) - -Five engine route handlers are registered on each vLLM worker's system HTTP port (default `8081` local / `9090` in k8s). These are **internal** routes called by the Rust frontend's `/v1/rl/*` handlers -- not called directly by Prime-RL. - -| Route | Method | vLLM API called | Description | -|-------|--------|-----------------|-------------| -| `/engine/pause_generation` | POST | `engine_client.pause_generation()` | Drain in-flight requests, keep model loaded in GPU memory | -| `/engine/resume_generation` | POST | `engine_client.resume_generation()` | Resume accepting inference requests | -| `/engine/flush_cache` | POST | `engine_client.reset_prefix_cache()` | Invalidate prefix/KV cache (required before weight reload) | -| `/engine/update_weights_from_path` | POST | `engine_client.collective_rpc("reload_weights", ...)` | Load weights from filesystem (safetensors checkpoint) | -| `/engine/get_weight_version` | POST | `self._weight_version` | Return current weight version string | - -Both decode and prefill worker types register all 5 routes. Route signatures are compatible with SGLang's merged `#6094` routes for backend interoperability. - -### Registration (worker_factory.py) - -```python -runtime.register_engine_route("pause_generation", handler.pause_generation) -runtime.register_engine_route("resume_generation", handler.resume_generation) -runtime.register_engine_route("flush_cache", handler.flush_cache) -runtime.register_engine_route("update_weights_from_path", handler.update_weights_from_path) -runtime.register_engine_route("get_weight_version", handler.get_weight_version) -``` - -### publisher.py Crash Guard - -The `DynamoStatLoggerPublisher.record()` method includes a guard for `scheduler_stats is None`. This prevents an `AttributeError` crash during the transient window right after a weight reload when the vLLM engine's stats logger fires before the engine core has re-initialized its scheduler stats. - ---- - -## 8. Known Limitations - -| Limitation | Workaround | Notes | -|-----------|-----------|-------| -| `cache_salt` not supported -- returns 400 for requests with `cache_salt` in body | Set `[experimental] use_prefix_cache_salt = false` in prime-rl `orch.toml` | verifiers dev6+ defaults `use_prefix_cache_salt=True` | -| `prompt_token_ids` only injected for non-streaming responses | Use non-streaming mode for RL rollouts (the default) | Streaming final-chunk injection is planned | -| Weight version `"initial"` before first update | Do not depend on version string for correctness; use `/v1/rl/ready` for readiness | | -| NCCL weight broadcast is a no-op on Dynamo side | Use `type = "filesystem"` in `[weight_broadcast]` for all current deployments | | -| ~~`VLLM_USE_V1=0` required~~ **Resolved on vLLM 0.19.1** | Set `VLLM_USE_V1=1` on images shipping `VLLM_VER≥0.19.1` (current default). Keep `VLLM_USE_V1=0` only on legacy 0.18.x images where Meta-tensor crash with `--enforce-eager` still reproduces. | Verified under Run D (Qwen3.5-35B-A3B-FP8, 12 workers, batch=64) with V1 enabled and CUDA graphs for LoRA decode. | -| Filesystem weight broadcast scales poorly for large models | Acceptable for 0.6B (257ms load); marginal at 7B (~25s); impractical at 70B (~5 min) | RDMA pull transfer planned | - ---- - -## 9. Validation Results - -### Local (2x A6000, Qwen3-0.6B-Reverse-Text-SFT, 20 steps) - -| Metric | vLLM Baseline | Dynamo | Delta | -|--------|:-------------:|:------:|:-----:| -| Steps completed | 20/20 | 20/20 | -- | -| Peak reward | 0.798 | **0.825** | +3.4% | -| Final reward | 0.716 | **0.724** | +1.2% | -| `is_masked/mean` | 1.2% | **0.13%** | -92% (better) | -| Mismatch KL (final) | 0.0075 | **0.0056** | -25% (better) | -| Weight update cycles | 19 | 19 | -- | -| Mean weight cycle time | -- | 257.5ms | pause: 3.2ms, load: 249ms, resume: 4.8ms | - -W&B: https://wandb.ai/test232/prime-rl-parity-apr17 - -### Kubernetes (GB200, same model, 20 steps) - -| Metric | K8s | -|--------|:---:| -| Steps completed | 20/20 | -| Reward at step 13 | 0.714 (climbing) | -| Mismatch KL (steps 0-3) | 0.0007 - 0.0009 | -| Pods | 4 | -| All RL routes verified | Yes | - -The Rust RL API produces **better token alignment than native vLLM** (0.13% masked vs 1.2%). diff --git a/docs/dynamo-RL-api.md b/docs/dynamo-RL-api.md new file mode 100644 index 000000000000..af31731a6e46 --- /dev/null +++ b/docs/dynamo-RL-api.md @@ -0,0 +1,590 @@ +# Dynamo RL API + +**Branch:** `bis/dynamo-rl` (HEAD: `a2cc90da6d`) +**Pull request:** [ai-dynamo/dynamo#9131](https://github.com/ai-dynamo/dynamo/pull/9131) + +This document describes the RL training API surface on the Dynamo serving stack. The Dynamo Rust frontend exposes a small, focused set of endpoints that let an RL trainer (prime-rl, NeMo-rl, or any OpenAI-compatible client) drive a vLLM-served model through pause / weight-update / resume cycles, hot-swap LoRA adapters, and post pre-tokenized inputs on the standard chat-completions endpoint. + +## Table of Contents + +1. [Overview](#1-overview) +2. [Architecture](#2-architecture) +3. [Configuration](#3-configuration) +4. [API Reference](#4-api-reference) + - 4.1 Chat Completions (RL-enhanced + TITO) + - 4.2 RL Lifecycle (`/v1/rl/*`) +5. [Data Flow](#5-data-flow) +6. [Key Data Structures](#6-key-data-structures) +7. [Worker Engine Routes (Internal)](#7-worker-engine-routes-internal) +8. [Known Limitations](#8-known-limitations) +9. [What Changed vs. the Earlier Draft](#9-what-changed-vs-the-earlier-draft) + +--- + +## 1. Overview + +The Dynamo Rust frontend exposes: + +- A `/v1/rl/*` router for the full RL control-plane lifecycle (composite state, liveness probe, pause/resume, weight update, LoRA hot-swap) +- Token-level data injection (`prompt_token_ids`, `choices[i].token_ids`, `nvext.completion_token_ids`) on standard chat-completion responses +- Pre-tokenized prompt support on the standard `/v1/chat/completions` endpoint via the `prompt_token_ids` extension (no separate URI) + +Zero Python in the inference or admin data path. The Rust frontend handles every HTTP route; vLLM workers expose a small set of internal engine routes for pause/update/resume on the GPU. + +### Endpoint Summary + +| Capability | Endpoint | Method | Notes | +|---|---|---|---| +| Inference | `/v1/chat/completions` | POST | Standard OpenAI plus RL extras: `prompt_token_ids`, `stop_token_ids`, `allowed_token_ids`, `bad_words_token_ids`, `truncate_prompt_tokens`, `weight_version`, `nvext.{completion_token_ids,return_token_ids,return_routed_experts,return_prompt_logprobs}` | +| Composite state | `/v1/rl/state` | GET | Aggregated per-worker `{ready, engine_alive, pause_state, applied_weight_version, loras, workers}` | +| Liveness | `/v1/rl/liveness` | GET | Round-trips `engine_client.check_health()` so a wedged event loop surfaces 503 | +| Pause fleet | `/v1/rl/pause` | POST | `?mode=keep\|wait\|abort&clear_cache=bool` | +| Resume fleet | `/v1/rl/resume` | POST | | +| Update weights | `/v1/rl/update_weights` | POST | Typed body: `{weight_dir, weight_version?, reset_prefix_cache=true}` | +| Load LoRA adapter | `/v1/rl/load_lora_adapter` | POST | Filesystem-native PEFT-style hot-swap | +| Unload LoRA adapter | `/v1/rl/unload_lora_adapter` | POST | Idempotent | +| Legacy: health | `/v1/rl/health` | GET | Kept for back-compat; prefer `/v1/rl/state` | +| Legacy: ready | `/v1/rl/ready` | GET | Kept for back-compat; prefer `/v1/rl/state` | +| Legacy: weight_version | `/v1/rl/weight_version` | GET | Kept for back-compat; folded into `/v1/rl/state.applied_weight_version` | + +Endpoints intentionally **not** present (returned 404): + +| Removed | Reason | +|---|---| +| `/v1/chat/completions/tokens` | TITO collapsed into `/v1/chat/completions` via the `prompt_token_ids` top-level extension | +| `/v1/tokenize` | Owned by [#7699](https://github.com/ai-dynamo/dynamo/pull/7699) (NeMo-rl scope), not used by prime-rl | +| `/v1/detokenize` | Same as above | + +The handler functions and route helpers are kept in source under `#[allow(dead_code)]` so downstream code that still references them compiles; physical deletion is a follow-up cleanup commit. + +--- + +## 2. Architecture + +### Component Topology + +```mermaid +flowchart TD + subgraph prime_rl["prime-rl"] + orch["Orchestrator
(prime_rl.orchestrator)"] + trainer["Trainer
(prime_rl.trainer.rl.train)
torchrun --nproc-per-node=N"] + end + + subgraph dynamo["Dynamo Serving Stack"] + subgraph frontend["Frontend Pod (Rust, port 8000)"] + cc["/v1/chat/completions
+ prompt_token_ids extension
+ choices[].token_ids
+ stop_token_ids / allowed_token_ids / ..."] + rl["/v1/rl/* (admin)
state, liveness,
pause, resume, update_weights,
load_lora_adapter, unload_lora_adapter"] + end + subgraph worker["vLLM Worker Pod (Python, system port 9090)"] + eng["Engine routes:
get_state, liveness_probe,
pause_generation, resume_generation,
flush_cache, update_weights_from_path,
get_weight_version,
load_lora_adapter, unload_lora_adapter"] + gpu["GPU
Model Weights"] + end + end + + subgraph storage["Shared Storage (PVC)"] + pvc["safetensors checkpoints
+ adapter_model.safetensors / adapter_config.json"] + end + + orch -- "rollouts: POST /v1/chat/completions
(messages OR prompt_token_ids)" --> cc + orch -- "weight lifecycle:
pause → update_weights → resume" --> rl + rl -- "fan-out (concurrent)" --> eng + eng --> gpu + trainer -- "write checkpoint" --> pvc + eng -- "reload_weights
(collective_rpc)" --> pvc +``` + +### Key Design Decisions + +1. **Single entry point.** prime-rl points both `base_url` and `admin_base_url` at the Dynamo frontend. No separate admin service. +2. **Fan-out in Rust.** `/v1/rl/*` handlers fan out to all vLLM workers via `DYN_RL_WORKER_SYSTEM_URLS`. Supports DP > 1 without prime-rl needing to discover workers. Returns HTTP 200 only when every worker responds OK; otherwise 502 with per-worker details. +3. **Token IDs as a response extension.** When `DYN_ENABLE_RL=true`, `prompt_token_ids` and `choices[i].token_ids` are injected into every non-streaming response automatically. `nvext.completion_token_ids` is the canonical Dynamo location; the choice-level field is a compatibility shim for prime-rl/verifiers. +4. **Backward compatible.** All new response fields use `#[serde(skip_serializing_if = "Option::is_none")]`. Clients that don't set `DYN_ENABLE_RL` see standard OpenAI-compatible responses with no extra fields. +5. **TITO without a URI fork.** Pre-tokenized input is a top-level extension on the standard chat-completions request (`prompt_token_ids`), not a separate `/v1/chat/completions/tokens` URI. Bridges until vLLM 0.20+ accepts the same extension natively. + +--- + +## 3. Configuration + +### Environment Variables (Frontend) + +| Variable | Default | Description | +|---|---|---| +| `DYN_ENABLE_RL` | `false` | Master switch. Mounts `/v1/rl/*` routes and auto-injects token IDs in chat completion responses. | +| `DYN_RL_WORKER_SYSTEM_URLS` | `http://localhost:8081` | Comma-separated vLLM worker system HTTP base URLs for fan-out. | +| `DYN_RL_LIVENESS_TIMEOUT_MS` | `5000` | Per-worker timeout for `/v1/rl/liveness`. | + +### Environment Variables (Worker) + +| Variable | Default | Description | +|---|---|---| +| `DYN_SYSTEM_PORT` | `8081` (local) / `9090` (k8s) | Worker's system HTTP port where engine routes are registered. | + +### prime-rl `orch.toml` (representative) + +```toml +[client] +base_url = ["http://:8000/v1"] +admin_base_url = ["http://:8000/v1/rl"] +backend = "vllm" +skip_model_check = true + +[weight_broadcast] +type = "filesystem" # NCCL is a Dynamo-side no-op today; see §8 +``` + +### Kubernetes (DGD frontend env) + +```yaml +- name: DYN_ENABLE_RL + value: "true" +- name: DYN_RL_WORKER_SYSTEM_URLS + value: "http://-vllmworker..svc.cluster.local:9090" +``` + +--- + +## 4. API Reference + +### 4.1 Chat Completions (RL-enhanced + TITO) + +``` +POST /v1/chat/completions +``` + +Standard OpenAI chat completions. When `DYN_ENABLE_RL=true`, every non-streaming response is automatically enriched with token IDs. + +#### RL request extensions + +The following top-level fields are accepted in addition to the OpenAI schema. They are validated by `validate.rs::PASSTHROUGH_EXTRA_FIELDS` and forwarded to the engine where vLLM 0.20+ accepts them natively: + +| Field | Type | Purpose | +|---|---|---| +| `prompt_token_ids` | `u32[]` | Pre-tokenized prompt (TITO). Mutually exclusive with non-empty `messages` (except for the legacy `nvext.token_data` renderer-mode placeholder, which still coexists). | +| `stop_token_ids` | `u32[]` | Plumbed into `SamplingParams.stop_token_ids`; forces stop on any of these IDs. Malformed input (e.g. `"not-an-array"`) returns a typed 400. | +| `allowed_token_ids` | `u32[]` | Restricts decoding to this set. | +| `bad_words_token_ids` | `u32[]` | Suppresses these IDs. | +| `truncate_prompt_tokens` | `int` | Truncates prompt to N most-recent tokens. | +| `weight_version` | `string` | Routing filter for IS-correction strict-version mode (today accepted; routing follow-up). | +| `cache_salt` | `string` | KV prefix-cache isolation hint. (Coordinated with #8197 → `X-Tenant-Id` header; both forms accepted for one release.) | +| `return_token_ids` | `bool` | Per-request opt-in for `nvext.completion_token_ids` (also achievable via `extra_fields`). | +| `return_routed_experts` | `bool` | MoE expert-routing replay capture. | +| `return_prompt_logprobs` | `bool` | Streaming logprobs for input tokens. | + +In the legacy `nvext` channel, `nvext.token_data` (renderer-mode pre-tokenized prompt) and `nvext.extra_fields = ["token_ids", "completion_token_ids", ...]` continue to work unchanged. + +#### TITO via `prompt_token_ids` + +```bash +curl -s -X POST http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "Qwen/Qwen3-0.6B", + "messages": [], + "prompt_token_ids": [151644, 8948, 198, 151645, 198, 151644, 872, 198, + 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, + 198, 151644, 77091, 198], + "stop_token_ids": [151643], + "max_tokens": 64 + }' +``` + +Validation rules: + +- `messages` may be empty when `prompt_token_ids` is non-empty (the chat template short-circuits). +- `messages` non-empty + `prompt_token_ids` non-empty → 400 mutual-exclusion error (canonical channel only). +- `nvext.token_data` + non-empty `messages` → still allowed (renderer-mode placeholder pattern from `verifiers.dynamo_chat_nvext` keeps working). + +#### Sample response (non-streaming, `DYN_ENABLE_RL=true`) + +```jsonc +{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "model": "Qwen/Qwen3-0.6B", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "dlrow olleh"}, + "finish_reason": "stop", + "logprobs": {"content": [...]}, + "token_ids": [67, 1245, 893, 15] + }], + "prompt_token_ids": [151644, 8948, 198, ...], + "usage": {"prompt_tokens": 21, "completion_tokens": 4, "total_tokens": 25}, + "nvext": { + "completion_token_ids": [67, 1245, 893, 15] + } +} +``` + +#### Response field reference + +| Field | JSON path | Description | +|---|---|---| +| `prompt_token_ids` | `response.prompt_token_ids` | Promoted by `rl_tokenize_prompt`: messages → tokenizer (model chat template) → token IDs. | +| `token_ids` | `response.choices[i].token_ids` | Per-choice output token IDs, promoted by `rl_promote_token_ids_in_response` from `nvext.completion_token_ids`. | +| `completion_token_ids` | `response.nvext.completion_token_ids` | Canonical Dynamo location; accumulated across SSE chunks by `DeltaGenerator`. | + +**Why two locations?** prime-rl/verifiers reads `response.prompt_token_ids` and `choices[i].token_ids`; Dynamo natively emits in `nvext.completion_token_ids`. The Rust post-processor promotes the latter to the former. + +**Invariant:** `len(completion_token_ids) == len(logprobs.content)`. + +#### Streaming (SSE) + +Intermediate chunks carry `delta.content` only. Token IDs appear exclusively on the **final chunk** (the one with a non-null `finish_reason`). + +--- + +### 4.2 RL Lifecycle (`/v1/rl/*`) + +Mounted only when `DYN_ENABLE_RL=true`. All non-trivial routes fan out to the worker URLs in `DYN_RL_WORKER_SYSTEM_URLS`. + +#### `GET /v1/rl/state` — composite read-only + +Single endpoint that returns everything prime-rl needs to make a decision. Aggregates `get_state` per-worker payloads. + +```bash +curl -s http://localhost:8000/v1/rl/state +``` + +```jsonc +// 200 +{ + "ready": true, + "ingress_alive": true, + "engine_alive": true, + "pause_state": "running", // or "paused" | "mixed" + "applied_weight_version": "step_5", // null when workers disagree + "loras": [ + {"name": "r16-a32", "loaded_on": [0, 1]} + ], + "workers": [, ...] +} + +// 503 — no workers registered +{"ready": false, "ingress_alive": true, "engine_alive": false, "pause_state": "running", + "applied_weight_version": null, "loras": [], "workers": [], + "status": "error", "message": "no workers registered"} +``` + +`ready = ingress_alive AND engine_alive AND len(workers) > 0`. `ingress_alive` is unconditionally `true` because reaching this handler proves the frontend HTTP listener is up. + +#### `GET /v1/rl/liveness` — deep liveness probe + +Round-trips `engine_client.check_health()` per worker so a wedged event loop or hung NCCL collective surfaces as 503. Override timeout via `DYN_RL_LIVENESS_TIMEOUT_MS` (default 5000). + +```bash +curl -s http://localhost:8000/v1/rl/liveness +``` + +```jsonc +// 200 +{"status": "ok", "alive": true, "workers": [{"alive": true}, ...]} + +// 503 — at least one worker hung past timeout +{"status": "error", "alive": false, "workers": [{"alive": false, "error": "timeout"}]} +``` + +#### `POST /v1/rl/pause` — 3-mode pause + cache control + +Query parameters (or JSON body): + +| Param | Type | Default | Effect | +|---|---|---|---| +| `mode` | `keep` \| `wait` \| `abort` | `keep` | `keep`: drain in-flight (legacy behaviour). `wait`: same as `keep` but block on completion. `abort`: trigger `collective_rpc(abort_all_requests)` on the engine (graceful warn-fallback on vLLM 0.19 where that RPC isn't implemented). | +| `clear_cache` | `bool` | `false` | If `true`, calls `reset_prefix_cache` after the pause completes. | + +```bash +curl -s -X POST 'http://localhost:8000/v1/rl/pause?mode=abort&clear_cache=true' +``` + +400 on unknown `mode`: + +```json +{"status": "error", "message": "Invalid mode 'foo'; expected one of keep|wait|abort"} +``` + +#### `POST /v1/rl/resume` + +Resumes generation on all workers. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/resume -H 'Content-Type: application/json' -d '{}' +``` + +```json +{"status": "ok", "workers": [{"status": "ok", "message": "Engine resumed"}]} +``` + +#### `POST /v1/rl/update_weights` — typed body + +Body schema: + +```jsonc +{ + "weight_dir": "/data/outputs/.../broadcasts/step_5", // required (string | null) + "weight_version": "step_5", // optional, defaults to basename(weight_dir) + "reset_prefix_cache": true // optional, default true +} +``` + +Behaviour: + +- `weight_dir = "/path/..."` → fan out `update_weights_from_path` to every worker. Each worker calls `engine_client.collective_rpc("reload_weights", kwargs={"weights_path": path})` (vLLM's in-place layerwise load). +- `weight_dir = null` → NCCL mode. Dynamo logs `"NCCL mode, no-op on Dynamo side"` and returns 200 immediately. The actual GPU↔GPU transfer happens out of band on a pre-established NCCL group between trainer and inference workers. **Today the inference-side NCCL receiver is not wired into `dynamo.vllm`**; see §8. +- `reset_prefix_cache = true` → flush prefix/KV cache after the load (default). + +```bash +# Filesystem mode +curl -s -X POST http://localhost:8000/v1/rl/update_weights \ + -H 'Content-Type: application/json' \ + -d '{"weight_dir": "/data/outputs/run_default/broadcasts/step_5"}' + +# NCCL mode (Dynamo no-op — see §8) +curl -s -X POST http://localhost:8000/v1/rl/update_weights \ + -H 'Content-Type: application/json' \ + -d '{"weight_dir": null}' +``` + +```jsonc +// 200 +{ + "status": "ok", + "applied_weight_version": "step_5", + "workers": [ + {"status": "ok", "message": "Weights loaded from /data/...", "version": "step_5"} + ] +} + +// 502 (some worker failed) +{"status": "error", "stage": "update_weights_from_path", + "workers": [{"status": "ok", ...}, {"status": "error", "message": "..."}]} +``` + +#### `POST /v1/rl/load_lora_adapter` + +Hot-load / hot-swap a LoRA adapter from a filesystem path. Adapter dir must contain PEFT-style `adapter_model.safetensors` and `adapter_config.json`. + +- First call for a given `lora_name` → `add_lora` + publish a ModelDeploymentCard so subsequent inference with `model=` routes here. +- Subsequent calls (hot-swap) → `remove_lora(old_id)` → `add_lora` with new weights → `reset_prefix_cache`. MDC is left in place. + +Pair with `/v1/rl/pause` + `/v1/rl/resume` for full drain-swap-resume. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/load_lora_adapter \ + -H 'Content-Type: application/json' \ + -d '{"lora_name": "r16-a32", "lora_path": "/data/outputs/run_default/broadcasts/step_5"}' +``` + +```jsonc +// 200 +{"status": "ok", + "workers": [{"status": "ok", "message": "LoRA adapter 'r16-a32' loaded from /data/...", + "lora_name": "r16-a32", "lora_id": 788776416, "hot_swap": false}]} + +// 400 — missing/empty fields +{"status": "error", + "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)"} +``` + +vLLM worker requirements: started with `--enable-lora --max-lora-rank R --max-loras N`. For prime-rl's single-adapter loop, `--max-loras 1` is sufficient. + +#### `POST /v1/rl/unload_lora_adapter` + +Remove an adapter by name. Idempotent — unloading an already-absent adapter returns `status: ok`. + +```bash +curl -s -X POST http://localhost:8000/v1/rl/unload_lora_adapter \ + -H 'Content-Type: application/json' \ + -d '{"lora_name": "r16-a32"}' +``` + +#### Legacy endpoints (kept for back-compat) + +`GET /v1/rl/health`, `GET /v1/rl/ready`, `GET /v1/rl/weight_version` — same shapes as the previous draft. To be removed in Phase 5 of `docs/design-docs/rl-support.md` once prime-rl's AdminAPI migrates to `/v1/rl/state`. + +--- + +## 5. Data Flow + +### 5.1 Rollout (inference) path + +```mermaid +sequenceDiagram + participant Orch as prime-rl Orchestrator + participant FE as Dynamo Frontend (Rust) + participant Worker as vLLM Worker (GPU) + + Orch->>FE: POST /v1/chat/completions
{messages OR prompt_token_ids, stop_token_ids?, ...} + Note over FE: validate.rs: PASSTHROUGH_EXTRA_FIELDS
plumbs RL extras into SamplingParams
If DYN_ENABLE_RL=true, inject
nvext.extra_fields = ["token_ids","completion_token_ids"]
force logprobs=true + FE->>Worker: forward request (TCP/NATS) + Worker-->>FE: SSE chunks (delta.content + delta.token_ids) + Note over FE: DeltaGenerator accumulates
completion_token_ids; serde failures
now log tracing::warn! (no silent drops) + Worker-->>FE: final chunk (finish_reason + nvext.completion_token_ids) + Note over FE: rl_tokenize_prompt(messages) -> prompt_token_ids
rl_promote_token_ids_in_response()
nvext.completion_token_ids -> choices[i].token_ids + FE-->>Orch: enriched response +``` + +### 5.2 Weight update path + +```mermaid +sequenceDiagram + participant Trainer as prime-rl Trainer + participant PVC as Shared Storage + participant Orch as prime-rl Orchestrator + participant FE as Dynamo Frontend (Rust) + participant W1 as vLLM Worker 1 + participant W2 as vLLM Worker 2 + + Trainer->>PVC: write checkpoint
/data/outputs/.../step_N/*.safetensors + Orch->>FE: POST /v1/rl/pause?mode=keep + FE->>W1: pause_generation + FE->>W2: pause_generation + W1-->>FE: ok + W2-->>FE: ok + FE-->>Orch: {status: ok} + Orch->>FE: POST /v1/rl/update_weights
{weight_dir: /data/.../step_N, reset_prefix_cache: true} + FE->>W1: update_weights_from_path + FE->>W2: update_weights_from_path + Note over W1,W2: collective_rpc("reload_weights")
vLLM in-place layerwise load + W1-->>FE: {status: ok, version: step_N} + W2-->>FE: {status: ok, version: step_N} + FE-->>Orch: {status: ok, applied_weight_version: step_N} + Orch->>FE: POST /v1/rl/resume + FE->>W1: resume_generation + FE->>W2: resume_generation + W1-->>FE: ok + W2-->>FE: ok + FE-->>Orch: {status: ok} +``` + +NCCL mode: `weight_dir=null` returns 200 immediately; the actual GPU↔GPU broadcast must be coordinated out of band (see §8 for the wiring gap). + +### 5.3 LoRA hot-swap + +```mermaid +sequenceDiagram + participant Orch as prime-rl Orchestrator + participant FE as Dynamo Frontend + participant W1 as vLLM Worker 1 + + Orch->>FE: POST /v1/rl/pause?mode=keep + FE-->>Orch: ok + Orch->>FE: POST /v1/rl/load_lora_adapter
{lora_name, lora_path} + Note over FE,W1: First call: add_lora + publish MDC
Subsequent: remove_lora(old) → add_lora → reset_prefix_cache + FE-->>Orch: {status: ok, lora_id, hot_swap} + Orch->>FE: POST /v1/rl/resume + FE-->>Orch: ok +``` + +--- + +## 6. Key Data Structures + +### `NvCreateChatCompletionRequest` (Rust, request side) + +Custom fields (top-level, beyond stock OpenAI): + +| Field | `serde` behaviour | Notes | +|---|---|---| +| `prompt_token_ids` | passthrough | Canonical TITO channel. Read by `NvExtProvider::get_pretokenized_input`. | +| `stop_token_ids` | passthrough | Read by `OpenAIStopConditionsProvider::get_stop_token_ids() → Result>>`. Malformed input returns 400. | +| `allowed_token_ids`, `bad_words_token_ids`, `truncate_prompt_tokens` | passthrough | Plumbed into `SamplingParams`. | +| `weight_version`, `cache_salt`, `return_*` | passthrough | See §4.1. | +| `tokens` | `skip_serializing` | Legacy compat — caught and ignored. | +| `return_token_ids` | `skip_serializing` | Legacy compat — use `nvext.extra_fields` or `DYN_ENABLE_RL`. | + +### `NvCreateChatCompletionResponse` (Rust, response side) + +```rust +NvCreateChatCompletionResponse { + inner: CreateChatCompletionResponse, // standard OpenAI + nvext: Option, // NvExtResponse JSON + prompt_token_ids: Option>, // RL only +} +``` + +### `NvExtResponse` + +Serialized as `nvext` on each SSE chunk and the unary response body: + +```rust +NvExtResponse { + worker_id: Option, + timing: Option, + token_ids: Option>, // GAIE Stage 1 prompt + routed_experts: Option, + completion_token_ids: Option>, // RL output, final chunk only +} +``` + +### `RlUpdateWeightsBody` + +```rust +struct RlUpdateWeightsBody { + weight_dir: Option, // null => NCCL mode + weight_version: Option, // defaults to basename(weight_dir) + #[serde(default = "default_reset_prefix_cache")] + reset_prefix_cache: bool, // default true +} +``` + +### `DeltaGenerator` (streaming pipeline) + +Tracks `accumulated_completion_token_ids: Vec` per request. Activated when `extra_fields` includes `"completion_token_ids"` (auto-set under `DYN_ENABLE_RL`). Emits the full vector in `nvext.completion_token_ids` on the final chunk. + +### Post-processing helpers + +- `rl_tokenize_prompt(state, model, messages) -> Option>` — resolves the model card, builds `PromptFormatter`, renders messages through the chat template, tokenizes, returns IDs. +- `rl_promote_token_ids_in_response(json_val)` — copies `nvext.completion_token_ids` to `choices[i].token_ids` per choice. Doc-block now lives on this function (commit `d295ebc6` move). + +--- + +## 7. Worker Engine Routes (Internal) + +Registered on each vLLM worker's system HTTP port (default `8081` local / `9090` k8s) by `worker_factory.py::register_engine_routes()`. Called by Rust `/v1/rl/*` handlers — not by prime-rl directly. + +| Route | vLLM API called | Used by | +|---|---|---| +| `pause_generation` | `engine_client.pause_generation()` (+ `abort_all_requests` when mode=abort) | `/v1/rl/pause` | +| `resume_generation` | `engine_client.resume_generation()` | `/v1/rl/resume` | +| `flush_cache` | `engine_client.reset_prefix_cache()` | `/v1/rl/update_weights` (when `reset_prefix_cache=true`) | +| `update_weights_from_path` | `collective_rpc("reload_weights", weights_path=...)` | `/v1/rl/update_weights` | +| `get_weight_version` | reads `self._weight_version` | `/v1/rl/weight_version` (legacy) | +| `get_state` | composite per-worker snapshot (engine_alive, pause_state, applied_weight_version, loras) | `/v1/rl/state` | +| `liveness_probe` | round-trips `engine_client.check_health()` so a wedged event loop returns 503 | `/v1/rl/liveness` | +| `load_lora_adapter` | `add_lora`, `remove_lora` | `/v1/rl/load_lora_adapter` | +| `unload_lora_adapter` | `remove_lora` + MDC unregister | `/v1/rl/unload_lora_adapter` | + +### `publisher.py` crash guard + +`DynamoStatLoggerPublisher.record()` guards against `scheduler_stats is None`. This prevents an `AttributeError` crash during the transient window right after a weight reload, when the vLLM stats logger fires before the engine core has re-initialized its scheduler. + +--- + +## 8. Known Limitations + +| Limitation | Workaround | Notes | +|---|---|---| +| **NCCL mode is a no-op on Dynamo's vLLM side.** `update_weights` with `weight_dir=null` returns 200 immediately, but `dynamo.vllm` does not load `NCCLWeightBroadcastReceiver` as a vLLM worker class — so the trainer's NCCL broadcast has no peer on the inference side. Trainer's `init_process_group` times out at `weight_broadcast.timeout` (default 1200 s). | Use `weight_broadcast.type = "filesystem"`. The `dynamo.sglang` backend ships `update_weights_from_distributed` natively and does work over NCCL. | Tracked at the prime-rl side: orchestrator already POSTs `/v1/rl/init_broadcaster` which dynamo.vllm doesn't expose (logs `route does not exist. Skipping NCCL broadcast initialization.` on the orch side). Wiring is the next workstream. | +| `cache_salt` not yet honored end-to-end | Set `[experimental] use_prefix_cache_salt = false` in prime-rl `orch.toml`; or send the equivalent `X-Tenant-Id` header (#8197). | Field is whitelisted (`PASSTHROUGH_EXTRA_FIELDS`) so requests don't 400; routing-side filter is a follow-up. | +| `prompt_token_ids` only injected for non-streaming responses | Use non-streaming mode for RL rollouts (the default). | Streaming final-chunk injection is planned. | +| Weight version `"initial"` before first update | Use `/v1/rl/state.applied_weight_version` for source-of-truth; don't rely on the version string for correctness. | | +| Filesystem weight broadcast scales poorly for large models | Ok for 0.6B (~250 ms load); marginal at 7B (~25 s); ~150 s at 30B-A3B BF16; impractical at 70B+. | RDMA / NCCL-receive on dynamo.vllm planned. | + +--- + +## 9. What Changed vs. the Earlier Draft + +For readers who know the previous `Dynamo-RL-api-draft.md`: + +| Old | New | +|---|---| +| `/v1/chat/completions/tokens` (TITO URI fork) | TITO collapsed into `/v1/chat/completions` via the `prompt_token_ids` top-level extension. URI returns 404. | +| `/v1/tokenize`, `/v1/detokenize` | Removed (return 404). Owned by [#7699](https://github.com/ai-dynamo/dynamo/pull/7699), out of scope for this surface. | +| `POST /v1/rl/pause` (no params) | `POST /v1/rl/pause?mode=keep|wait|abort&clear_cache=bool` (3-mode). | +| `POST /v1/rl/update_weights` (string body) | Typed body `{weight_dir, weight_version?, reset_prefix_cache=true}`; response carries `applied_weight_version`. | +| `/v1/rl/health` + `/v1/rl/ready` + `/v1/rl/weight_version` | All three kept for back-compat, **plus** new composite `GET /v1/rl/state`. New `GET /v1/rl/liveness` does a deep `check_health()` round-trip. | +| `PASSTHROUGH_EXTRA_FIELDS = [cache_salt]` | Now: `cache_salt`, `prompt_token_ids`, `weight_version`, `return_routed_experts`, `return_token_ids`, `return_prompt_logprobs`, `stop_token_ids`, `bad_words_token_ids`, `allowed_token_ids`, `truncate_prompt_tokens`. Full `SamplingParams` parity. | +| `get_stop_token_ids() -> Option>` (silent drop on bad input) | `get_stop_token_ids() -> Result>>`. Malformed input returns a typed 400. | +| nvext serde failures: `if let Ok(json) = serde_json::to_value(...) { ... }` (silent drop) | `match { Ok(json) => ..., Err(e) => tracing::warn!(...) }`. No silent corruption of promoted token IDs / weight version. | +| Engine routes: 5–7 (pause, resume, flush_cache, update_weights_from_path, get_weight_version, +load_lora, +unload_lora) | 9. Added: `get_state`, `liveness_probe`. | From 8e0e019bfbc7f4a1c6da0a2c4b708f3c4d98281d Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 7 May 2026 16:00:00 -0700 Subject: [PATCH 06/18] update --- docs/dynamo-RL-api.md | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/docs/dynamo-RL-api.md b/docs/dynamo-RL-api.md index af31731a6e46..54e82770b155 100644 --- a/docs/dynamo-RL-api.md +++ b/docs/dynamo-RL-api.md @@ -1,9 +1,6 @@ # Dynamo RL API -**Branch:** `bis/dynamo-rl` (HEAD: `a2cc90da6d`) -**Pull request:** [ai-dynamo/dynamo#9131](https://github.com/ai-dynamo/dynamo/pull/9131) - -This document describes the RL training API surface on the Dynamo serving stack. The Dynamo Rust frontend exposes a small, focused set of endpoints that let an RL trainer (prime-rl, NeMo-rl, or any OpenAI-compatible client) drive a vLLM-served model through pause / weight-update / resume cycles, hot-swap LoRA adapters, and post pre-tokenized inputs on the standard chat-completions endpoint. +This document describes the RL training API surface on the Dynamo serving stack. The Dynamo Rust frontend exposes a small, focused set of endpoints that let an RL trainer drive a vLLM-served model through pause / weight-update / resume cycles, hot-swap LoRA adapters, and post pre-tokenized inputs on the standard chat-completions endpoint. ## Table of Contents @@ -52,7 +49,7 @@ Endpoints intentionally **not** present (returned 404): | Removed | Reason | |---|---| | `/v1/chat/completions/tokens` | TITO collapsed into `/v1/chat/completions` via the `prompt_token_ids` top-level extension | -| `/v1/tokenize` | Owned by [#7699](https://github.com/ai-dynamo/dynamo/pull/7699) (NeMo-rl scope), not used by prime-rl | +| `/v1/tokenize` | Out of scope for this surface (covered by a separate PR) | | `/v1/detokenize` | Same as above | The handler functions and route helpers are kept in source under `#[allow(dead_code)]` so downstream code that still references them compiles; physical deletion is a follow-up cleanup commit. @@ -65,9 +62,9 @@ The handler functions and route helpers are kept in source under `#[allow(dead_c ```mermaid flowchart TD - subgraph prime_rl["prime-rl"] - orch["Orchestrator
(prime_rl.orchestrator)"] - trainer["Trainer
(prime_rl.trainer.rl.train)
torchrun --nproc-per-node=N"] + subgraph rl_client["RL Trainer (external)"] + orch["Orchestrator
(rollouts + admin calls)"] + trainer["Trainer
(torchrun, FSDP/EP/etc.)"] end subgraph dynamo["Dynamo Serving Stack"] From dc62cb761f8526f55f68a5548d45a363f4c5eff2 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 7 May 2026 17:14:47 -0700 Subject: [PATCH 07/18] chore(rl): scrub internal review markers from PR-added comments + docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleanup-only — no behavior change. Strips review-tracker noise that accumulated on top of PR-added text during iteration: - "Closes hhzhang16 HH-19/HH-21/HH-22/HH-23/HH-25/HH-26/HH-27" - "CR-8 / CR-9 / CR-10 closure" prefixes on serde-error / doc-attach fixes - Branch-name references: bis/parity-tokenize-tcp, bis/prime-rl-merged - Internal PR numbers: #6094, #7699, #8197, #9141 - Phase numbers from internal design docs (rl-support.md Phase 1/4/5) - "prime-rl" mentions in narrative copy and mermaid diagrams → generic "RL trainer / RL orchestrator / external client" Technical content (semantics, invariants, why-this-exists rationale) preserved everywhere; only the internal-process scaffolding is removed. Scope verification: every removed line is one this branch ADDED (diff main..HEAD shows the removed text on a `+` line). No edits land on pre-existing main-branch comments. Specifically reverted the nvext.rs cleanup attempt — its target lines (GAIE Stage 1/2, SGLang-specific) live on main, not in this PR's diff. Files touched: components/src/dynamo/vllm/handlers.py +9 -10 components/src/dynamo/vllm/worker_factory.py +6 -4 docs/dynamo-RL-api.md +19 -32 lib/llm/src/http/service/openai.rs +32 -34 lib/llm/src/protocols/openai/chat_completions/delta.rs +4 -4 lib/llm/src/protocols/openai/completions/delta.rs +3 -3 lib/llm/src/protocols/openai/validate.rs +20 -20 cargo check -p dynamo-llm: clean (1 pre-existing benign warning). --- components/src/dynamo/vllm/handlers.py | 19 +++--- components/src/dynamo/vllm/worker_factory.py | 10 +-- docs/dynamo-RL-api.md | 56 ++++++---------- lib/llm/src/http/service/openai.rs | 66 +++++++++---------- .../openai/chat_completions/delta.rs | 8 +-- .../src/protocols/openai/completions/delta.rs | 6 +- lib/llm/src/protocols/openai/validate.rs | 40 +++++------ 7 files changed, 93 insertions(+), 112 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index c12ee1743e95..5691f46be692 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -826,20 +826,19 @@ async def stop_profile(self, body: dict) -> dict: return {"status": "error", "message": str(e)} # ── RL weight lifecycle engine routes ────────────────────────────── - # Signatures kept compatible with SGLang's merged #6094 routes so - # a single admin coordinator can talk to either backend. + # Signatures intentionally line up with the SGLang RL admin routes so a + # single admin coordinator can talk to either backend. async def pause_generation(self, body: dict) -> dict: """Pause the engine: drain in-flight requests, keep model loaded. - Called by RL admin coordinator before weight updates. + Called by the RL admin coordinator before weight updates. Uses engine_client.pause_generation() directly -- does NOT sleep (no GPU memory release) and does NOT unregister from discovery. - Body (all optional, all default to the prime-rl client convention): + Body (all optional): - mode: "keep" | "wait" | "abort" (default "keep" — drain in-flight) - clear_cache: bool (default False) - Closes hhzhang16 review HH-21 (3-mode pause). """ body = body or {} mode = body.get("mode", "keep") @@ -907,8 +906,8 @@ async def liveness_probe(self, body: dict) -> dict: short timeout (default 5s). Returning ``alive: True`` requires the engine_client IPC roundtrip to complete: a hung event loop, deadlocked worker, or wedged engine will time out at the frontend instead of - returning a stale ``OK``. Closes hhzhang16 HH-23 (health probe returns - OK no matter what). + returning a stale ``OK`` (which is what the legacy ``/v1/rl/health`` + does — that endpoint is just a frontend-process check). """ body = body or {} try: @@ -931,9 +930,9 @@ async def get_state(self, body: dict) -> dict: """Composite per-worker state snapshot for ``GET /v1/rl/state``. The Rust frontend aggregates these per-worker payloads into the - fleet-wide ``RlStateResponse``. Closes hhzhang16 HH-19/HH-25/HH-27 - (single state endpoint replacing /health + /ready + /weight_version, - RL-specific, weight_version folded in). + fleet-wide ``RlStateResponse`` — a single composite that replaces the + separate ``/v1/rl/health`` + ``/v1/rl/ready`` + ``/v1/rl/weight_version`` + endpoints with one RL-scoped readiness call. """ body = body or {} try: diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index 9948fdad2f81..5d122e15b3b9 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -676,8 +676,10 @@ def register_engine_routes( runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("scale_elastic_ep", handler.scale_elastic_ep) - # RL weight-lifecycle routes (parity with SGLang #6094) — driven by the + # RL weight-lifecycle routes — driven by the # /v1/rl/{pause,resume,update_weights} bracket in the Rust frontend. + # Names line up with the SGLang RL admin routes so a single admin + # coordinator can talk to either backend. runtime.register_engine_route("pause_generation", handler.pause_generation) runtime.register_engine_route("resume_generation", handler.resume_generation) runtime.register_engine_route("flush_cache", handler.flush_cache) @@ -688,12 +690,12 @@ def register_engine_routes( # RL state + liveness — drive /v1/rl/state and /v1/rl/liveness in the # Rust frontend. /v1/rl/state aggregates these per-worker snapshots - # into the composite RlStateResponse (rl-support.md Phase 1). + # into the composite RlStateResponse. runtime.register_engine_route("get_state", handler.get_state) runtime.register_engine_route("liveness_probe", handler.liveness_probe) - # RL LoRA adapter routes: filesystem-native hot-swap used by Prime-RL - # every training step to broadcast new adapter weights into the engine. + # RL LoRA adapter routes: filesystem-native hot-swap used by RL + # trainers every step to broadcast new adapter weights into the engine. runtime.register_engine_route("load_lora_adapter", handler.load_lora_adapter) runtime.register_engine_route( "unload_lora_adapter", handler.unload_lora_adapter diff --git a/docs/dynamo-RL-api.md b/docs/dynamo-RL-api.md index 54e82770b155..72ba903a4f9d 100644 --- a/docs/dynamo-RL-api.md +++ b/docs/dynamo-RL-api.md @@ -14,7 +14,6 @@ This document describes the RL training API surface on the Dynamo serving stack. 6. [Key Data Structures](#6-key-data-structures) 7. [Worker Engine Routes (Internal)](#7-worker-engine-routes-internal) 8. [Known Limitations](#8-known-limitations) -9. [What Changed vs. the Earlier Draft](#9-what-changed-vs-the-earlier-draft) --- @@ -92,11 +91,11 @@ flowchart TD ### Key Design Decisions -1. **Single entry point.** prime-rl points both `base_url` and `admin_base_url` at the Dynamo frontend. No separate admin service. -2. **Fan-out in Rust.** `/v1/rl/*` handlers fan out to all vLLM workers via `DYN_RL_WORKER_SYSTEM_URLS`. Supports DP > 1 without prime-rl needing to discover workers. Returns HTTP 200 only when every worker responds OK; otherwise 502 with per-worker details. -3. **Token IDs as a response extension.** When `DYN_ENABLE_RL=true`, `prompt_token_ids` and `choices[i].token_ids` are injected into every non-streaming response automatically. `nvext.completion_token_ids` is the canonical Dynamo location; the choice-level field is a compatibility shim for prime-rl/verifiers. +1. **Single entry point.** The trainer points both `base_url` and `admin_base_url` at the Dynamo frontend. No separate admin service. +2. **Fan-out in Rust.** `/v1/rl/*` handlers fan out to all vLLM workers via `DYN_RL_WORKER_SYSTEM_URLS`. Supports DP > 1 without the client needing to discover workers. Returns HTTP 200 only when every worker responds OK; otherwise 502 with per-worker details. +3. **Token IDs as a response extension.** When `DYN_ENABLE_RL=true`, `prompt_token_ids` and `choices[i].token_ids` are injected into every non-streaming response automatically. `nvext.completion_token_ids` is the canonical Dynamo location; the choice-level field is a compatibility shim for clients that read tokens from the choice object. 4. **Backward compatible.** All new response fields use `#[serde(skip_serializing_if = "Option::is_none")]`. Clients that don't set `DYN_ENABLE_RL` see standard OpenAI-compatible responses with no extra fields. -5. **TITO without a URI fork.** Pre-tokenized input is a top-level extension on the standard chat-completions request (`prompt_token_ids`), not a separate `/v1/chat/completions/tokens` URI. Bridges until vLLM 0.20+ accepts the same extension natively. +5. **TITO without a URI fork.** Pre-tokenized input is a top-level extension on the standard chat-completions request (`prompt_token_ids`), not a separate `/v1/chat/completions/tokens` URI. Aligns with vLLM 0.20+ which accepts the same extension natively. --- @@ -116,7 +115,7 @@ flowchart TD |---|---|---| | `DYN_SYSTEM_PORT` | `8081` (local) / `9090` (k8s) | Worker's system HTTP port where engine routes are registered. | -### prime-rl `orch.toml` (representative) +### Sample trainer config ```toml [client] @@ -162,7 +161,7 @@ The following top-level fields are accepted in addition to the OpenAI schema. Th | `bad_words_token_ids` | `u32[]` | Suppresses these IDs. | | `truncate_prompt_tokens` | `int` | Truncates prompt to N most-recent tokens. | | `weight_version` | `string` | Routing filter for IS-correction strict-version mode (today accepted; routing follow-up). | -| `cache_salt` | `string` | KV prefix-cache isolation hint. (Coordinated with #8197 → `X-Tenant-Id` header; both forms accepted for one release.) | +| `cache_salt` | `string` | KV prefix-cache isolation hint. The equivalent `X-Tenant-Id` request header is also accepted; the header takes precedence when both are present. | | `return_token_ids` | `bool` | Per-request opt-in for `nvext.completion_token_ids` (also achievable via `extra_fields`). | | `return_routed_experts` | `bool` | MoE expert-routing replay capture. | | `return_prompt_logprobs` | `bool` | Streaming logprobs for input tokens. | @@ -189,7 +188,7 @@ Validation rules: - `messages` may be empty when `prompt_token_ids` is non-empty (the chat template short-circuits). - `messages` non-empty + `prompt_token_ids` non-empty → 400 mutual-exclusion error (canonical channel only). -- `nvext.token_data` + non-empty `messages` → still allowed (renderer-mode placeholder pattern from `verifiers.dynamo_chat_nvext` keeps working). +- `nvext.token_data` + non-empty `messages` → still allowed (legacy renderer-mode placeholder pattern that uses a synthetic user message alongside pre-tokenized input). #### Sample response (non-streaming, `DYN_ENABLE_RL=true`) @@ -221,7 +220,7 @@ Validation rules: | `token_ids` | `response.choices[i].token_ids` | Per-choice output token IDs, promoted by `rl_promote_token_ids_in_response` from `nvext.completion_token_ids`. | | `completion_token_ids` | `response.nvext.completion_token_ids` | Canonical Dynamo location; accumulated across SSE chunks by `DeltaGenerator`. | -**Why two locations?** prime-rl/verifiers reads `response.prompt_token_ids` and `choices[i].token_ids`; Dynamo natively emits in `nvext.completion_token_ids`. The Rust post-processor promotes the latter to the former. +**Why two locations?** Some RL clients read tokens from `response.prompt_token_ids` / `choices[i].token_ids`; Dynamo natively emits them under `nvext.completion_token_ids`. The Rust post-processor promotes the canonical field to the choice-level field so both client conventions work. **Invariant:** `len(completion_token_ids) == len(logprobs.content)`. @@ -237,7 +236,7 @@ Mounted only when `DYN_ENABLE_RL=true`. All non-trivial routes fan out to the wo #### `GET /v1/rl/state` — composite read-only -Single endpoint that returns everything prime-rl needs to make a decision. Aggregates `get_state` per-worker payloads. +Single endpoint that returns the full fleet state in one call. Aggregates `get_state` per-worker payloads. ```bash curl -s http://localhost:8000/v1/rl/state @@ -383,7 +382,7 @@ curl -s -X POST http://localhost:8000/v1/rl/load_lora_adapter \ "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)"} ``` -vLLM worker requirements: started with `--enable-lora --max-lora-rank R --max-loras N`. For prime-rl's single-adapter loop, `--max-loras 1` is sufficient. +vLLM worker requirements: started with `--enable-lora --max-lora-rank R --max-loras N`. For a single-adapter training loop, `--max-loras 1` is sufficient. #### `POST /v1/rl/unload_lora_adapter` @@ -397,7 +396,7 @@ curl -s -X POST http://localhost:8000/v1/rl/unload_lora_adapter \ #### Legacy endpoints (kept for back-compat) -`GET /v1/rl/health`, `GET /v1/rl/ready`, `GET /v1/rl/weight_version` — same shapes as the previous draft. To be removed in Phase 5 of `docs/design-docs/rl-support.md` once prime-rl's AdminAPI migrates to `/v1/rl/state`. +`GET /v1/rl/health`, `GET /v1/rl/ready`, `GET /v1/rl/weight_version` — return the same shapes they did before `/v1/rl/state` was added. They will be removed once existing clients migrate to `/v1/rl/state`. --- @@ -407,7 +406,7 @@ curl -s -X POST http://localhost:8000/v1/rl/unload_lora_adapter \ ```mermaid sequenceDiagram - participant Orch as prime-rl Orchestrator + participant Orch as RL Orchestrator participant FE as Dynamo Frontend (Rust) participant Worker as vLLM Worker (GPU) @@ -425,9 +424,9 @@ sequenceDiagram ```mermaid sequenceDiagram - participant Trainer as prime-rl Trainer + participant Trainer as RL Trainer participant PVC as Shared Storage - participant Orch as prime-rl Orchestrator + participant Orch as RL Orchestrator participant FE as Dynamo Frontend (Rust) participant W1 as vLLM Worker 1 participant W2 as vLLM Worker 2 @@ -460,7 +459,7 @@ NCCL mode: `weight_dir=null` returns 200 immediately; the actual GPU↔GPU broad ```mermaid sequenceDiagram - participant Orch as prime-rl Orchestrator + participant Orch as RL Orchestrator participant FE as Dynamo Frontend participant W1 as vLLM Worker 1 @@ -508,7 +507,7 @@ Serialized as `nvext` on each SSE chunk and the unary response body: NvExtResponse { worker_id: Option, timing: Option, - token_ids: Option>, // GAIE Stage 1 prompt + token_ids: Option>, // pre-tokenized prompt (used by disaggregated query/fill stages) routed_experts: Option, completion_token_ids: Option>, // RL output, final chunk only } @@ -538,7 +537,7 @@ Tracks `accumulated_completion_token_ids: Vec` per request. Activated when ## 7. Worker Engine Routes (Internal) -Registered on each vLLM worker's system HTTP port (default `8081` local / `9090` k8s) by `worker_factory.py::register_engine_routes()`. Called by Rust `/v1/rl/*` handlers — not by prime-rl directly. +Registered on each vLLM worker's system HTTP port (default `8081` local / `9090` k8s) by `worker_factory.py::register_engine_routes()`. Called by the Rust `/v1/rl/*` handlers — not by external clients directly. | Route | vLLM API called | Used by | |---|---|---| @@ -562,26 +561,9 @@ Registered on each vLLM worker's system HTTP port (default `8081` local / `9090` | Limitation | Workaround | Notes | |---|---|---| -| **NCCL mode is a no-op on Dynamo's vLLM side.** `update_weights` with `weight_dir=null` returns 200 immediately, but `dynamo.vllm` does not load `NCCLWeightBroadcastReceiver` as a vLLM worker class — so the trainer's NCCL broadcast has no peer on the inference side. Trainer's `init_process_group` times out at `weight_broadcast.timeout` (default 1200 s). | Use `weight_broadcast.type = "filesystem"`. The `dynamo.sglang` backend ships `update_weights_from_distributed` natively and does work over NCCL. | Tracked at the prime-rl side: orchestrator already POSTs `/v1/rl/init_broadcaster` which dynamo.vllm doesn't expose (logs `route does not exist. Skipping NCCL broadcast initialization.` on the orch side). Wiring is the next workstream. | -| `cache_salt` not yet honored end-to-end | Set `[experimental] use_prefix_cache_salt = false` in prime-rl `orch.toml`; or send the equivalent `X-Tenant-Id` header (#8197). | Field is whitelisted (`PASSTHROUGH_EXTRA_FIELDS`) so requests don't 400; routing-side filter is a follow-up. | +| **NCCL mode is a no-op on Dynamo's vLLM side.** `update_weights` with `weight_dir=null` returns 200 immediately, but `dynamo.vllm` does not load an NCCL weight-broadcast receiver as a vLLM worker class — so the trainer's NCCL broadcast has no peer on the inference side, and `init_process_group` on the trainer times out at `weight_broadcast.timeout` (default 1200 s). | Use `weight_broadcast.type = "filesystem"`. The `dynamo.sglang` backend ships `update_weights_from_distributed` natively and does work over NCCL. | The bootstrap admin route the trainer expects (`/v1/rl/init_broadcaster`) is not exposed by `dynamo.vllm` today; wiring it (and the receiver class) is the next workstream. | +| `cache_salt` not yet honored end-to-end | Disable prefix-cache-salt on the client side, or send the equivalent `X-Tenant-Id` header. | Field is whitelisted (`PASSTHROUGH_EXTRA_FIELDS`) so requests don't 400; routing-side filter is a follow-up. | | `prompt_token_ids` only injected for non-streaming responses | Use non-streaming mode for RL rollouts (the default). | Streaming final-chunk injection is planned. | | Weight version `"initial"` before first update | Use `/v1/rl/state.applied_weight_version` for source-of-truth; don't rely on the version string for correctness. | | | Filesystem weight broadcast scales poorly for large models | Ok for 0.6B (~250 ms load); marginal at 7B (~25 s); ~150 s at 30B-A3B BF16; impractical at 70B+. | RDMA / NCCL-receive on dynamo.vllm planned. | ---- - -## 9. What Changed vs. the Earlier Draft - -For readers who know the previous `Dynamo-RL-api-draft.md`: - -| Old | New | -|---|---| -| `/v1/chat/completions/tokens` (TITO URI fork) | TITO collapsed into `/v1/chat/completions` via the `prompt_token_ids` top-level extension. URI returns 404. | -| `/v1/tokenize`, `/v1/detokenize` | Removed (return 404). Owned by [#7699](https://github.com/ai-dynamo/dynamo/pull/7699), out of scope for this surface. | -| `POST /v1/rl/pause` (no params) | `POST /v1/rl/pause?mode=keep|wait|abort&clear_cache=bool` (3-mode). | -| `POST /v1/rl/update_weights` (string body) | Typed body `{weight_dir, weight_version?, reset_prefix_cache=true}`; response carries `applied_weight_version`. | -| `/v1/rl/health` + `/v1/rl/ready` + `/v1/rl/weight_version` | All three kept for back-compat, **plus** new composite `GET /v1/rl/state`. New `GET /v1/rl/liveness` does a deep `check_health()` round-trip. | -| `PASSTHROUGH_EXTRA_FIELDS = [cache_salt]` | Now: `cache_salt`, `prompt_token_ids`, `weight_version`, `return_routed_experts`, `return_token_ids`, `return_prompt_logprobs`, `stop_token_ids`, `bad_words_token_ids`, `allowed_token_ids`, `truncate_prompt_tokens`. Full `SamplingParams` parity. | -| `get_stop_token_ids() -> Option>` (silent drop on bad input) | `get_stop_token_ids() -> Result>>`. Malformed input returns a typed 400. | -| nvext serde failures: `if let Ok(json) = serde_json::to_value(...) { ... }` (silent drop) | `match { Ok(json) => ..., Err(e) => tracing::warn!(...) }`. No silent corruption of promoted token IDs / weight version. | -| Engine routes: 5–7 (pause, resume, flush_cache, update_weights_from_path, get_weight_version, +load_lora, +unload_lora) | 9. Added: `get_state`, `liveness_probe`. | diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 92bd2724da2f..cd49ddfc847e 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -207,7 +207,6 @@ impl ErrorMessage { /// Bad Request Error. /// Return this error when the client sends an invalid request — malformed /// JSON, schema mismatch, or fields that fail `validate.rs` gating. - /// (CR-8 closure: stale doc-block lines about "Not Implemented" removed.) #[allow(dead_code)] // exposed for downstream crates; not directly called in lib/llm pub fn bad_request(msg: &str) -> ErrorResponse { let code = StatusCode::BAD_REQUEST; @@ -2110,9 +2109,9 @@ fn resolve_model_card( Ok((model, card)) } -// Phase 5: handler kept (no callers) until jthomson04 PR #7699 lands -// `/tokenize` and `/detokenize` at root paths. Re-mount via -// `tokenization_router` in `service_v2.rs` if needed standalone. +// Handler kept (no callers in this branch) for downstream code that re-mounts +// `tokenization_router` in `service_v2.rs` standalone, until the upstream +// `/tokenize` and `/detokenize` work lands at the root paths. #[allow(dead_code)] async fn tokenize( State(state): State>, @@ -2357,11 +2356,11 @@ pub fn chat_completions_router( /// /// If no path is provided, the default path is `/v1/chat/completions/tokens`. /// -/// Phase 5: dropped from the v2 surface (see `service_v2.rs`). TITO callers -/// retarget to `/v1/chat/completions` with `prompt_token_ids` extension — -/// vLLM 0.20+ skips chat templating when that field is present, identical -/// behavior. The handler is kept as `#[allow(dead_code)]` until prime-rl -/// `bis/prime-rl-merged` migration P1 lands. +/// Dropped from the v2 surface (see `service_v2.rs`). TITO callers retarget +/// to `/v1/chat/completions` with the `prompt_token_ids` extension — vLLM +/// 0.20+ skips chat templating when that field is present, identical +/// behavior. The handler is kept as `#[allow(dead_code)]` for downstream +/// code that still references it; deletion is a follow-up cleanup. #[allow(dead_code)] pub fn chat_completions_tokens_router( state: Arc, @@ -3231,9 +3230,9 @@ async fn rl_ready(State(state): State>) -> impl IntoResponse { /// - `mode`: `keep` | `wait` | `abort` (default `keep`) /// - `clear_cache`: `true` | `false` (default `false`) /// -/// Closes hhzhang16 HH-21 (3-mode pause: vLLM exposes abort/wait/keep). -/// Default is `mode=keep&clear_cache=false` to match prime-rl -/// `client.py:_pause_engines` so existing callers keep working. +/// Three-mode pause matches what vLLM exposes (abort / wait / keep). The +/// default `mode=keep&clear_cache=false` preserves the original single-mode +/// pause behavior so existing callers keep working without changes. #[derive(Debug, serde::Deserialize)] struct RlPauseQuery { #[serde(default)] @@ -3609,13 +3608,12 @@ fn rl_tokenize_prompt( /// /// response.nvext.completion_token_ids → response.choices[i].token_ids /// -/// This lets Prime-RL read `choice.token_ids` without knowing about the `nvext` -/// extension structure. Called on non-streaming responses when RL token ID mode -/// is active. (CR-10 closure: doc-block was previously misattached to -/// `rl_tokenize_prompt`.) +/// This lets RL clients read `choice.token_ids` without knowing about the +/// `nvext` extension structure. Called on non-streaming responses when RL +/// token ID mode is active. fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { - // Move completion_token_ids from response-level nvext to each choice. - // Prime-RL / verifiers expects: + // Move completion_token_ids from response-level nvext to each choice, + // because some RL clients expect: // response.choices[i].token_ids (not response.nvext.completion_token_ids) let has_nvext = json_val.get("nvext").is_some(); let has_completion_ids = json_val @@ -3649,22 +3647,22 @@ fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { /// `GET /v1/rl/health` — lightweight health check for Prime-RL admin client. /// -/// Prime-RL's `check_health()` calls `GET /health` on the admin client. When -/// `admin_base_url = ["http://dynamo:8000/v1/rl"]` the request arrives here. -/// Returns 200 OK if the frontend process is running (no deep probe needed — -/// the frontend's own `/health` endpoint handles that separately). +/// RL admin clients that POST `GET /health` against the admin client land +/// here when `admin_base_url = ["http://dynamo:8000/v1/rl"]`. Returns 200 OK +/// if the frontend process is running (no deep probe needed — the frontend's +/// own `/health` endpoint handles that separately). /// -/// **Deprecated in favor of `/v1/rl/state.ingress_alive`** (rl-support.md -/// Phase 1 / Phase 5). Kept for prime-rl `bis/prime-rl-merged` until the -/// AdminAPI migration lands; will be removed once prime-rl P2 commits. +/// **Deprecated in favor of `/v1/rl/state.ingress_alive`.** Kept for +/// back-compat until existing clients migrate to `/v1/rl/state`; will be +/// removed in a follow-up. async fn rl_health() -> impl IntoResponse { (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) } -/// `GET /v1/rl/liveness` — engine event-loop probe via `liveness_probe` -/// engine route. Closes hhzhang16 HH-23 (the v1 `/v1/rl/health` returns OK -/// no matter what; this endpoint round-trips through the engine so a hung -/// event loop or wedged worker surfaces as 503). +/// `GET /v1/rl/liveness` — engine event-loop probe via the `liveness_probe` +/// engine route. The legacy `/v1/rl/health` returns OK as long as the +/// frontend process is up; this endpoint round-trips through the engine so +/// a hung event loop or wedged worker surfaces as 503. /// /// Each per-worker call carries a 5s timeout (override via /// `DYN_RL_LIVENESS_TIMEOUT_MS`). Returns 200 only when every worker @@ -3750,9 +3748,9 @@ async fn rl_liveness(State(state): State>) -> impl IntoResponse { /// `GET /v1/rl/state` — composite RL fleet state snapshot. /// /// Replaces three v1 endpoints (`/v1/rl/health` + `/v1/rl/ready` + -/// `/v1/rl/weight_version`) with a single composite. Closes hhzhang16 -/// HH-19 (single state endpoint), HH-25 (RL-specific vs broader Dynamo -/// readiness), HH-27 (weight_version folded in). +/// `/v1/rl/weight_version`) with a single composite, scoped to RL-specific +/// readiness (engine alive, pause state, applied weight version, loaded +/// LoRAs). /// /// Aggregates per-worker `get_state` engine-route responses into: /// @@ -3891,8 +3889,8 @@ pub fn rl_router() -> (Vec, Router) { .route("/v1/rl/update_weights", post(rl_update_weights)) // LoRA hot-swap. .route("/v1/rl/load_lora_adapter", post(rl_load_lora_adapter)) - // Legacy endpoints — kept until prime-rl `bis/prime-rl-merged` AdminAPI - // migration P2 lands; Phase 5 of rl-support.md drops them. + // Legacy endpoints — kept for back-compat until existing clients + // migrate to /v1/rl/state. Removed in a follow-up. .route("/v1/rl/health", get(rl_health)) .route("/v1/rl/ready", get(rl_ready)) .route("/v1/rl/weight_version", get(rl_weight_version)) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 9abe99f4fc28..75acc48356cf 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -460,10 +460,10 @@ impl crate::protocols::openai::DeltaGeneratorExt { stream_response.nvext = Some(nvext_json); diff --git a/lib/llm/src/protocols/openai/completions/delta.rs b/lib/llm/src/protocols/openai/completions/delta.rs index 26d84c7803ae..a47d9215a6ef 100644 --- a/lib/llm/src/protocols/openai/completions/delta.rs +++ b/lib/llm/src/protocols/openai/completions/delta.rs @@ -314,9 +314,9 @@ impl crate::protocols::openai::DeltaGeneratorExt for finish_reason.is_some(), delta.engine_data, ) { - // CR-9 closure: log a warning if serialization fails instead of - // silently dropping the nvext payload (would mean promoted fields - // never reach the client). + // Log a warning if serialization fails instead of silently + // dropping the nvext payload (would mean promoted fields never + // reach the client). match serde_json::to_value(&nvext_response) { Ok(nvext_json) => { response.nvext = Some(nvext_json); diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index 9864744d6b8e..1e694a0c6af2 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -97,33 +97,33 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0; // Shared Fields // -/// Fields that Prime-RL / verifiers may send as extra_body hints which Dynamo -/// does not implement but should not reject with a 400. They are silently -/// accepted (the chat-completions handler reads what it understands and -/// ignores the rest) so the RL client stack is forward-compatible with new -/// extension fields without churning Dynamo. +/// Fields that RL clients may send as extra_body hints which Dynamo does not +/// implement but should not reject with a 400. They are silently accepted +/// (the chat-completions handler reads what it understands and ignores the +/// rest) so the RL client stack is forward-compatible with new extension +/// fields without churning Dynamo. /// -/// Per `bis-dev/design-docs/rl-support.md` Phase 4, this is the canonical -/// home for the typed RL extension fields; the prior `nvext.extra_fields` -/// `["completion_token_ids", ...]` opt-in mechanism still works alongside it -/// but the named fields here are the recommended path. +/// This is the canonical home for typed RL extension fields; the prior +/// `nvext.extra_fields = ["completion_token_ids", ...]` opt-in mechanism +/// still works alongside it but the named fields here are the recommended +/// path. const PASSTHROUGH_EXTRA_FIELDS: &[&str] = &[ - // KV prefix-cache isolation hint from prime-rl orchestrator. Coordinated - // with PR #8197 (which moves this to the X-Tenant-Id header); both forms - // accepted for one release, header takes precedence. + // KV prefix-cache isolation hint. The equivalent `X-Tenant-Id` request + // header is also accepted; the header takes precedence when both are + // present. "cache_salt", // Pre-tokenized prompt for the RL TITO path. Mutually exclusive with - // `messages`; when present, vLLM 0.20+ skips chat templating. Closes - // hhzhang16 HH-22 / HH-26 — the "tokens variant of /v1/chat/completions" - // collapses into the same URI with this extension field instead of a - // forked /v1/chat/completions/tokens. Today RL clients pre-tokenize - // and pass via `nvext.token_data` (preprocessor.rs handles that - // already); the typed top-level field shipped here is the long-term - // canonical entry for clients written against the vLLM 0.20 schema. + // `messages`; when present, vLLM 0.20+ skips chat templating. The + // "tokens variant of /v1/chat/completions" collapses into the same URI + // with this extension field instead of a forked + // /v1/chat/completions/tokens. Today RL clients can also pre-tokenize + // and pass via `nvext.token_data` (handled in preprocessor.rs); the + // typed top-level field shipped here is the long-term canonical entry + // for clients written against the vLLM 0.20 schema. "prompt_token_ids", // RL routing filter — only dispatch to workers reporting this applied // weight version. Used by IS-correction strict-version mode and by - // NeMo RL eval-on-subset. Today accepted-and-ignored at the request + // eval-on-subset workflows. Today accepted-and-ignored at the request // level; the routing-side filter lands in a follow-up. "weight_version", // Per-request gate for MoE Routing Replay capture. Honored by From 0ee6cbe4ebdf3c0486257ca22b7ab4c13d3104c6 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 7 May 2026 17:25:56 -0700 Subject: [PATCH 08/18] =?UTF-8?q?fix(rl):=20address=20blocking=20review=20?= =?UTF-8?q?items=20=E2=80=94=20unwrap/expect,=20structured=20tracing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four blocking findings from a Graham-style review. 1. resolve_model_for_chat (openai.rs ~2086): replace served_models.into_iter().next().unwrap() after a len()==1 check with let-Some on drain().next(). Eliminates the silent-panic-if-empty path without adding a Safety: comment for what's now self-evident. 2. RlState::from_env (openai.rs ~3128): builder().build() failures (TLS init, resolver init) panicked the frontend on first request boot via .expect("Failed to create RL router HTTP client"). Now returns anyhow::Result and surfaces a typed error. Caller rl_router() becomes anyhow::Result<(Vec, Router)>; service_v2::build() propagates with `?` (it already returns Result). 3. rl_update_weights (openai.rs ~3349): replace `if weight_dir.is_none() { return ... } let weight_dir = weight_dir.unwrap();` with `let Some(weight_dir) = body.weight_dir.clone() else { return ... };`. One match instead of two; no unwrap. 4. Structured tracing fields, not format strings (8 sites in rl_pause / rl_resume / rl_update_weights / rl_load_lora_adapter / rl_unload_lora_adapter): tracing::warn!("RL pause: some workers failed: {:?}", results); → tracing::warn!(?results, "RL pause: some workers failed"); Same shape applied to info!/warn! calls that interpolated worker_count, mode, weight_dir, version, lora_name, lora_path. Use % for Display, ? for Debug per tracing::Value docs. Verification: cargo check -p dynamo-llm: clean (1 pre-existing benign warning). cargo test -p dynamo-llm --test test_common_ext: 15 passed. --- lib/llm/src/http/service/openai.rs | 92 ++++++++++++++------------ lib/llm/src/http/service/service_v2.rs | 2 +- 2 files changed, 52 insertions(+), 42 deletions(-) diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index cd49ddfc847e..4b72218c8924 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -2081,9 +2081,11 @@ fn resolve_tokenizer_model_name( } return Err(ErrorMessage::model_not_found()); } - let served_models = state.manager().model_display_names(); + let mut served_models = state.manager().model_display_names(); if served_models.len() == 1 { - return Ok(served_models.into_iter().next().unwrap()); + if let Some(only) = served_models.drain().next() { + return Ok(only); + } } Err(bad_request( "Model must be specified when more than one model is served.", @@ -3124,7 +3126,7 @@ struct RlState { } impl RlState { - fn from_env() -> Self { + fn from_env() -> anyhow::Result { let worker_system_urls = std::env::var(DYN_RL_WORKER_SYSTEM_URLS_ENV) .unwrap_or_else(|_| "http://localhost:8081".to_string()) .split(',') @@ -3132,17 +3134,18 @@ impl RlState { .filter(|s| !s.is_empty()) .collect::>(); tracing::info!( - "RL admin router configured with {} worker(s): {:?}", - worker_system_urls.len(), - worker_system_urls + worker_count = worker_system_urls.len(), + ?worker_system_urls, + "RL admin router configured" ); - Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .map_err(|e| anyhow::anyhow!("Failed to build RL router HTTP client: {e}"))?; + Ok(Self { worker_system_urls, - http_client: reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .expect("Failed to create RL router HTTP client"), - } + http_client, + }) } /// Call a single engine route on one worker. Returns the JSON body. @@ -3266,8 +3269,10 @@ async fn rl_pause( .await; if RlState::all_ok(&results) { tracing::info!( - "RL pause: all {} worker(s) paused (mode={mode}, clear_cache={clear_cache})", - results.len() + worker_count = results.len(), + mode = %mode, + clear_cache, + "RL pause: all workers paused" ); ( StatusCode::OK, @@ -3279,7 +3284,7 @@ async fn rl_pause( })), ) } else { - tracing::warn!("RL pause: some workers failed: {:?}", results); + tracing::warn!(?results, "RL pause: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({"status": "error", "workers": results})), @@ -3293,13 +3298,13 @@ async fn rl_resume(State(state): State>) -> impl IntoResponse { .fan_out("resume_generation", serde_json::json!({})) .await; if RlState::all_ok(&results) { - tracing::info!("RL resume: all {} worker(s) resumed", results.len()); + tracing::info!(worker_count = results.len(), "RL resume: all workers resumed"); ( StatusCode::OK, Json(serde_json::json!({"status": "ok", "workers": results})), ) } else { - tracing::warn!("RL resume: some workers failed: {:?}", results); + tracing::warn!(?results, "RL resume: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({"status": "error", "workers": results})), @@ -3341,10 +3346,9 @@ async fn rl_update_weights( State(state): State>, body: axum::extract::Json, ) -> impl IntoResponse { - let weight_dir = body.weight_dir.clone(); let reset_prefix_cache = body.reset_prefix_cache; - if weight_dir.is_none() { + let Some(weight_dir) = body.weight_dir.clone() else { tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); return ( StatusCode::OK, @@ -3353,9 +3357,8 @@ async fn rl_update_weights( "message": "NCCL mode, no-op on Dynamo side" })), ); - } + }; - let weight_dir = weight_dir.unwrap(); let version = body.weight_version.clone().unwrap_or_else(|| { std::path::Path::new(&weight_dir) .file_name() @@ -3364,14 +3367,17 @@ async fn rl_update_weights( .to_string() }); tracing::info!( - "RL update_weights: weight_dir={weight_dir} version={version} reset_prefix_cache={reset_prefix_cache}" + weight_dir = %weight_dir, + version = %version, + reset_prefix_cache, + "RL update_weights" ); // Step 1 (optional): flush_cache across all workers. if reset_prefix_cache { let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; if !RlState::all_ok(&flush_results) { - tracing::warn!("RL update_weights: flush_cache failed: {:?}", flush_results); + tracing::warn!(?flush_results, "RL update_weights: flush_cache failed"); return ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({ @@ -3384,12 +3390,13 @@ async fn rl_update_weights( } // Step 2: update_weights_from_path across all workers. - let load_body = serde_json::json!({"path": weight_dir, "version": version}); + let load_body = serde_json::json!({"path": &weight_dir, "version": version}); let load_results = state.fan_out("update_weights_from_path", load_body).await; if RlState::all_ok(&load_results) { tracing::info!( - "RL update_weights: all {} worker(s) updated weights to {weight_dir}", - load_results.len() + worker_count = load_results.len(), + weight_dir = %weight_dir, + "RL update_weights: all workers updated" ); ( StatusCode::OK, @@ -3401,8 +3408,8 @@ async fn rl_update_weights( ) } else { tracing::warn!( - "RL update_weights: update_weights_from_path failed: {:?}", - load_results + ?load_results, + "RL update_weights: update_weights_from_path failed" ); ( StatusCode::BAD_GATEWAY, @@ -3449,25 +3456,27 @@ async fn rl_load_lora_adapter( } }; - tracing::info!("RL load_lora_adapter: lora_name={lora_name} lora_path={lora_path}"); + tracing::info!(%lora_name, %lora_path, "RL load_lora_adapter"); let results = state .fan_out( "load_lora_adapter", - serde_json::json!({"lora_name": lora_name, "lora_path": lora_path}), + serde_json::json!({"lora_name": &lora_name, "lora_path": &lora_path}), ) .await; if RlState::all_ok(&results) { tracing::info!( - "RL load_lora_adapter: all {} worker(s) loaded LoRA '{lora_name}' from {lora_path}", - results.len() + worker_count = results.len(), + %lora_name, + %lora_path, + "RL load_lora_adapter: all workers loaded" ); ( StatusCode::OK, Json(serde_json::json!({"status": "ok", "workers": results})), ) } else { - tracing::warn!("RL load_lora_adapter: some workers failed: {:?}", results); + tracing::warn!(?results, %lora_name, "RL load_lora_adapter: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({"status": "error", "workers": results})), @@ -3503,25 +3512,26 @@ async fn rl_unload_lora_adapter( } }; - tracing::info!("RL unload_lora_adapter: lora_name={lora_name}"); + tracing::info!(%lora_name, "RL unload_lora_adapter"); let results = state .fan_out( "unload_lora_adapter", - serde_json::json!({"lora_name": lora_name}), + serde_json::json!({"lora_name": &lora_name}), ) .await; if RlState::all_ok(&results) { tracing::info!( - "RL unload_lora_adapter: all {} worker(s) unloaded LoRA '{lora_name}'", - results.len() + worker_count = results.len(), + %lora_name, + "RL unload_lora_adapter: all workers unloaded" ); ( StatusCode::OK, Json(serde_json::json!({"status": "ok", "workers": results})), ) } else { - tracing::warn!("RL unload_lora_adapter: some workers failed: {:?}", results); + tracing::warn!(?results, %lora_name, "RL unload_lora_adapter: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({"status": "error", "workers": results})), @@ -3861,8 +3871,8 @@ async fn rl_state(State(state): State>) -> impl IntoResponse { /// Prime-RL usage: set `admin_base_url = ["http://dynamo-frontend:8000/v1/rl"]` /// in the orchestrator config. Prime-RL strips the trailing `/v1` suffix only /// if present, so `/v1/rl` is preserved and all routes resolve correctly. -pub fn rl_router() -> (Vec, Router) { - let rl_state_arc = Arc::new(RlState::from_env()); +pub fn rl_router() -> anyhow::Result<(Vec, Router)> { + let rl_state_arc = Arc::new(RlState::from_env()?); let docs = vec![ // Phase 1: composite endpoints. RouteDoc::new(axum::http::Method::GET, "/v1/rl/state"), @@ -3897,7 +3907,7 @@ pub fn rl_router() -> (Vec, Router) { .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) .layer(middleware::from_fn(smart_json_error_middleware)) .with_state(rl_state_arc); - (docs, router) + Ok((docs, router)) } #[cfg(test)] diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 4d5e3816b666..8e26252b09af 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -552,7 +552,7 @@ impl HttpServiceConfigBuilder { // RL admin routes: enabled when builder flag is set OR when DYN_ENABLE_RL env var is truthy. if config.enable_rl || env_is_truthy("DYN_ENABLE_RL") { tracing::info!("RL admin routes enabled at /v1/rl/*"); - system_routes.push(super::openai::rl_router()); + system_routes.push(super::openai::rl_router()?); } let mut system_router = axum::Router::new(); for (route_docs, route) in system_routes { From 4aac7e890c44cd1678156b79860fa45efe4ad19e Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 7 May 2026 17:37:30 -0700 Subject: [PATCH 09/18] fix(rl): address CodeRabbit review findings on RL surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Quick-win review fixes from PR #9131. Heavy-lift items (#9 prompt_token_ids env-gate, #11 update_weights atomicity, #13 per-choice completion_token_ids) tracked separately as follow-ups. handlers.py - Catch EngineDeadError before the generic except in all 8 RL handlers (pause/resume/liveness_probe/get_state/flush_cache/update_weights_from_path/ load_lora_adapter/unload_lora_adapter): match the existing shutdown pattern in this file so admin calls also surface engine death instead of leaving a broken worker alive. - get_state: fall back to a no-op collective_rpc when check_health is absent — same fallback liveness_probe already uses, otherwise older engines without check_health always look alive. - load_lora_adapter hot-swap path: a remove_lora() failure now returns a 400-style error response (was: silent log warn + continue, leaving add_lora to no-op against the still-registered ID); a reset_prefix_cache() failure after add_lora succeeds also returns error (was: log error and continue, leaving stale KV from the old adapter routable). - unload_lora_adapter: an unregister_model() failure after engine remove_lora succeeds now returns error (was: log warn and report success, leaving model= still routed to this worker even though _resolve_lora_request would now fall back to the base model). container/deps/vllm/install_vllm.sh - Pin prime-rl install to an immutable commit SHA (d49f3939e7dca29bceb9ed515cc1782497b67e81 ↔ tag v0.5.1.dev101) so a re-pointed tag upstream can't change what we ship. PRIME_RL_REF kept in build logs for human readability; PRIME_RL_COMMIT is the authoritative pin. - Replace `echo "\n=== ..."` with `printf '\n=== ...\n'` (shellcheck SC2028). lib/llm/src/http/service/openai.rs - Force `request.inner.logprobs = Some(true)` unconditionally in both RL token-id promotion blocks (was: only when None). RL extraction of completion_token_ids depends on logprobs being on at the engine; an explicit logprobs=false would otherwise silently drop them. - Bound `/v1/rl/ready` per-worker probes with a 5s timeout (override via DYN_RL_LIVENESS_TIMEOUT_MS). Was reusing the shared 600s http_client, so one wedged worker could block readiness for 10 minutes instead of failing fast as 503. - Tokenize Chat handler: call `request.validate()?` before `merged_chat_template_kwargs()` so the continue_final_message + add_generation_prompt mutual-exclusion constraint is enforced (validate() existed but was never invoked). lib/llm/src/protocols/openai/chat_completions.rs - Update stale doc comments on the legacy `tokens` and `return_token_ids` fields: they pointed callers at the now-404 `/v1/chat/completions/tokens` URI. Direct callers to the canonical top-level `prompt_token_ids` extension and `nvext.extra_fields` instead. cargo check -p dynamo-llm: clean (1 pre-existing benign warning). cargo test -p dynamo-llm --test test_common_ext: 15 passed. --- components/src/dynamo/vllm/handlers.py | 106 ++++++++++++++++-- container/deps/vllm/install_vllm.sh | 10 +- lib/llm/src/http/service/openai.rs | 39 +++++-- .../src/protocols/openai/chat_completions.rs | 13 ++- 4 files changed, 136 insertions(+), 32 deletions(-) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 5691f46be692..b1b92b526e7e 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -883,6 +883,11 @@ async def pause_generation(self, body: dict) -> dict: "mode": mode, "clear_cache": clear_cache, } + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.error(f"[RL] Failed to pause: {e}") return {"status": "error", "message": str(e)} @@ -895,6 +900,11 @@ async def resume_generation(self, body: dict) -> dict: self._paused = False logger.info("[RL] Engine resumed") return {"status": "ok", "message": "Engine resumed"} + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.error(f"[RL] Failed to resume: {e}") return {"status": "error", "message": str(e)} @@ -922,6 +932,11 @@ async def liveness_probe(self, body: dict) -> dict: # event loop is wedged the frontend's 5s timeout fires. await self.engine_client.collective_rpc("get_weight_version", kwargs={}) return {"status": "ok", "alive": True} + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.warning(f"[RL] liveness_probe failed: {e}") return {"status": "error", "alive": False, "message": str(e)} @@ -940,6 +955,13 @@ async def get_state(self, body: dict) -> dict: try: if hasattr(self.engine_client, "check_health"): await self.engine_client.check_health() + else: + # Same fallback as liveness_probe: a no-op collective_rpc + # round-trip is the liveness signal when check_health is + # absent; otherwise older engines would always look alive. + await self.engine_client.collective_rpc( + "get_weight_version", kwargs={} + ) except Exception as health_err: engine_alive = False logger.warning(f"[RL] get_state: engine_alive=false ({health_err})") @@ -953,6 +975,11 @@ async def get_state(self, body: dict) -> dict: for name, info in getattr(self, "loaded_loras", {}).items() ], } + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.error(f"[RL] get_state failed: {e}") return {"status": "error", "message": str(e)} @@ -964,6 +991,11 @@ async def flush_cache(self, body: dict) -> dict: await self.engine_client.reset_prefix_cache() logger.info("[RL] Prefix cache flushed") return {"status": "ok", "message": "Cache flushed"} + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.error(f"[RL] Failed to flush cache: {e}") return {"status": "error", "message": str(e)} @@ -995,6 +1027,11 @@ async def update_weights_from_path(self, body: dict) -> dict: "message": f"Weights loaded from {path}", "version": version, } + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.error(f"[RL] Failed to load weights from {path}: {e}") return {"status": "error", "message": str(e)} @@ -1052,12 +1089,24 @@ async def load_lora_adapter(self, body: dict) -> dict: # Invalidate the cache entry immediately after remove succeeds. # If add_lora below fails, this prevents a stale entry pointing # at an adapter the engine no longer holds from poisoning future - # rollouts with wrong importance ratios (Tier-1 RL correctness risk). + # rollouts with wrong importance ratios. self.loaded_loras.pop(lora_name, None) except Exception as e: - logger.warning( + # remove_lora failure during hot-swap is non-recoverable + # for this request: add_lora below would no-op against + # the still-registered ID. Surface as error so the + # caller doesn't think the swap succeeded. + logger.error( f"[RL] remove_lora({lora_name}, id={old_id}) failed during hot-swap: {e}" ) + return { + "status": "error", + "message": ( + f"Failed to remove existing LoRA '{lora_name}' " + f"before hot-swap: {e}" + ), + "lora_name": lora_name, + } await self.engine_client.add_lora( LoRARequest( @@ -1074,15 +1123,24 @@ async def load_lora_adapter(self, body: dict) -> dict: try: await self.engine_client.reset_prefix_cache() except Exception as e: - # ERROR not WARNING: a failed cache reset means subsequent requests - # sharing a prefix with an old rollout can reuse KV state computed - # under the previous adapter — causing silent logprobs mismatch. + # A failed cache reset means subsequent requests sharing + # a prefix with an old rollout can reuse KV state + # computed under the previous adapter — silent logprobs + # mismatch. Surface as an error so the caller doesn't + # treat the swap as safe to serve. logger.error( - f"[RL] reset_prefix_cache after LoRA swap failed — KV cache may " - f"be contaminated with stale entries from the old adapter. " - f"Rollouts on this worker are unreliable until the next successful " - f"swap: {e}" + f"[RL] reset_prefix_cache after LoRA swap failed: {e}" ) + return { + "status": "error", + "message": ( + f"LoRA '{lora_name}' was loaded but prefix cache " + f"reset failed; worker is not safe to serve until " + f"the next successful swap." + ), + "lora_name": lora_name, + "lora_id": lora_id, + } # Publish an MDC for the LoRA on first load so Dynamo's frontend # can route requests with model= to this worker. @@ -1140,6 +1198,11 @@ async def load_lora_adapter(self, body: dict) -> dict: "lora_id": lora_id, "hot_swap": is_hot_swap, } + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.exception( f"[RL] Failed to load LoRA adapter '{lora_name}' from {lora_path}: {e}" @@ -1173,7 +1236,11 @@ async def unload_lora_adapter(self, body: dict) -> dict: del self.loaded_loras[lora_name] # Unregister the MDC published on load so the frontend stops - # routing `model=` requests to this worker. + # routing `model=` requests to this worker. If this + # fails the engine no longer has the adapter but the frontend + # still routes to us — `_resolve_lora_request` then falls back + # to the base model, silently changing semantics. Surface as + # an error so the caller can retry / drain explicitly. if self.generate_endpoint is not None: try: await unregister_model( @@ -1181,9 +1248,19 @@ async def unload_lora_adapter(self, body: dict) -> dict: lora_name=lora_name, ) except Exception as e: - logger.warning( - f"[RL] Failed to unregister LoRA '{lora_name}' MDC (adapter already removed from engine): {e}" + logger.error( + f"[RL] Failed to unregister LoRA '{lora_name}' MDC after engine removal: {e}" ) + return { + "status": "error", + "message": ( + f"LoRA '{lora_name}' removed from engine but " + f"discovery unregister failed; frontend may " + f"still route to this worker until retried: {e}" + ), + "lora_name": lora_name, + "lora_id": lora_id, + } logger.info( f"[RL] LoRA adapter unloaded: name={lora_name} id={lora_id}" @@ -1194,6 +1271,11 @@ async def unload_lora_adapter(self, body: dict) -> dict: "lora_name": lora_name, "lora_id": lora_id, } + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) except Exception as e: logger.exception( f"[RL] Failed to unload LoRA adapter '{lora_name}': {e}" diff --git a/container/deps/vllm/install_vllm.sh b/container/deps/vllm/install_vllm.sh index 90bee7bf3034..ffbceff6bb1e 100755 --- a/container/deps/vllm/install_vllm.sh +++ b/container/deps/vllm/install_vllm.sh @@ -309,7 +309,9 @@ fi # LoRA, etc.) automatically in every vLLM worker process -- including spawned # subprocesses. Required for prime-rl / Dynamo RL training integration. # -# Override at build time: --build-arg PRIME_RL_REF=v0.5.1.dev101 +# Pinned to an immutable commit SHA (not a tag) for reproducibility; tags can +# be re-pointed upstream. PRIME_RL_REF is kept for human-readable build logs. +# Override at build time: --build-arg PRIME_RL_COMMIT= # --no-deps: prime-rl's full dep tree includes trainer + wandb; Dynamo only # needs the inference-side plugin and worker-extension classes. # Python version: prime-rl pins requires-python = "~=3.12.0"; Dynamo containers @@ -317,9 +319,11 @@ fi # dev venvs use the regular pip (not uv) with --ignore-requires-python. # --------------------------------------------------------------------------- PRIME_RL_REF="${PRIME_RL_REF:-v0.5.1.dev101}" -echo "\n=== Installing prime-rl vLLM plugin (ref=${PRIME_RL_REF}) ===" +PRIME_RL_COMMIT="${PRIME_RL_COMMIT:-d49f3939e7dca29bceb9ed515cc1782497b67e81}" +printf '\n=== Installing prime-rl vLLM plugin (ref=%s commit=%s) ===\n' \ + "$PRIME_RL_REF" "$PRIME_RL_COMMIT" uv pip install --no-deps \ - "prime-rl @ git+https://github.com/PrimeIntellect-ai/prime-rl@${PRIME_RL_REF}" + "prime-rl @ git+https://github.com/PrimeIntellect-ai/prime-rl@${PRIME_RL_COMMIT}" # Sanity-check: confirm vllm.general_plugins entry-point is registered. python3 - <<'PY_SANITY' diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 4b72218c8924..fa5fd153891d 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -981,10 +981,12 @@ async fn handler_chat_completions( } } nvext.extra_fields = Some(extra_fields); - // Also force logprobs on when RL is requesting token IDs - if request.inner.logprobs.is_none() { - request.inner.logprobs = Some(true); - } + // RL token-id extraction depends on logprobs being enabled at + // the engine. Override unconditionally — an explicit + // logprobs=false would otherwise drop completion_token_ids + // from the response while we silently still claim to return + // them. + request.inner.logprobs = Some(true); } request.nvext = Some(nvext); @@ -2147,6 +2149,12 @@ async fn tokenize( (token_ids, token_strs) } TokenizeRequest::Chat(request) => { + // Reject mutually-exclusive flags + // (continue_final_message + add_generation_prompt) up-front so the + // chat-template render below doesn't see an inconsistent state. + request + .validate() + .map_err(|err| bad_request(&format!("Invalid tokenize request: {err}")))?; let model = request .model .clone() @@ -2423,10 +2431,10 @@ async fn handler_chat_completions_tokens( nvext.extra_fields = Some(extra_fields); request.nvext = Some(nvext); - // Force logprobs on (RL always needs them) - if request.inner.logprobs.is_none() { - request.inner.logprobs = Some(true); - } + // Force logprobs on (RL always needs them). Unconditional — an + // explicit logprobs=false from the caller would otherwise silently + // strip completion_token_ids from the response. + request.inner.logprobs = Some(true); // Ensure messages is non-empty (Dynamo requires it for model lookup / chat template) if request.inner.messages.is_empty() { @@ -3194,7 +3202,16 @@ impl RlState { } /// `GET /v1/rl/ready` — composite readiness check: worker health via system port. +/// +/// Bounded with a per-worker probe timeout (default 5s, override via +/// `DYN_RL_LIVENESS_TIMEOUT_MS`) so a wedged worker fails fast as 503 instead +/// of hanging on the shared 600s `http_client` timeout. async fn rl_ready(State(state): State>) -> impl IntoResponse { + let timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(5000); + let timeout = std::time::Duration::from_millis(timeout_ms); let futures: Vec<_> = state .worker_system_urls .iter() @@ -3202,10 +3219,10 @@ async fn rl_ready(State(state): State>) -> impl IntoResponse { let client = state.http_client.clone(); let health_url = format!("{url}/health"); async move { - client - .get(&health_url) - .send() + tokio::time::timeout(timeout, client.get(&health_url).send()) .await + .ok() + .and_then(Result::ok) .map(|r| r.status().is_success()) .unwrap_or(false) } diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index 328711139e5f..6d1f071d5a18 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -59,15 +59,16 @@ pub struct NvCreateChatCompletionRequest { #[serde(default, skip_serializing_if = "Option::is_none")] pub media_io_kwargs: Option, - /// RL: Pre-tokenized prompt tokens from Prime-RL's TITO interface. - /// On the standard `/v1/chat/completions` endpoint this field is accepted but ignored - /// (use `/v1/chat/completions/tokens` for TITO mode where tokens are authoritative). - /// Accepting it here avoids 400 errors when Prime-RL sends it without the rl-admin proxy. + /// Legacy RL field (pre-tokenized prompt). Accepted but ignored on + /// `/v1/chat/completions` — the canonical TITO channel is now the + /// top-level `prompt_token_ids` extension on the same endpoint + /// (allowlisted in `validate.rs::PASSTHROUGH_EXTRA_FIELDS`). Kept here + /// so older clients still sending `tokens` don't 400. #[serde(default, skip_serializing)] pub tokens: Option>, - /// RL: Prime-RL requests token IDs in the response via this field. - /// Accepted but ignored on standard chat completions (use `nvext.extra_fields` instead). + /// Legacy RL field. Accepted but ignored on standard chat completions — + /// use `nvext.extra_fields = ["completion_token_ids"]` instead. #[serde(default, skip_serializing)] pub return_token_ids: Option, From 8e08e32b0670692668f359d84002415f4eb20aa2 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Thu, 7 May 2026 18:34:10 -0700 Subject: [PATCH 10/18] =?UTF-8?q?chore(rl):=20address=20Tier=201/2/3/5=20r?= =?UTF-8?q?eview=20issues=20=E2=80=94=20drop=20dead=20code,=20dedupe=20shu?= =?UTF-8?q?tdowns,=20tier=20the=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stack-ranked review fixes from the latest Graham-style pass on this branch. Tier 4 (commit message conventions and tokenizer-crate scope) deliberately deferred. Tier 1 — correctness / contract claims - Drop the word "atomic" from the `/v1/rl/update_weights` docblock and spell out the partial-failure semantics: workers `0..N-1` may have switched while worker `N` failed; per-worker status is in the response, true rollback is a follow-up. - Reject `n > 1` on `/v1/chat/completions` when RL token IDs are requested. The streaming aggregator's `completion_token_ids` Vec is shared across all choices, so per-choice promotion downstream cannot recover which tokens belong to which choice with `n > 1`. Hard-reject is the interim guard until the keyed-by-index refactor lands. - `update_weights` empty `weight_dir` (`""`) is now treated the same as null/missing (NCCL-mode no-op) instead of being forwarded to the engine as `path=""`. Tier 2 — hard rules - Hot-path `[RL]` `logger.info` → `logger.debug` for pause / resume / flush_cache / weights-load / LoRA load / LoRA unload (8 sites). RL trainers fire these per training step; info-level was a log flood. - Extract the duplicated 4-line `EngineDeadError` shutdown stanza into `BaseWorkerHandler._shutdown_on_engine_dead(e) -> NoReturn` and collapse the 9 call sites (8 RL handlers + 1 generate path) to a single line each. ~32 lines removed. - Strip the remaining internal-tracker comments missed by an earlier chore-scrub: `service_v2.rs` had `jthomson04 PR #7699`, `bis-dev/design-docs/rl-support.md §1`, and `hhzhang16 HH-22 / HH-26` references in two places. Replaced with neutral prose. - Strip the SGLang-coordination comment in `handlers.py` ("Signatures intentionally line up with the SGLang RL admin routes"). The kind of line that goes stale when SGLang's admin set drifts. - Delete dead-code carcasses for the dropped routes (`/v1/chat/completions/tokens`, `/v1/tokenize`, `/v1/detokenize`): remove `tokenize`, `detokenize`, `tokenization_router`, `chat_completions_tokens_router`, `handler_chat_completions_tokens` (~240 lines), drop `pub mod tokenization;`, delete the `tokenization.rs` module file (124 lines), drop unused tokenize-type imports. All five were behind `#[allow(dead_code)]` and unmounted. Tier 3 — tests - Make `RlState::new(...)` a pub(super) test-friendly constructor so handler-level tests don't need `from_env` / process env vars. - Convert `RlPauseQuery::mode: Option` to `Option`, a typed enum with `serde(rename_all = "lowercase")`. Axum now returns 400 on `mode=foo` before the handler runs; the runtime string match is gone. - Add four behavior tests in `mod tests`: test_pause_mode_serde_roundtrip test_pause_mode_rejects_unknown_value test_rl_update_weights_body_defaults test_rl_state_new_constructs_without_env - Fix unrelated test struct-init breakage that shipped earlier in this branch: 26 sites across 11 files were missing the `prompt_token_ids`, `return_token_ids`, `tokens`, and `completion_token_ids` fields added to the response/request/nvext structs. `cargo test -p dynamo-llm` now compiles cleanly. Tier 5 — design / nits - `DYN_RL_LIVENESS_TIMEOUT_MS` is read once in `RlState::from_env` and cached as `RlState.probe_timeout: Duration`; `rl_ready` and `rl_liveness` use the cached value instead of re-parsing the env on every request. - `rl_ready` worker-probe body simplified from a four-link `.ok().and_then(Result::ok).map(...).unwrap_or(false)` chain to a `match (timeout, send_result)` that surfaces the timeout/network distinction for a future log line. cargo check -p dynamo-llm: clean (1 pre-existing benign warning). cargo test -p dynamo-llm --lib: 58 passed (4 new RL tests). cargo test -p dynamo-llm --tests integration suites: all green. --- components/src/dynamo/vllm/handlers.py | 79 ++- lib/llm/src/http/service/openai.rs | 464 +++++++----------- lib/llm/src/http/service/service_v2.rs | 21 +- lib/llm/src/protocols/anthropic/types.rs | 2 + lib/llm/src/protocols/openai.rs | 1 - .../openai/chat_completions/delta.rs | 6 + lib/llm/src/protocols/openai/nvext.rs | 4 + lib/llm/src/protocols/openai/responses/mod.rs | 14 + lib/llm/src/protocols/openai/tokenization.rs | 124 ----- lib/llm/src/protocols/unified.rs | 3 + .../tests/parallel_tool_call_integration.rs | 3 + lib/llm/tests/preprocessor.rs | 6 + lib/llm/tests/test_streaming_usage.rs | 6 + lib/llm/tests/tool_choice.rs | 3 + lib/llm/tests/tool_choice_finish_reasons.rs | 3 + 15 files changed, 259 insertions(+), 480 deletions(-) delete mode 100644 lib/llm/src/protocols/openai/tokenization.rs diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index b1b92b526e7e..02e5a54d97a5 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar +from typing import Any, AsyncIterator, Dict, Final, Generic, NoReturn, Optional, TypeVar import torch from vllm.config import ModelConfig, VllmConfig @@ -568,6 +568,15 @@ def __init__( # Store shutdown event for graceful shutdown monitoring self.shutdown_event = shutdown_event + def _shutdown_on_engine_dead(self, e: EngineDeadError) -> NoReturn: + """Common handler for `EngineDeadError`: log, shut down the runtime, + hard-exit. Called from RL admin handler `except` clauses so a dead + engine surfaces as a worker restart instead of silent failure.""" + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) + def init_embedding_loader( self, config: Config, encode_worker_client: Optional[Client] = None ) -> Optional[MultiModalEmbeddingLoader]: @@ -826,8 +835,6 @@ async def stop_profile(self, body: dict) -> dict: return {"status": "error", "message": str(e)} # ── RL weight lifecycle engine routes ────────────────────────────── - # Signatures intentionally line up with the SGLang RL admin routes so a - # single admin coordinator can talk to either backend. async def pause_generation(self, body: dict) -> dict: """Pause the engine: drain in-flight requests, keep model loaded. @@ -868,13 +875,13 @@ async def pause_generation(self, body: dict) -> dict: if clear_cache: try: await self.engine_client.reset_prefix_cache() - logger.info("[RL] pause: prefix cache cleared") + logger.debug("[RL] pause: prefix cache cleared") except Exception as flush_err: logger.warning( f"[RL] pause: clear_cache requested but reset_prefix_cache failed: {flush_err}" ) self._paused = True - logger.info( + logger.debug( f"[RL] Engine paused (generation quiesced, mode={mode}, clear_cache={clear_cache})" ) return { @@ -884,10 +891,7 @@ async def pause_generation(self, body: dict) -> dict: "clear_cache": clear_cache, } except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.error(f"[RL] Failed to pause: {e}") return {"status": "error", "message": str(e)} @@ -898,13 +902,10 @@ async def resume_generation(self, body: dict) -> dict: try: await self.engine_client.resume_generation() self._paused = False - logger.info("[RL] Engine resumed") + logger.debug("[RL] Engine resumed") return {"status": "ok", "message": "Engine resumed"} except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.error(f"[RL] Failed to resume: {e}") return {"status": "error", "message": str(e)} @@ -933,10 +934,7 @@ async def liveness_probe(self, body: dict) -> dict: await self.engine_client.collective_rpc("get_weight_version", kwargs={}) return {"status": "ok", "alive": True} except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.warning(f"[RL] liveness_probe failed: {e}") return {"status": "error", "alive": False, "message": str(e)} @@ -976,10 +974,7 @@ async def get_state(self, body: dict) -> dict: ], } except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.error(f"[RL] get_state failed: {e}") return {"status": "error", "message": str(e)} @@ -989,13 +984,10 @@ async def flush_cache(self, body: dict) -> dict: body = body or {} try: await self.engine_client.reset_prefix_cache() - logger.info("[RL] Prefix cache flushed") + logger.debug("[RL] Prefix cache flushed") return {"status": "ok", "message": "Cache flushed"} except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.error(f"[RL] Failed to flush cache: {e}") return {"status": "error", "message": str(e)} @@ -1021,17 +1013,14 @@ async def update_weights_from_path(self, body: dict) -> dict: kwargs={"weights_path": path}, ) self._weight_version = version - logger.info(f"[RL] Weights loaded from {path} (version={version})") + logger.debug(f"[RL] Weights loaded from {path} (version={version})") return { "status": "ok", "message": f"Weights loaded from {path}", "version": version, } except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.error(f"[RL] Failed to load weights from {path}: {e}") return {"status": "error", "message": str(e)} @@ -1162,7 +1151,7 @@ async def load_lora_adapter(self, body: dict) -> dict: lora_name=lora_name, base_model_path=self.config.model, ) - logger.info( + logger.debug( f"[RL] Published LoRA '{lora_name}' ModelDeploymentCard" ) except Exception as e: @@ -1187,7 +1176,7 @@ async def load_lora_adapter(self, body: dict) -> dict: "lora_name": lora_name, } - logger.info( + logger.debug( f"[RL] LoRA adapter {'hot-swapped' if is_hot_swap else 'loaded'}: " f"name={lora_name} id={lora_id} path={lora_path}" ) @@ -1199,10 +1188,7 @@ async def load_lora_adapter(self, body: dict) -> dict: "hot_swap": is_hot_swap, } except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.exception( f"[RL] Failed to load LoRA adapter '{lora_name}' from {lora_path}: {e}" @@ -1262,7 +1248,7 @@ async def unload_lora_adapter(self, body: dict) -> dict: "lora_id": lora_id, } - logger.info( + logger.debug( f"[RL] LoRA adapter unloaded: name={lora_name} id={lora_id}" ) return { @@ -1272,10 +1258,7 @@ async def unload_lora_adapter(self, body: dict) -> dict: "lora_id": lora_id, } except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) except Exception as e: logger.exception( f"[RL] Failed to unload LoRA adapter '{lora_name}': {e}" @@ -2469,10 +2452,7 @@ async def generate_tokens( num_output_tokens_so_far[output_idx] = next_total_toks except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) class DecodeWorkerHandler(BaseWorkerHandler): @@ -2714,10 +2694,7 @@ async def _generate_token_mode(self, request, context, request_id): ] = prefill_prompt_tokens_details yield tok except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) async def _generate_text_mode(self, request, context, request_id): """Generate text using OpenAI-compatible format (text-in-text-out).""" diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index fa5fd153891d..3c021dea13ae 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -55,10 +55,6 @@ use crate::protocols::openai::{ embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, images::{NvCreateImageRequest, NvImagesResponse}, responses::{NvCreateResponse, NvResponse, ResponseParams, chat_completion_to_response}, - tokenization::{ - DetokenizeRequest, DetokenizeResponse, TokenizeCompletionRequest, TokenizeRequest, - TokenizeResponse, - }, videos::{NvCreateVideoRequest, NvVideosResponse}, }; use crate::protocols::unified::UnifiedRequest; @@ -940,7 +936,18 @@ async fn handler_chat_completions( .take() .unwrap_or_else(|| dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL")); if rl_want_token_ids { - tracing::info!("RL: want_token_ids=true, will promote nvext.extra_fields"); + // Reject n > 1 for the RL token-id path: the streaming aggregator + // accumulates `completion_token_ids` into a single `Vec` shared + // across all choices, so the per-choice promotion downstream cannot + // recover which tokens belong to which choice. A keyed-by-index + // accumulator is the long-term fix. + if request.inner.n.unwrap_or(1) > 1 { + return Err(bad_request( + "n > 1 is not supported when RL token IDs are requested. \ + Send separate requests instead.", + )); + } + tracing::debug!("RL: want_token_ids=true, will promote nvext.extra_fields"); } { // If `tokens` is provided, inject into nvext.token_data (pre-tokenized prompt path). @@ -2113,138 +2120,6 @@ fn resolve_model_card( Ok((model, card)) } -// Handler kept (no callers in this branch) for downstream code that re-mounts -// `tokenization_router` in `service_v2.rs` standalone, until the upstream -// `/tokenize` and `/detokenize` work lands at the root paths. -#[allow(dead_code)] -async fn tokenize( - State(state): State>, - Json(request): Json, -) -> Result { - check_ready(&state)?; - - let (_, card) = resolve_model_card(&state, request.model())?; - let tokenizer = card - .tokenizer() - .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to load tokenizer"))?; - - let (tokens, token_strs) = match request { - TokenizeRequest::Completion(TokenizeCompletionRequest { - prompt, - add_special_tokens, - return_token_strs, - .. - }) => { - let encoding = tokenizer - .encode_with_special_tokens(&prompt, add_special_tokens) - .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to tokenize prompt"))?; - let token_ids = encoding.token_ids().to_vec(); - let token_strs = if return_token_strs { - Some(tokenizer.convert_ids_to_tokens(&token_ids).map_err(|err| { - ErrorMessage::from_anyhow(err, "Failed to resolve token strings") - })?) - } else { - None - }; - (token_ids, token_strs) - } - TokenizeRequest::Chat(request) => { - // Reject mutually-exclusive flags - // (continue_final_message + add_generation_prompt) up-front so the - // chat-template render below doesn't see an inconsistent state. - request - .validate() - .map_err(|err| bad_request(&format!("Invalid tokenize request: {err}")))?; - let model = request - .model - .clone() - .unwrap_or_else(|| card.display_name.clone()); - // Render the chat messages to a prompt string via the model's chat template - let formatter = crate::preprocessor::prompt::PromptFormatter::from_mdc(&card) - .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to build chat formatter"))?; - let inner_request = dynamo_protocols::types::CreateChatCompletionRequest { - model, - messages: request.messages.clone(), - tools: request.tools.clone(), - ..Default::default() - }; - let wrapped = - crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest { - inner: inner_request, - common: Default::default(), - nvext: None, - chat_template_args: Some(request.merged_chat_template_kwargs()), - media_io_kwargs: None, - tokens: None, - return_token_ids: None, - unsupported_fields: Default::default(), - }; - let prompt = match formatter { - crate::preprocessor::prompt::PromptFormatter::OAI(f) => f.render(&wrapped), - } - .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to render chat prompt"))?; - - let encoding = tokenizer - .encode_with_special_tokens(&prompt, request.add_special_tokens) - .map_err(|err| { - ErrorMessage::from_anyhow(err, "Failed to tokenize rendered chat prompt") - })?; - let token_ids = encoding.token_ids().to_vec(); - let token_strs = if request.return_token_strs { - Some(tokenizer.convert_ids_to_tokens(&token_ids).map_err(|err| { - ErrorMessage::from_anyhow(err, "Failed to resolve token strings") - })?) - } else { - None - }; - (token_ids, token_strs) - } - }; - - Ok(Json(TokenizeResponse { - count: tokens.len(), - max_model_len: card.context_length, - tokens, - token_strs, - }) - .into_response()) -} - -#[allow(dead_code)] // see tokenize() above -async fn detokenize( - State(state): State>, - Json(request): Json, -) -> Result { - check_ready(&state)?; - - let (_, card) = resolve_model_card(&state, request.model.as_deref())?; - let tokenizer = card - .tokenizer() - .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to load tokenizer"))?; - let prompt: String = tokenizer - .decode(&request.tokens, false) - .map_err(|err| ErrorMessage::from_anyhow(err, "Failed to detokenize prompt"))? - .into(); - - Ok(Json(DetokenizeResponse { prompt }).into_response()) -} - -#[allow(dead_code)] // see tokenize() above; not mounted in service_v2 v2 surface -pub fn tokenization_router(state: Arc) -> (Vec, Router) { - let tokenize_path = "/v1/tokenize"; - let detokenize_path = "/v1/detokenize"; - let docs = vec![ - RouteDoc::new(axum::http::Method::POST, tokenize_path), - RouteDoc::new(axum::http::Method::POST, detokenize_path), - ]; - let router = Router::new() - .route(tokenize_path, post(tokenize)) - .route(detokenize_path, post(detokenize)) - .layer(middleware::from_fn(smart_json_error_middleware)) - .layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) - .with_state(state); - (docs, router) -} /// openai compatible format /// Example: @@ -2371,21 +2246,6 @@ pub fn chat_completions_router( /// 0.20+ skips chat templating when that field is present, identical /// behavior. The handler is kept as `#[allow(dead_code)]` for downstream /// code that still references it; deletion is a follow-up cleanup. -#[allow(dead_code)] -pub fn chat_completions_tokens_router( - state: Arc, - template: Option, - path: Option, -) -> (Vec, Router) { - let path = path.unwrap_or("/v1/chat/completions/tokens".to_string()); - let doc = RouteDoc::new(axum::http::Method::POST, &path); - let router = Router::new() - .route(&path, post(handler_chat_completions_tokens)) - .layer(middleware::from_fn(smart_json_error_middleware)) - .layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) - .with_state((state, template)); - (vec![doc], router) -} /// Handler for TITO (Token-In / Token-Out) chat completions. /// @@ -2397,104 +2257,6 @@ pub fn chat_completions_tokens_router( /// 4. Forces `logprobs = true` (RL always needs logprobs) /// 5. Ensures `messages` is non-empty (Dynamo requires it for chat template selection) /// 6. Delegates to the standard `chat_completions()` internal function (zero HTTP proxy) -#[allow(dead_code)] // see chat_completions_tokens_router above -async fn handler_chat_completions_tokens( - State((state, template)): State<(Arc, Option)>, - headers: HeaderMap, - Json(mut request): Json, -) -> Result { - check_ready(&state)?; - - // Extract the tokens field (Prime-RL's TITO input) - let tokens = request.tokens.take(); - // Clear return_token_ids (not supported by Dynamo, avoid confusion) - request.return_token_ids = None; - - if let Some(token_ids) = tokens { - if token_ids.is_empty() { - return Err(ErrorMessage::bad_request( - "TITO endpoint requires non-empty 'tokens' field", - )); - } - - // Inject tokens into nvext.token_data - let mut nvext = request.nvext.take().unwrap_or_default(); - nvext.token_data = Some(token_ids); - - // Request token echo and completion token IDs in response - let mut extra_fields = nvext.extra_fields.take().unwrap_or_default(); - for field in &["token_ids", "completion_token_ids"] { - if !extra_fields.contains(&field.to_string()) { - extra_fields.push(field.to_string()); - } - } - nvext.extra_fields = Some(extra_fields); - request.nvext = Some(nvext); - - // Force logprobs on (RL always needs them). Unconditional — an - // explicit logprobs=false from the caller would otherwise silently - // strip completion_token_ids from the response. - request.inner.logprobs = Some(true); - - // Ensure messages is non-empty (Dynamo requires it for model lookup / chat template) - if request.inner.messages.is_empty() { - use dynamo_protocols::types::{ - ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, - ChatCompletionRequestUserMessageContent, - }; - request - .inner - .messages - .push(ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Text( - "(token-in mode)".to_string(), - ), - name: None, - }, - )); - } - } else { - return Err(ErrorMessage::bad_request( - "Missing 'tokens' field for TITO endpoint. \ - Use /v1/chat/completions for message-based requests.", - )); - } - - // Apply header routing overrides - request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers); - - // Delegate to the standard chat completions flow (no HTTP proxy!) - let request_id = get_or_create_request_id(&headers); - let streaming = request.inner.stream.unwrap_or(false); - let cancellation_labels = CancellationLabels { - model: request.inner.model.clone(), - endpoint: Endpoint::ChatCompletions.to_string(), - request_type: if streaming { "stream" } else { "unary" }.to_string(), - }; - let request = Context::with_id(request, request_id); - let context = request.context(); - - let (mut connection_handle, stream_handle) = create_connection_monitor( - context.clone(), - Some(state.metrics_clone()), - cancellation_labels, - ) - .await; - - let response = - tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span()) - .await - .map_err(|e| { - ErrorMessage::internal_server_error(&format!( - "Failed to await TITO chat completions task: {:?}", - e, - )) - })?; - - connection_handle.disarm(); - response -} /// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint /// If not path is provided, the default path is `/v1/embeddings` @@ -3131,6 +2893,9 @@ struct RlState { worker_system_urls: Vec, /// Shared HTTP client for all fan-out calls to worker system ports. http_client: reqwest::Client, + /// Per-worker probe timeout for `/v1/rl/liveness` and `/v1/rl/ready`. + /// Read once from `DYN_RL_LIVENESS_TIMEOUT_MS` at construction. + probe_timeout: std::time::Duration, } impl RlState { @@ -3141,19 +2906,39 @@ impl RlState { .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect::>(); + let probe_timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(5000); tracing::info!( worker_count = worker_system_urls.len(), ?worker_system_urls, + probe_timeout_ms, "RL admin router configured" ); let http_client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(600)) .build() .map_err(|e| anyhow::anyhow!("Failed to build RL router HTTP client: {e}"))?; - Ok(Self { + Ok(Self::new( worker_system_urls, http_client, - }) + std::time::Duration::from_millis(probe_timeout_ms), + )) + } + + /// Test-friendly constructor — bypasses env reading so tests can pass in + /// fake worker URLs and a stubbed `reqwest::Client`. + fn new( + worker_system_urls: Vec, + http_client: reqwest::Client, + probe_timeout: std::time::Duration, + ) -> Self { + Self { + worker_system_urls, + http_client, + probe_timeout, + } } /// Call a single engine route on one worker. Returns the JSON body. @@ -3207,11 +2992,7 @@ impl RlState { /// `DYN_RL_LIVENESS_TIMEOUT_MS`) so a wedged worker fails fast as 503 instead /// of hanging on the shared 600s `http_client` timeout. async fn rl_ready(State(state): State>) -> impl IntoResponse { - let timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(5000); - let timeout = std::time::Duration::from_millis(timeout_ms); + let timeout = state.probe_timeout; let futures: Vec<_> = state .worker_system_urls .iter() @@ -3219,12 +3000,10 @@ async fn rl_ready(State(state): State>) -> impl IntoResponse { let client = state.http_client.clone(); let health_url = format!("{url}/health"); async move { - tokio::time::timeout(timeout, client.get(&health_url).send()) - .await - .ok() - .and_then(Result::ok) - .map(|r| r.status().is_success()) - .unwrap_or(false) + match tokio::time::timeout(timeout, client.get(&health_url).send()).await { + Ok(Ok(resp)) => resp.status().is_success(), + Ok(Err(_)) | Err(_) => false, + } } }) .collect(); @@ -3253,10 +3032,36 @@ async fn rl_ready(State(state): State>) -> impl IntoResponse { /// Three-mode pause matches what vLLM exposes (abort / wait / keep). The /// default `mode=keep&clear_cache=false` preserves the original single-mode /// pause behavior so existing callers keep working without changes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[serde(rename_all = "lowercase")] +enum PauseMode { + Keep, + Wait, + Abort, +} + +impl PauseMode { + fn as_str(self) -> &'static str { + match self { + PauseMode::Keep => "keep", + PauseMode::Wait => "wait", + PauseMode::Abort => "abort", + } + } +} + +impl Default for PauseMode { + fn default() -> Self { + PauseMode::Keep + } +} + #[derive(Debug, serde::Deserialize)] struct RlPauseQuery { + /// Axum returns 400 automatically if this fails to deserialize as a + /// `PauseMode` (i.e. on `mode=invalid`), so we don't need a runtime check. #[serde(default)] - mode: Option, + mode: Option, #[serde(default)] clear_cache: Option, } @@ -3265,29 +3070,18 @@ async fn rl_pause( State(state): State>, axum::extract::Query(q): axum::extract::Query, ) -> impl IntoResponse { - let mode = q.mode.as_deref().unwrap_or("keep").to_string(); + let mode = q.mode.unwrap_or_default(); let clear_cache = q.clear_cache.unwrap_or(false); - if !matches!(mode.as_str(), "keep" | "wait" | "abort") { - return ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "status": "error", - "message": format!( - "Invalid mode '{mode}'; expected one of keep|wait|abort" - ), - })), - ); - } let results = state .fan_out( "pause_generation", - serde_json::json!({"mode": mode, "clear_cache": clear_cache}), + serde_json::json!({"mode": mode.as_str(), "clear_cache": clear_cache}), ) .await; if RlState::all_ok(&results) { tracing::info!( worker_count = results.len(), - mode = %mode, + mode = %mode.as_str(), clear_cache, "RL pause: all workers paused" ); @@ -3295,7 +3089,7 @@ async fn rl_pause( StatusCode::OK, Json(serde_json::json!({ "status": "ok", - "mode": mode, + "mode": mode.as_str(), "clear_cache": clear_cache, "workers": results, })), @@ -3329,7 +3123,13 @@ async fn rl_resume(State(state): State>) -> impl IntoResponse { } } -/// `POST /v1/rl/update_weights` — atomic `flush_cache → update_weights_from_path` across all workers. +/// `POST /v1/rl/update_weights` — fan out `flush_cache → update_weights_from_path` to all workers. +/// +/// **Not atomic.** If `update_weights_from_path` succeeds on workers `0..N-1` +/// and fails on worker `N`, the fleet is left in a mixed-version state: the +/// successful workers serve the new version while worker `N` still runs the +/// previous one. The response carries per-worker status so callers can +/// retry / drain manually; a true rollback layer is a follow-up. /// /// Body schema (`reset_prefix_cache` defaults to `true` — the v1 sequence /// always flushed before reload, this just makes it explicit): @@ -3365,7 +3165,16 @@ async fn rl_update_weights( ) -> impl IntoResponse { let reset_prefix_cache = body.reset_prefix_cache; - let Some(weight_dir) = body.weight_dir.clone() else { + // Treat empty string the same as missing/null (NCCL no-op). Otherwise + // an empty string would reach the engine as `path=""` and fail + // confusingly downstream. + let weight_dir = body + .weight_dir + .as_ref() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(str::to_string); + let Some(weight_dir) = weight_dir else { tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); return ( StatusCode::OK, @@ -3705,11 +3514,7 @@ async fn rl_liveness(State(state): State>) -> impl IntoResponse { })), ); } - let timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(5000); - let timeout = std::time::Duration::from_millis(timeout_ms); + let timeout = state.probe_timeout; let futures: Vec<_> = state .worker_system_urls @@ -3742,7 +3547,7 @@ async fn rl_liveness(State(state): State>) -> impl IntoResponse { .unwrap_or_else(|_| serde_json::json!({ "status": "error", "alive": false, - "message": format!("liveness_probe timed out after {timeout_ms}ms") + "message": format!("liveness_probe timed out after {}ms", timeout.as_millis()) })) } }) @@ -4168,6 +3973,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_required_fields(&request); assert!(result.is_err()); @@ -4200,6 +4008,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_required_fields(&request); assert!(result.is_ok()); @@ -4416,6 +4227,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); @@ -4446,6 +4260,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -4475,6 +4292,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -4504,6 +4324,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -4535,6 +4358,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -4564,6 +4390,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -5591,4 +5420,65 @@ mod tests { let json = extract_sse_data_json(events[0].as_ref().unwrap()); assert_eq!(json["reasoning_content"], "让我想想 🤔 分析完成 ✅"); } + + // ── RL admin types ────────────────────────────────────────────────── + + #[test] + fn test_pause_mode_serde_roundtrip() { + for (mode, lower) in [ + (PauseMode::Keep, "keep"), + (PauseMode::Wait, "wait"), + (PauseMode::Abort, "abort"), + ] { + let json = serde_json::to_string(&mode).unwrap(); + assert_eq!(json, format!("\"{lower}\"")); + assert_eq!(mode.as_str(), lower); + let parsed: PauseMode = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, mode); + } + assert_eq!(PauseMode::default(), PauseMode::Keep); + } + + #[test] + fn test_pause_mode_rejects_unknown_value() { + // Axum returns 400 on this deserialize failure before the handler + // runs — that's the whole point of the typed enum vs the prior + // string match. + let err = serde_json::from_str::("\"foo\"") + .expect_err("foo is not a valid PauseMode"); + assert!(err.to_string().to_lowercase().contains("foo")); + } + + #[test] + fn test_rl_update_weights_body_defaults() { + let body: RlUpdateWeightsBody = serde_json::from_str(r#"{}"#).unwrap(); + assert!(body.weight_dir.is_none()); + assert!(body.weight_version.is_none()); + assert!(body.reset_prefix_cache); + + let body: RlUpdateWeightsBody = + serde_json::from_str(r#"{"weight_dir":null}"#).unwrap(); + assert!(body.weight_dir.is_none()); + assert!(body.reset_prefix_cache); + + let body: RlUpdateWeightsBody = + serde_json::from_str(r#"{"weight_dir":"/path","reset_prefix_cache":false}"#) + .unwrap(); + assert_eq!(body.weight_dir.as_deref(), Some("/path")); + assert!(!body.reset_prefix_cache); + } + + #[test] + fn test_rl_state_new_constructs_without_env() { + // Sanity check the testability constructor — needed so future + // route-level tests can build an `RlState` without env vars or a + // real network client. + let state = RlState::new( + vec!["http://w0:9090".to_string(), "http://w1:9090".to_string()], + reqwest::Client::new(), + std::time::Duration::from_millis(100), + ); + assert_eq!(state.worker_system_urls.len(), 2); + assert_eq!(state.probe_timeout, std::time::Duration::from_millis(100)); + } } diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 8e26252b09af..9429fe38f6d9 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -538,13 +538,6 @@ impl HttpServiceConfigBuilder { } else { super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()) }, - // /v1/tokenize and /v1/detokenize are NOT required by prime-rl - // (source audit: zero references). Owned by jthomson04 PR #7699 - // which mounts /tokenize and /detokenize at root paths for the - // NeMo-rl tokenize-then-generate pattern. Dropped from the v2 - // surface here per `bis-dev/design-docs/rl-support.md` §1 - // out-of-scope. Re-enable by uncommenting the next line: - // super::openai::tokenization_router(state.clone()), super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), @@ -618,16 +611,10 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_CHAT_PATH_ENV).ok(), ); - // /v1/chat/completions/tokens (the v1 TITO fork URI) is dropped per - // `bis-dev/design-docs/rl-support.md` Phase 5 + hhzhang16 HH-22 / HH-26. - // TITO callers retarget to /v1/chat/completions with `prompt_token_ids` - // as a top-level extension (now in `validate.rs:104` - // PASSTHROUGH_EXTRA_FIELDS) — vLLM 0.20+ skips chat templating when - // that field is present, identical behavior to the dropped fork URI. - // The handler `handler_chat_completions_tokens` and helper - // `chat_completions_tokens_router` are intentionally left in the - // codebase as dead code for now; a subsequent commit can delete - // them once prime-rl has fully migrated. + // The legacy `/v1/chat/completions/tokens` TITO fork URI is dropped. + // TITO callers send `prompt_token_ids` as a top-level extension on + // `/v1/chat/completions` (allowlisted by `validate.rs::PASSTHROUGH_EXTRA_FIELDS`); + // vLLM 0.20+ skips chat templating when that field is present. let (cmpl_docs, cmpl_route) = super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok()); diff --git a/lib/llm/src/protocols/anthropic/types.rs b/lib/llm/src/protocols/anthropic/types.rs index 4104191a61ae..33bc1be37422 100644 --- a/lib/llm/src/protocols/anthropic/types.rs +++ b/lib/llm/src/protocols/anthropic/types.rs @@ -823,6 +823,8 @@ mod tests { }), }, nvext: None, + + prompt_token_ids: None, }; let response = chat_completion_to_anthropic_response(chat_resp, "test-model", None); diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index a3af818a4edd..e22d4ae12f80 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -20,7 +20,6 @@ pub mod images; pub mod models; pub mod nvext; pub mod responses; -pub mod tokenization; pub mod tools; pub mod validate; pub mod videos; diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 75acc48356cf..facff839fc1d 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -556,6 +556,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } @@ -649,6 +652,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index 1c4711c3a4cc..ffce160e8ab7 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -910,6 +910,8 @@ mod tests { token_ids: true, routed_experts: true, engine_data: false, + + completion_token_ids: false, }; let tracker = tracker_with_prefill_worker(); let params = disagg_params_full(); @@ -946,6 +948,8 @@ mod tests { token_ids: false, // only enabled via query_instance_id routed_experts: true, engine_data: false, + + completion_token_ids: false, } ); } diff --git a/lib/llm/src/protocols/openai/responses/mod.rs b/lib/llm/src/protocols/openai/responses/mod.rs index e806b3c74732..708bd8dc8238 100644 --- a/lib/llm/src/protocols/openai/responses/mod.rs +++ b/lib/llm/src/protocols/openai/responses/mod.rs @@ -2115,6 +2115,8 @@ mod tests { usage: None, }, nvext: None, + + prompt_token_ids: None, }; let wrapped = @@ -2176,6 +2178,8 @@ mod tests { usage: None, }, nvext: None, + + prompt_token_ids: None, }; let wrapped = @@ -2381,6 +2385,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2414,6 +2420,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2442,6 +2450,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2467,6 +2477,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2585,6 +2597,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, } } diff --git a/lib/llm/src/protocols/openai/tokenization.rs b/lib/llm/src/protocols/openai/tokenization.rs deleted file mode 100644 index 95559684ad89..000000000000 --- a/lib/llm/src/protocols/openai/tokenization.rs +++ /dev/null @@ -1,124 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -use std::collections::HashMap; - -use serde::{Deserialize, Serialize}; - -use crate::preprocessor::media::MediaDecoder; -use crate::types::TokenIdType; - -fn default_true() -> bool { - true -} - -fn default_false() -> bool { - false -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct TokenizeCompletionRequest { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model: Option, - pub prompt: String, - #[serde(default = "default_true")] - pub add_special_tokens: bool, - #[serde(default = "default_false")] - pub return_token_strs: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct TokenizeChatRequest { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model: Option, - pub messages: Vec, - #[serde(default = "default_true")] - pub add_generation_prompt: bool, - #[serde(default = "default_false")] - pub return_token_strs: bool, - #[serde(default = "default_false")] - pub continue_final_message: bool, - #[serde(default = "default_false")] - pub add_special_tokens: bool, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub chat_template: Option, - #[serde( - default, - skip_serializing_if = "Option::is_none", - alias = "chat_template_args" - )] - pub chat_template_kwargs: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub media_io_kwargs: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub mm_processor_kwargs: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tools: Option>, -} - -impl TokenizeChatRequest { - pub fn validate(&self) -> Result<(), String> { - if self.continue_final_message && self.add_generation_prompt { - return Err( - "Cannot set both `continue_final_message` and `add_generation_prompt` to True." - .to_string(), - ); - } - - Ok(()) - } - - pub fn merged_chat_template_kwargs(&self) -> HashMap { - let mut kwargs = self.chat_template_kwargs.clone().unwrap_or_default(); - kwargs.insert( - "add_generation_prompt".to_string(), - serde_json::Value::Bool(self.add_generation_prompt), - ); - kwargs.insert( - "continue_final_message".to_string(), - serde_json::Value::Bool(self.continue_final_message), - ); - kwargs - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -#[allow(clippy::large_enum_variant)] -pub enum TokenizeRequest { - Completion(TokenizeCompletionRequest), - Chat(TokenizeChatRequest), -} - -impl TokenizeRequest { - pub fn model(&self) -> Option<&str> { - match self { - Self::Completion(request) => request.model.as_deref(), - Self::Chat(request) => request.model.as_deref(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct TokenizeResponse { - pub count: usize, - pub max_model_len: u32, - pub tokens: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub token_strs: Option>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct DetokenizeRequest { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model: Option, - pub tokens: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct DetokenizeResponse { - pub prompt: String, -} diff --git a/lib/llm/src/protocols/unified.rs b/lib/llm/src/protocols/unified.rs index 6ce62744e7f3..e9a61f7a2110 100644 --- a/lib/llm/src/protocols/unified.rs +++ b/lib/llm/src/protocols/unified.rs @@ -535,6 +535,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let unified = UnifiedRequest::from(req.clone()); diff --git a/lib/llm/tests/parallel_tool_call_integration.rs b/lib/llm/tests/parallel_tool_call_integration.rs index 2827239d4754..896c005d13ec 100644 --- a/lib/llm/tests/parallel_tool_call_integration.rs +++ b/lib/llm/tests/parallel_tool_call_integration.rs @@ -93,6 +93,9 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/preprocessor.rs b/lib/llm/tests/preprocessor.rs index 2896e01c427f..45a6893b656d 100644 --- a/lib/llm/tests/preprocessor.rs +++ b/lib/llm/tests/preprocessor.rs @@ -261,6 +261,9 @@ impl Request { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } } @@ -701,6 +704,9 @@ mod context_length_validation { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/test_streaming_usage.rs b/lib/llm/tests/test_streaming_usage.rs index 0a6fd3178bf6..2f94498f6168 100644 --- a/lib/llm/tests/test_streaming_usage.rs +++ b/lib/llm/tests/test_streaming_usage.rs @@ -195,6 +195,9 @@ fn create_chat_request( chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } @@ -529,6 +532,9 @@ fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index bbff1bd38508..da2120b9a274 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -41,6 +41,9 @@ fn create_test_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs index d3d190c3953c..12c9b4ac8cab 100644 --- a/lib/llm/tests/tool_choice_finish_reasons.rs +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -34,6 +34,9 @@ fn create_test_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } From 6007c77c690c71c95f62d001ab2e56888eff080e Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 01:40:54 -0700 Subject: [PATCH 11/18] =?UTF-8?q?refactor(rl):=20extract=20dynamo-rl=20cra?= =?UTF-8?q?te=20at=20lib/rl=20(PR=20A=20=E2=80=94=20pure=20refactor,=20no?= =?UTF-8?q?=20behavior=20change)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 16 + Cargo.toml | 2 + lib/bindings/python/Cargo.lock | 16 + lib/llm/Cargo.toml | 1 + lib/llm/src/http/service/openai.rs | 803 +--------------------- lib/rl/Cargo.toml | 27 + lib/rl/src/lib.rs | 1011 ++++++++++++++++++++++++++++ 7 files changed, 1102 insertions(+), 774 deletions(-) create mode 100644 lib/rl/Cargo.toml create mode 100644 lib/rl/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 3acda9a73e6a..bd0ce7b991c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2433,6 +2433,7 @@ dependencies = [ "dynamo-mocker", "dynamo-parsers", "dynamo-protocols", + "dynamo-rl", "dynamo-runtime", "dynamo-tokenizers", "dynamo-tokens", @@ -2606,6 +2607,21 @@ dependencies = [ "uuid", ] +[[package]] +name = "dynamo-rl" +version = "1.2.0" +dependencies = [ + "anyhow", + "axum 0.8.4", + "dynamo-runtime", + "futures", + "reqwest 0.12.28", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "dynamo-runtime" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index d7d2f9fd02f6..6cc5e3d7a270 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "lib/backend-common/examples/mocker", "lib/bindings/c", "lib/bindings/python/codegen", + "lib/rl", ] resolver = "3" @@ -41,6 +42,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"] # Local crates dynamo-runtime = { path = "lib/runtime", version = "1.2.0" } dynamo-llm = { path = "lib/llm", version = "1.2.0" } +dynamo-rl = { path = "lib/rl", version = "1.2.0" } dynamo-config = { path = "lib/config", version = "1.2.0" } dynamo-tokenizers = { path = "lib/tokenizers", version = "1.2.0" } dynamo-tokens = { path = "lib/tokens", version = "1.2.0" } diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index e9bb77d0e246..1a31b48b50ac 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -2071,6 +2071,7 @@ dependencies = [ "dynamo-mocker", "dynamo-parsers", "dynamo-protocols", + "dynamo-rl", "dynamo-runtime", "dynamo-tokenizers", "dynamo-tokens", @@ -2241,6 +2242,21 @@ dependencies = [ "uuid", ] +[[package]] +name = "dynamo-rl" +version = "1.2.0" +dependencies = [ + "anyhow", + "axum", + "dynamo-runtime", + "futures", + "reqwest", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "dynamo-runtime" version = "1.2.0" diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 31559be92ed4..d384bf1d5510 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -56,6 +56,7 @@ dynamo-config = { workspace = true } dynamo-kv-router = { workspace = true, features = ["metrics", "runtime-protocols"] } dynamo-memory = { workspace = true } dynamo-mocker = { workspace = true } +dynamo-rl = { workspace = true } dynamo-runtime = { workspace = true } dynamo-tokenizers = { workspace = true } dynamo-tokens = { workspace = true } diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 3c021dea13ae..e5e4c6c9e8a7 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -1589,9 +1589,7 @@ pub fn validate_chat_completion_required_fields( // RL renderer / TITO callers send `prompt_token_ids` (or legacy // `nvext.token_data`) in place of `messages`. Treat either pre-tokenized // input as satisfying the "non-empty input" requirement. - let has_pretokenized_input = request - .unsupported_fields - .contains_key("prompt_token_ids") + let has_pretokenized_input = request.unsupported_fields.contains_key("prompt_token_ids") || request .nvext .as_ref() @@ -2120,7 +2118,6 @@ fn resolve_model_card( Ok((model, card)) } - /// openai compatible format /// Example: /// { @@ -2880,529 +2877,6 @@ pub fn audios_router( // ────────────────────────────────────────────────────────────────────────── // RL Admin router: /v1/rl/* // ────────────────────────────────────────────────────────────────────────── - -/// Environment variable for comma-separated worker system HTTP URLs. -/// Defaults to `http://localhost:8081` when not set. -const DYN_RL_WORKER_SYSTEM_URLS_ENV: &str = "DYN_RL_WORKER_SYSTEM_URLS"; - -/// Shared state for the RL admin router. -#[derive(Clone)] -struct RlState { - /// Worker system HTTP base URLs (e.g. `http://localhost:8081`). - /// Set via `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated list). - worker_system_urls: Vec, - /// Shared HTTP client for all fan-out calls to worker system ports. - http_client: reqwest::Client, - /// Per-worker probe timeout for `/v1/rl/liveness` and `/v1/rl/ready`. - /// Read once from `DYN_RL_LIVENESS_TIMEOUT_MS` at construction. - probe_timeout: std::time::Duration, -} - -impl RlState { - fn from_env() -> anyhow::Result { - let worker_system_urls = std::env::var(DYN_RL_WORKER_SYSTEM_URLS_ENV) - .unwrap_or_else(|_| "http://localhost:8081".to_string()) - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect::>(); - let probe_timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(5000); - tracing::info!( - worker_count = worker_system_urls.len(), - ?worker_system_urls, - probe_timeout_ms, - "RL admin router configured" - ); - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build RL router HTTP client: {e}"))?; - Ok(Self::new( - worker_system_urls, - http_client, - std::time::Duration::from_millis(probe_timeout_ms), - )) - } - - /// Test-friendly constructor — bypasses env reading so tests can pass in - /// fake worker URLs and a stubbed `reqwest::Client`. - fn new( - worker_system_urls: Vec, - http_client: reqwest::Client, - probe_timeout: std::time::Duration, - ) -> Self { - Self { - worker_system_urls, - http_client, - probe_timeout, - } - } - - /// Call a single engine route on one worker. Returns the JSON body. - async fn call_engine_route( - &self, - url: &str, - route: &str, - body: &serde_json::Value, - ) -> serde_json::Value { - let endpoint = format!("{url}/engine/{route}"); - match self.http_client.post(&endpoint).json(body).send().await { - Ok(resp) => { - let status = resp.status(); - match resp.json::().await { - Ok(v) => v, - Err(e) => serde_json::json!({ - "status": "error", - "message": format!("Failed to decode response from {endpoint}: {e}"), - "http_status": status.as_u16() - }), - } - } - Err(e) => serde_json::json!({ - "status": "error", - "message": format!("Request to {endpoint} failed: {e}") - }), - } - } - - /// Fan out an engine route call to all configured workers concurrently. - async fn fan_out(&self, route: &str, body: serde_json::Value) -> Vec { - let futures: Vec<_> = self - .worker_system_urls - .iter() - .map(|url| self.call_engine_route(url, route, &body)) - .collect(); - futures::future::join_all(futures).await - } - - /// Returns true only if all results have `status: "ok"`. - fn all_ok(results: &[serde_json::Value]) -> bool { - results - .iter() - .all(|r| r.get("status").and_then(|s| s.as_str()) == Some("ok")) - } -} - -/// `GET /v1/rl/ready` — composite readiness check: worker health via system port. -/// -/// Bounded with a per-worker probe timeout (default 5s, override via -/// `DYN_RL_LIVENESS_TIMEOUT_MS`) so a wedged worker fails fast as 503 instead -/// of hanging on the shared 600s `http_client` timeout. -async fn rl_ready(State(state): State>) -> impl IntoResponse { - let timeout = state.probe_timeout; - let futures: Vec<_> = state - .worker_system_urls - .iter() - .map(|url| { - let client = state.http_client.clone(); - let health_url = format!("{url}/health"); - async move { - match tokio::time::timeout(timeout, client.get(&health_url).send()).await { - Ok(Ok(resp)) => resp.status().is_success(), - Ok(Err(_)) | Err(_) => false, - } - } - }) - .collect(); - let results = futures::future::join_all(futures).await; - let all_ready = !results.is_empty() && results.iter().all(|ok| *ok); - if all_ready { - (StatusCode::OK, Json(serde_json::json!({"status": "ready"}))) - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "status": "not_ready", - "workers_ready": results.iter().filter(|ok| **ok).count(), - "workers_total": results.len() - })), - ) - } -} - -/// `POST /v1/rl/pause` — fan out `pause_generation` to all workers. -/// -/// Query params (both optional): -/// - `mode`: `keep` | `wait` | `abort` (default `keep`) -/// - `clear_cache`: `true` | `false` (default `false`) -/// -/// Three-mode pause matches what vLLM exposes (abort / wait / keep). The -/// default `mode=keep&clear_cache=false` preserves the original single-mode -/// pause behavior so existing callers keep working without changes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] -#[serde(rename_all = "lowercase")] -enum PauseMode { - Keep, - Wait, - Abort, -} - -impl PauseMode { - fn as_str(self) -> &'static str { - match self { - PauseMode::Keep => "keep", - PauseMode::Wait => "wait", - PauseMode::Abort => "abort", - } - } -} - -impl Default for PauseMode { - fn default() -> Self { - PauseMode::Keep - } -} - -#[derive(Debug, serde::Deserialize)] -struct RlPauseQuery { - /// Axum returns 400 automatically if this fails to deserialize as a - /// `PauseMode` (i.e. on `mode=invalid`), so we don't need a runtime check. - #[serde(default)] - mode: Option, - #[serde(default)] - clear_cache: Option, -} - -async fn rl_pause( - State(state): State>, - axum::extract::Query(q): axum::extract::Query, -) -> impl IntoResponse { - let mode = q.mode.unwrap_or_default(); - let clear_cache = q.clear_cache.unwrap_or(false); - let results = state - .fan_out( - "pause_generation", - serde_json::json!({"mode": mode.as_str(), "clear_cache": clear_cache}), - ) - .await; - if RlState::all_ok(&results) { - tracing::info!( - worker_count = results.len(), - mode = %mode.as_str(), - clear_cache, - "RL pause: all workers paused" - ); - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "mode": mode.as_str(), - "clear_cache": clear_cache, - "workers": results, - })), - ) - } else { - tracing::warn!(?results, "RL pause: some workers failed"); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), - ) - } -} - -/// `POST /v1/rl/resume` — fan out `resume_generation` to all workers. -async fn rl_resume(State(state): State>) -> impl IntoResponse { - let results = state - .fan_out("resume_generation", serde_json::json!({})) - .await; - if RlState::all_ok(&results) { - tracing::info!(worker_count = results.len(), "RL resume: all workers resumed"); - ( - StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), - ) - } else { - tracing::warn!(?results, "RL resume: some workers failed"); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), - ) - } -} - -/// `POST /v1/rl/update_weights` — fan out `flush_cache → update_weights_from_path` to all workers. -/// -/// **Not atomic.** If `update_weights_from_path` succeeds on workers `0..N-1` -/// and fails on worker `N`, the fleet is left in a mixed-version state: the -/// successful workers serve the new version while worker `N` still runs the -/// previous one. The response carries per-worker status so callers can -/// retry / drain manually; a true rollback layer is a follow-up. -/// -/// Body schema (`reset_prefix_cache` defaults to `true` — the v1 sequence -/// always flushed before reload, this just makes it explicit): -/// ```json -/// { -/// "weight_dir": "/path/to/checkpoint" | null, // null → NCCL mode no-op -/// "weight_version": "step_42", // optional; derived from -/// // weight_dir basename if missing -/// "reset_prefix_cache": true -/// } -/// ``` -/// -/// Returns `{ "status": "ok", "applied_weight_version": "step_42", "workers": [...] }` on success. -/// -/// The pause/resume envelope is left to the caller; full-FT updates MUST -/// bracket this call with `/v1/rl/pause` and `/v1/rl/resume`. -#[derive(Debug, serde::Deserialize)] -struct RlUpdateWeightsBody { - weight_dir: Option, - #[serde(default)] - weight_version: Option, - #[serde(default = "default_reset_prefix_cache")] - reset_prefix_cache: bool, -} - -fn default_reset_prefix_cache() -> bool { - true -} - -async fn rl_update_weights( - State(state): State>, - body: axum::extract::Json, -) -> impl IntoResponse { - let reset_prefix_cache = body.reset_prefix_cache; - - // Treat empty string the same as missing/null (NCCL no-op). Otherwise - // an empty string would reach the engine as `path=""` and fail - // confusingly downstream. - let weight_dir = body - .weight_dir - .as_ref() - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(str::to_string); - let Some(weight_dir) = weight_dir else { - tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); - return ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "message": "NCCL mode, no-op on Dynamo side" - })), - ); - }; - - let version = body.weight_version.clone().unwrap_or_else(|| { - std::path::Path::new(&weight_dir) - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string() - }); - tracing::info!( - weight_dir = %weight_dir, - version = %version, - reset_prefix_cache, - "RL update_weights" - ); - - // Step 1 (optional): flush_cache across all workers. - if reset_prefix_cache { - let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; - if !RlState::all_ok(&flush_results) { - tracing::warn!(?flush_results, "RL update_weights: flush_cache failed"); - return ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({ - "status": "error", - "stage": "flush_cache", - "workers": flush_results - })), - ); - } - } - - // Step 2: update_weights_from_path across all workers. - let load_body = serde_json::json!({"path": &weight_dir, "version": version}); - let load_results = state.fan_out("update_weights_from_path", load_body).await; - if RlState::all_ok(&load_results) { - tracing::info!( - worker_count = load_results.len(), - weight_dir = %weight_dir, - "RL update_weights: all workers updated" - ); - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "applied_weight_version": version, - "workers": load_results, - })), - ) - } else { - tracing::warn!( - ?load_results, - "RL update_weights: update_weights_from_path failed" - ); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({ - "status": "error", - "stage": "update_weights_from_path", - "workers": load_results - })), - ) - } -} - -/// `POST /v1/rl/load_lora_adapter` — hot-load/swap a LoRA adapter from a filesystem path. -/// -/// Expected body: `{"lora_name": "r16-a32.0", "lora_path": "/path/to/adapter_dir"}` -/// -/// The adapter directory must contain PEFT-style `adapter_model.safetensors` and -/// `adapter_config.json`. This is the RL-specific LoRA path used by Prime-RL every -/// training step (separate from Dynamo's URI-based `load_lora` gRPC endpoint which -/// downloads adapters from S3/file URIs and publishes a new ModelDeploymentCard). -/// -/// Hot-swap semantics: calling with a `lora_name` that is already loaded removes -/// the previous adapter and loads the new one under the same deterministic int ID, -/// then resets the prefix cache so stale KV entries don't poison new rollouts. -/// -/// Pair with `/v1/rl/pause` and `/v1/rl/resume` for a full drain-swap-resume cycle. -async fn rl_load_lora_adapter( - State(state): State>, - body: axum::extract::Json, -) -> impl IntoResponse { - let lora_name = body.get("lora_name").and_then(|v| v.as_str()); - let lora_path = body.get("lora_path").and_then(|v| v.as_str()); - - let (lora_name, lora_path) = match (lora_name, lora_path) { - (Some(n), Some(p)) if !n.is_empty() && !p.is_empty() => (n.to_string(), p.to_string()), - _ => { - return ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "status": "error", - "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)" - })), - ); - } - }; - - tracing::info!(%lora_name, %lora_path, "RL load_lora_adapter"); - let results = state - .fan_out( - "load_lora_adapter", - serde_json::json!({"lora_name": &lora_name, "lora_path": &lora_path}), - ) - .await; - - if RlState::all_ok(&results) { - tracing::info!( - worker_count = results.len(), - %lora_name, - %lora_path, - "RL load_lora_adapter: all workers loaded" - ); - ( - StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), - ) - } else { - tracing::warn!(?results, %lora_name, "RL load_lora_adapter: some workers failed"); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), - ) - } -} - -/// `POST /v1/rl/unload_lora_adapter` — remove a previously loaded LoRA adapter by name. -/// -/// Expected body: `{"lora_name": "r16-a32.0"}` -/// -/// Idempotent: unloading an already-absent LoRA returns `status: ok` so callers -/// can retry safely without special-casing not-found. -async fn rl_unload_lora_adapter( - State(state): State>, - body: axum::extract::Json, -) -> impl IntoResponse { - let lora_name = body - .get("lora_name") - .and_then(|v| v.as_str()) - .map(str::to_string); - - let lora_name = match lora_name { - Some(n) if !n.is_empty() => n, - _ => { - return ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "status": "error", - "message": "Expected body: {\"lora_name\": str} (required, non-empty)" - })), - ); - } - }; - - tracing::info!(%lora_name, "RL unload_lora_adapter"); - let results = state - .fan_out( - "unload_lora_adapter", - serde_json::json!({"lora_name": &lora_name}), - ) - .await; - - if RlState::all_ok(&results) { - tracing::info!( - worker_count = results.len(), - %lora_name, - "RL unload_lora_adapter: all workers unloaded" - ); - ( - StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), - ) - } else { - tracing::warn!(?results, %lora_name, "RL unload_lora_adapter: some workers failed"); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), - ) - } -} - -/// `GET /v1/rl/weight_version` — query weight version from all workers. -async fn rl_weight_version(State(state): State>) -> impl IntoResponse { - let results = state - .fan_out("get_weight_version", serde_json::json!({})) - .await; - - // Collect distinct versions and check for consistency - let versions: Vec<_> = results - .iter() - .filter_map(|r| { - r.get("version") - .and_then(|v| v.as_str()) - .map(str::to_string) - }) - .collect(); - - let unique: std::collections::HashSet<&str> = versions.iter().map(String::as_str).collect(); - if unique.len() == 1 { - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "version": unique.into_iter().next().unwrap_or(""), - "workers": results - })), - ) - } else { - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "inconsistent", - "versions": unique.into_iter().collect::>(), - "workers": results - })), - ) - } -} - /// Tokenize chat messages using the model's tokenizer and return prompt token IDs. /// Used by the RL post-processing path to populate `response.prompt_token_ids`. fn rl_tokenize_prompt( @@ -3491,244 +2965,27 @@ fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { /// **Deprecated in favor of `/v1/rl/state.ingress_alive`.** Kept for /// back-compat until existing clients migrate to `/v1/rl/state`; will be /// removed in a follow-up. -async fn rl_health() -> impl IntoResponse { - (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) -} -/// `GET /v1/rl/liveness` — engine event-loop probe via the `liveness_probe` -/// engine route. The legacy `/v1/rl/health` returns OK as long as the -/// frontend process is up; this endpoint round-trips through the engine so -/// a hung event loop or wedged worker surfaces as 503. -/// -/// Each per-worker call carries a 5s timeout (override via -/// `DYN_RL_LIVENESS_TIMEOUT_MS`). Returns 200 only when every worker -/// reports `alive: true` within the deadline; 503 otherwise. -async fn rl_liveness(State(state): State>) -> impl IntoResponse { - if state.worker_system_urls.is_empty() { - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "status": "error", - "alive": false, - "message": "no workers registered" - })), - ); - } - let timeout = state.probe_timeout; +// ── RL admin router ──────────────────────────────────────────────────── +// All `/v1/rl/*` handlers, `RlState`, body types, and fan-out logic now +// live in the `dynamo-rl` crate (see `plans/rl-crate.md`). This shim +// delegates and wraps the result into dynamo-llm's `RouteDoc` plus the +// shared `smart_json_error_middleware` that all OpenAI-side routes use. - let futures: Vec<_> = state - .worker_system_urls - .iter() - .map(|url| { - let client = state.http_client.clone(); - let endpoint = format!("{url}/engine/liveness_probe"); - async move { - tokio::time::timeout( - timeout, - async { - match client.post(&endpoint).json(&serde_json::json!({})).send().await { - Ok(resp) => resp - .json::() - .await - .unwrap_or_else(|e| serde_json::json!({ - "status": "error", - "alive": false, - "message": format!("decode failed: {e}") - })), - Err(e) => serde_json::json!({ - "status": "error", - "alive": false, - "message": format!("request failed: {e}") - }), - } - }, - ) - .await - .unwrap_or_else(|_| serde_json::json!({ - "status": "error", - "alive": false, - "message": format!("liveness_probe timed out after {}ms", timeout.as_millis()) - })) - } - }) - .collect(); - let results = futures::future::join_all(futures).await; - let all_alive = results - .iter() - .all(|r| r.get("alive").and_then(|v| v.as_bool()) == Some(true)); - if all_alive { - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "alive": true, - "workers": results, - })), - ) - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "status": "error", - "alive": false, - "workers": results, - })), - ) - } -} - -/// `GET /v1/rl/state` — composite RL fleet state snapshot. -/// -/// Replaces three v1 endpoints (`/v1/rl/health` + `/v1/rl/ready` + -/// `/v1/rl/weight_version`) with a single composite, scoped to RL-specific -/// readiness (engine alive, pause state, applied weight version, loaded -/// LoRAs). -/// -/// Aggregates per-worker `get_state` engine-route responses into: -/// -/// ```json -/// { -/// "ready": bool, -/// "ingress_alive": true, -/// "engine_alive": bool, // every worker's engine.check_health() ok -/// "pause_state": "running"|"paused"|"mixed", -/// "applied_weight_version": str, // when consistent across workers; null if mixed -/// "loras": [{name, loaded_on: [worker_idx]}], -/// "workers": [] -/// } -/// ``` +/// Build the `/v1/rl/*` router. Delegates to `dynamo_rl::rl_router()` and +/// wraps the documentation tuples into `RouteDoc`. Wraps the router with +/// `smart_json_error_middleware` so 422s are coerced to 400s consistently +/// with the OpenAI-compat surface. /// -/// `ingress_alive` is unconditionally `true` because reaching this handler -/// means the frontend HTTP listener is up. `ready = ingress_alive AND -/// engine_alive AND len(workers) > 0`. -async fn rl_state(State(state): State>) -> impl IntoResponse { - if state.worker_system_urls.is_empty() { - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "ready": false, - "ingress_alive": true, - "engine_alive": false, - "pause_state": "running", - "applied_weight_version": null, - "loras": [], - "workers": [], - "status": "error", - "message": "no workers registered" - })), - ); - } - let results = state.fan_out("get_state", serde_json::json!({})).await; - - let engine_alive = results - .iter() - .all(|r| r.get("engine_alive").and_then(|v| v.as_bool()) == Some(true)); - - // Aggregate pause_state: if all workers agree, surface that; else "mixed". - let pause_states: std::collections::HashSet<&str> = results - .iter() - .filter_map(|r| r.get("pause_state").and_then(|v| v.as_str())) - .collect(); - let pause_state = if pause_states.len() == 1 { - pause_states.into_iter().next().unwrap_or("running").to_string() - } else if pause_states.is_empty() { - "running".to_string() - } else { - "mixed".to_string() - }; - - // applied_weight_version is reported only when consistent. - let weight_versions: std::collections::HashSet<&str> = results - .iter() - .filter_map(|r| r.get("applied_weight_version").and_then(|v| v.as_str())) - .collect(); - let applied_weight_version: Option = if weight_versions.len() == 1 { - weight_versions.into_iter().next().map(|s| s.to_string()) - } else { - None - }; - - // LoRA name → list of worker indices that have it loaded. - let mut lora_loaded_on: std::collections::BTreeMap> = - std::collections::BTreeMap::new(); - for (idx, worker) in results.iter().enumerate() { - if let Some(loras) = worker.get("loras").and_then(|v| v.as_array()) { - for lora in loras { - if let Some(name) = lora.get("name").and_then(|v| v.as_str()) { - lora_loaded_on.entry(name.to_string()).or_default().push(idx); - } - } - } - } - let loras: Vec = lora_loaded_on +/// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` +/// is set. Mounted by `service_v2.rs`. +pub fn rl_router() -> anyhow::Result<(Vec, Router)> { + let (rl_docs, router) = dynamo_rl::rl_router()?; + let docs = rl_docs .into_iter() - .map(|(name, loaded_on)| serde_json::json!({"name": name, "loaded_on": loaded_on})) + .map(|d| RouteDoc::new(d.method, d.path)) .collect(); - - let ready = engine_alive && !results.is_empty(); - let body = serde_json::json!({ - "ready": ready, - "ingress_alive": true, - "engine_alive": engine_alive, - "pause_state": pause_state, - "applied_weight_version": applied_weight_version, - "loras": loras, - "workers": results, - }); - let status = if ready { - StatusCode::OK - } else { - StatusCode::SERVICE_UNAVAILABLE - }; - (status, Json(body)) -} - -/// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. -/// -/// Worker system URLs are read from the `DYN_RL_WORKER_SYSTEM_URLS` environment -/// variable (comma-separated, defaults to `http://localhost:8081`). -/// -/// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` is set. -/// -/// Prime-RL usage: set `admin_base_url = ["http://dynamo-frontend:8000/v1/rl"]` -/// in the orchestrator config. Prime-RL strips the trailing `/v1` suffix only -/// if present, so `/v1/rl` is preserved and all routes resolve correctly. -pub fn rl_router() -> anyhow::Result<(Vec, Router)> { - let rl_state_arc = Arc::new(RlState::from_env()?); - let docs = vec![ - // Phase 1: composite endpoints. - RouteDoc::new(axum::http::Method::GET, "/v1/rl/state"), - RouteDoc::new(axum::http::Method::GET, "/v1/rl/liveness"), - // Pause / resume / update_weights bracket. - RouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), - RouteDoc::new(axum::http::Method::POST, "/v1/rl/resume"), - RouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), - // LoRA hot-swap. - RouteDoc::new(axum::http::Method::POST, "/v1/rl/load_lora_adapter"), - // Legacy (deprecated; subsumed by /v1/rl/state — Phase 5 will drop): - RouteDoc::new(axum::http::Method::GET, "/v1/rl/health"), - RouteDoc::new(axum::http::Method::GET, "/v1/rl/ready"), - RouteDoc::new(axum::http::Method::GET, "/v1/rl/weight_version"), - RouteDoc::new(axum::http::Method::POST, "/v1/rl/unload_lora_adapter"), - ]; - let router = Router::new() - // Phase 1: composite read-only endpoints. - .route("/v1/rl/state", get(rl_state)) - .route("/v1/rl/liveness", get(rl_liveness)) - // Pause / resume / update_weights bracket. - .route("/v1/rl/pause", post(rl_pause)) - .route("/v1/rl/resume", post(rl_resume)) - .route("/v1/rl/update_weights", post(rl_update_weights)) - // LoRA hot-swap. - .route("/v1/rl/load_lora_adapter", post(rl_load_lora_adapter)) - // Legacy endpoints — kept for back-compat until existing clients - // migrate to /v1/rl/state. Removed in a follow-up. - .route("/v1/rl/health", get(rl_health)) - .route("/v1/rl/ready", get(rl_ready)) - .route("/v1/rl/weight_version", get(rl_weight_version)) - .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) - .layer(middleware::from_fn(smart_json_error_middleware)) - .with_state(rl_state_arc); + let router = router.layer(middleware::from_fn(smart_json_error_middleware)); Ok((docs, router)) } @@ -3973,7 +3230,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4008,7 +3265,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4227,7 +3484,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4260,7 +3517,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4292,7 +3549,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4324,7 +3581,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4358,7 +3615,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -4390,7 +3647,7 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, }; @@ -5444,8 +4701,8 @@ mod tests { // Axum returns 400 on this deserialize failure before the handler // runs — that's the whole point of the typed enum vs the prior // string match. - let err = serde_json::from_str::("\"foo\"") - .expect_err("foo is not a valid PauseMode"); + let err = + serde_json::from_str::("\"foo\"").expect_err("foo is not a valid PauseMode"); assert!(err.to_string().to_lowercase().contains("foo")); } @@ -5456,14 +4713,12 @@ mod tests { assert!(body.weight_version.is_none()); assert!(body.reset_prefix_cache); - let body: RlUpdateWeightsBody = - serde_json::from_str(r#"{"weight_dir":null}"#).unwrap(); + let body: RlUpdateWeightsBody = serde_json::from_str(r#"{"weight_dir":null}"#).unwrap(); assert!(body.weight_dir.is_none()); assert!(body.reset_prefix_cache); let body: RlUpdateWeightsBody = - serde_json::from_str(r#"{"weight_dir":"/path","reset_prefix_cache":false}"#) - .unwrap(); + serde_json::from_str(r#"{"weight_dir":"/path","reset_prefix_cache":false}"#).unwrap(); assert_eq!(body.weight_dir.as_deref(), Some("/path")); assert!(!body.reset_prefix_cache); } diff --git a/lib/rl/Cargo.toml b/lib/rl/Cargo.toml new file mode 100644 index 000000000000..ac3fcdae5e14 --- /dev/null +++ b/lib/rl/Cargo.toml @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "dynamo-rl" +description = "RL admin control plane — handlers, state, fan-out, and HTTP facade for /v1/rl/*" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +# Dependency direction: dynamo-llm -> dynamo-rl -> dynamo-runtime. +# This crate must NOT depend on dynamo-llm. + +[dependencies] +dynamo-runtime = { workspace = true } + +axum = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +anyhow = { workspace = true } +futures = { workspace = true } +reqwest = { workspace = true } diff --git a/lib/rl/src/lib.rs b/lib/rl/src/lib.rs new file mode 100644 index 000000000000..ac9deadb9ce1 --- /dev/null +++ b/lib/rl/src/lib.rs @@ -0,0 +1,1011 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Dynamo RL admin control plane — handlers, state, fan-out for `/v1/rl/*`. +//! +//! See `plans/rl-crate.md` and `plans/weight-transfer-config.md`. +//! +//! **PR A status:** pure refactor — handlers + state moved verbatim out of +//! `lib/llm/src/http/service/openai.rs` so the admin code lives in its own +//! crate. Behavior unchanged. Future work (per the plan): +//! +//! - **PR B:** replace `worker_system_urls: Vec` (HTTP system-port +//! fan-out, env-driven) with discovery-backed fan-out via the dynamo +//! request plane. Drop `reqwest::Client`. Drop `DYN_RL_WORKER_SYSTEM_URLS`. +//! - **PR C:** introduce `DYN_ENABLE_RL_ENDPOINTS` (frontend-only) to gate +//! this router on a separate Axum listener (`DYN_RL_PORT` / `--rl-port`). +//! `DYN_ENABLE_RL` keeps its meaning as the inference-plane RL extensions +//! gate plus worker-side engine-route registration. + +use std::sync::Arc; + +use axum::{ + Json, Router, + extract::State, + http::{Method, StatusCode}, + response::IntoResponse, + routing::{get, post}, +}; + +/// Documentation tuple for an RL admin route. The dynamo-llm caller wraps +/// each tuple into its own `RouteDoc` for `/openapi.json` aggregation. +#[derive(Debug, Clone)] +pub struct RlRouteDoc { + pub method: Method, + pub path: String, +} + +impl RlRouteDoc { + fn new(method: Method, path: impl Into) -> Self { + Self { + method, + path: path.into(), + } + } +} + +/// Environment variable for comma-separated worker system HTTP URLs. +/// Defaults to `http://localhost:8081` when not set. +const DYN_RL_WORKER_SYSTEM_URLS_ENV: &str = "DYN_RL_WORKER_SYSTEM_URLS"; + +/// Shared state for the RL admin router. +#[derive(Clone)] +struct RlState { + /// Worker system HTTP base URLs (e.g. `http://localhost:8081`). + /// Set via `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated list). + worker_system_urls: Vec, + /// Shared HTTP client for all fan-out calls to worker system ports. + http_client: reqwest::Client, + /// Per-worker probe timeout for `/v1/rl/liveness` and `/v1/rl/ready`. + /// Read once from `DYN_RL_LIVENESS_TIMEOUT_MS` at construction. + probe_timeout: std::time::Duration, +} + +impl RlState { + fn from_env() -> anyhow::Result { + let worker_system_urls = std::env::var(DYN_RL_WORKER_SYSTEM_URLS_ENV) + .unwrap_or_else(|_| "http://localhost:8081".to_string()) + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect::>(); + let probe_timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(5000); + tracing::info!( + worker_count = worker_system_urls.len(), + ?worker_system_urls, + probe_timeout_ms, + "RL admin router configured" + ); + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .map_err(|e| anyhow::anyhow!("Failed to build RL router HTTP client: {e}"))?; + Ok(Self::new( + worker_system_urls, + http_client, + std::time::Duration::from_millis(probe_timeout_ms), + )) + } + + /// Test-friendly constructor — bypasses env reading so tests can pass in + /// fake worker URLs and a stubbed `reqwest::Client`. + fn new( + worker_system_urls: Vec, + http_client: reqwest::Client, + probe_timeout: std::time::Duration, + ) -> Self { + Self { + worker_system_urls, + http_client, + probe_timeout, + } + } + + /// Call a single engine route on one worker. Returns the JSON body. + async fn call_engine_route( + &self, + url: &str, + route: &str, + body: &serde_json::Value, + ) -> serde_json::Value { + let endpoint = format!("{url}/engine/{route}"); + match self.http_client.post(&endpoint).json(body).send().await { + Ok(resp) => { + let status = resp.status(); + match resp.json::().await { + Ok(v) => v, + Err(e) => serde_json::json!({ + "status": "error", + "message": format!("Failed to decode response from {endpoint}: {e}"), + "http_status": status.as_u16() + }), + } + } + Err(e) => serde_json::json!({ + "status": "error", + "message": format!("Request to {endpoint} failed: {e}") + }), + } + } + + /// Fan out an engine route call to all configured workers concurrently. + async fn fan_out(&self, route: &str, body: serde_json::Value) -> Vec { + let futures: Vec<_> = self + .worker_system_urls + .iter() + .map(|url| self.call_engine_route(url, route, &body)) + .collect(); + futures::future::join_all(futures).await + } + + /// Returns true only if all results have `status: "ok"`. + fn all_ok(results: &[serde_json::Value]) -> bool { + results + .iter() + .all(|r| r.get("status").and_then(|s| s.as_str()) == Some("ok")) + } +} + +/// `GET /v1/rl/ready` — composite readiness check: worker health via system port. +/// +/// Bounded with a per-worker probe timeout (default 5s, override via +/// `DYN_RL_LIVENESS_TIMEOUT_MS`) so a wedged worker fails fast as 503 instead +/// of hanging on the shared 600s `http_client` timeout. +async fn rl_ready(State(state): State>) -> impl IntoResponse { + let timeout = state.probe_timeout; + let futures: Vec<_> = state + .worker_system_urls + .iter() + .map(|url| { + let client = state.http_client.clone(); + let health_url = format!("{url}/health"); + async move { + match tokio::time::timeout(timeout, client.get(&health_url).send()).await { + Ok(Ok(resp)) => resp.status().is_success(), + Ok(Err(_)) | Err(_) => false, + } + } + }) + .collect(); + let results = futures::future::join_all(futures).await; + let all_ready = !results.is_empty() && results.iter().all(|ok| *ok); + if all_ready { + (StatusCode::OK, Json(serde_json::json!({"status": "ready"}))) + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "not_ready", + "workers_ready": results.iter().filter(|ok| **ok).count(), + "workers_total": results.len() + })), + ) + } +} + +/// `POST /v1/rl/pause` — fan out `pause_generation` to all workers. +/// +/// Query params (both optional): +/// - `mode`: `keep` | `wait` | `abort` (default `keep`) +/// - `clear_cache`: `true` | `false` (default `false`) +/// +/// Three-mode pause matches what vLLM exposes (abort / wait / keep). The +/// default `mode=keep&clear_cache=false` preserves the original single-mode +/// pause behavior so existing callers keep working without changes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[serde(rename_all = "lowercase")] +enum PauseMode { + Keep, + Wait, + Abort, +} + +impl PauseMode { + fn as_str(self) -> &'static str { + match self { + PauseMode::Keep => "keep", + PauseMode::Wait => "wait", + PauseMode::Abort => "abort", + } + } +} + +impl Default for PauseMode { + fn default() -> Self { + PauseMode::Keep + } +} + +#[derive(Debug, serde::Deserialize)] +struct RlPauseQuery { + /// Axum returns 400 automatically if this fails to deserialize as a + /// `PauseMode` (i.e. on `mode=invalid`), so we don't need a runtime check. + #[serde(default)] + mode: Option, + #[serde(default)] + clear_cache: Option, +} + +async fn rl_pause( + State(state): State>, + axum::extract::Query(q): axum::extract::Query, +) -> impl IntoResponse { + let mode = q.mode.unwrap_or_default(); + let clear_cache = q.clear_cache.unwrap_or(false); + let results = state + .fan_out( + "pause_generation", + serde_json::json!({"mode": mode.as_str(), "clear_cache": clear_cache}), + ) + .await; + if RlState::all_ok(&results) { + tracing::info!( + worker_count = results.len(), + mode = %mode.as_str(), + clear_cache, + "RL pause: all workers paused" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "mode": mode.as_str(), + "clear_cache": clear_cache, + "workers": results, + })), + ) + } else { + tracing::warn!(?results, "RL pause: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `POST /v1/rl/resume` — fan out `resume_generation` to all workers. +async fn rl_resume(State(state): State>) -> impl IntoResponse { + let results = state + .fan_out("resume_generation", serde_json::json!({})) + .await; + if RlState::all_ok(&results) { + tracing::info!( + worker_count = results.len(), + "RL resume: all workers resumed" + ); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!(?results, "RL resume: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `POST /v1/rl/update_weights` — fan out `flush_cache → update_weights_from_path` to all workers. +/// +/// **Not atomic.** If `update_weights_from_path` succeeds on workers `0..N-1` +/// and fails on worker `N`, the fleet is left in a mixed-version state: the +/// successful workers serve the new version while worker `N` still runs the +/// previous one. The response carries per-worker status so callers can +/// retry / drain manually; a true rollback layer is a follow-up. +/// +/// Two body shapes are accepted: +/// +/// **Legacy** (Phase 1 backward-compat): +/// ```json +/// { +/// "weight_dir": "/path/to/checkpoint" | null, // null → NCCL mode no-op +/// "weight_version": "step_42", // optional; derived from +/// // weight_dir basename if missing +/// "reset_prefix_cache": true +/// } +/// ``` +/// +/// **WeightTransferConfig** (new, single shape across backends): +/// ```json +/// { +/// "version": "step_42", +/// "target": {"kind": "base"} | {"kind": "lora", "name": "...", "op": "load|swap|unload"}, +/// "transport": { +/// "backend": "filesystem" | "nccl", +/// "filesystem": {"path": "...", "require_marker": "STABLE"}, +/// "nccl": {"transport_id": "...", "weight_names": [...], "dtype": "bf16"} +/// } +/// } +/// ``` +/// +/// Returns `{ "status": "ok", "applied_weight_version": "step_42", "workers": [...] }` on success. +/// +/// The pause/resume envelope is left to the caller; full-FT updates MUST +/// bracket this call with `/v1/rl/pause` and `/v1/rl/resume`. +#[derive(Debug, serde::Deserialize)] +#[serde(untagged)] +enum RlUpdateWeightsBody { + /// New shape — required field is `transport`. Serde tries this variant + /// first; falls back to legacy if it fails. + NewShape { + version: String, + target: serde_json::Value, + transport: serde_json::Value, + #[serde(default)] + pause_mode: Option, + #[serde(default)] + clear_cache: Option, + }, + /// Legacy single-arg body kept live during Phase 1 / 2. + Legacy { + weight_dir: Option, + #[serde(default)] + weight_version: Option, + #[serde(default = "default_reset_prefix_cache")] + reset_prefix_cache: bool, + }, +} + +fn default_reset_prefix_cache() -> bool { + true +} + +async fn rl_update_weights( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + // Dispatch on body shape. New shape goes through the WeightTransferConfig + // worker route; legacy keeps the existing flush_cache → update_weights_from_path + // sequence so unmigrated callers continue to work. + match body.0 { + RlUpdateWeightsBody::NewShape { + version, + target, + transport, + pause_mode, + clear_cache, + } => { + return rl_update_weights_new_shape( + state, + version, + target, + transport, + pause_mode, + clear_cache, + ) + .await; + } + RlUpdateWeightsBody::Legacy { + weight_dir, + weight_version, + reset_prefix_cache, + } => { + return rl_update_weights_legacy(state, weight_dir, weight_version, reset_prefix_cache) + .await; + } + } +} + +/// New WeightTransferConfig path — fans out to ``weight_transport_update``. +async fn rl_update_weights_new_shape( + state: Arc, + version: String, + target: serde_json::Value, + transport: serde_json::Value, + pause_mode: Option, + clear_cache: Option, +) -> (StatusCode, Json) { + let backend = transport + .get("backend") + .and_then(|v| v.as_str()) + .unwrap_or(""); + tracing::info!( + version = %version, + backend = %backend, + ?target, + "RL update_weights (new shape)" + ); + let mut body = serde_json::json!({ + "version": version, + "target": target, + "transport": transport, + }); + if let Some(pm) = pause_mode { + body["pause_mode"] = serde_json::Value::String(pm); + } + if let Some(cc) = clear_cache { + body["clear_cache"] = serde_json::Value::Bool(cc); + } + let results = state.fan_out("weight_transport_update", body).await; + if RlState::all_ok(&results) { + tracing::info!( + worker_count = results.len(), + backend = %backend, + version = %version, + "RL update_weights (new shape): all workers updated" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "applied_weight_version": version, + "backend": backend, + "workers": results, + })), + ) + } else { + tracing::warn!(?results, backend = %backend, "RL update_weights (new shape): some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "stage": "weight_transport_update", + "backend": backend, + "workers": results, + })), + ) + } +} + +/// Legacy single-arg body — Phase 1 backward-compat. +async fn rl_update_weights_legacy( + state: Arc, + weight_dir: Option, + weight_version: Option, + reset_prefix_cache: bool, +) -> (StatusCode, Json) { + // Treat empty string the same as missing/null (NCCL no-op). Otherwise + // an empty string would reach the engine as `path=""` and fail + // confusingly downstream. + let weight_dir = weight_dir + .as_ref() + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(str::to_string); + let Some(weight_dir) = weight_dir else { + tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); + return ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "message": "NCCL mode, no-op on Dynamo side" + })), + ); + }; + + let version = weight_version.clone().unwrap_or_else(|| { + std::path::Path::new(&weight_dir) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string() + }); + tracing::info!( + weight_dir = %weight_dir, + version = %version, + reset_prefix_cache, + "RL update_weights" + ); + + // Step 1 (optional): flush_cache across all workers. + if reset_prefix_cache { + let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; + if !RlState::all_ok(&flush_results) { + tracing::warn!(?flush_results, "RL update_weights: flush_cache failed"); + return ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "stage": "flush_cache", + "workers": flush_results + })), + ); + } + } + + // Step 2: update_weights_from_path across all workers. + let load_body = serde_json::json!({"path": &weight_dir, "version": version}); + let load_results = state.fan_out("update_weights_from_path", load_body).await; + if RlState::all_ok(&load_results) { + tracing::info!( + worker_count = load_results.len(), + weight_dir = %weight_dir, + "RL update_weights: all workers updated" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "applied_weight_version": version, + "workers": load_results, + })), + ) + } else { + tracing::warn!( + ?load_results, + "RL update_weights: update_weights_from_path failed" + ); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "stage": "update_weights_from_path", + "workers": load_results + })), + ) + } +} + +/// `POST /v1/rl/init_transport` — idempotent one-time setup for a weight +/// transport (filesystem / nccl). Replaces backend-specific bring-up +/// endpoints with a single discriminated body. +/// +/// Body: +/// ```json +/// { +/// "transport_id": "rl-weights-step", +/// "backend": "filesystem" | "nccl", +/// "filesystem": { … } | "nccl": { … } +/// } +/// ``` +/// +/// `filesystem` is a no-op (transport state goes ``ready`` immediately). +/// `nccl` triggers the worker-side group bootstrap. +async fn rl_init_transport( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let body = body.0; + let backend = body + .get("backend") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let transport_id = body + .get("transport_id") + .and_then(|v| v.as_str()) + .unwrap_or(&backend) + .to_string(); + tracing::info!(%backend, %transport_id, "RL init_transport"); + + let results = state.fan_out("weight_transport_init", body).await; + if RlState::all_ok(&results) { + tracing::info!( + worker_count = results.len(), + %backend, + %transport_id, + "RL init_transport: all workers ready" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "transport_id": transport_id, + "backend": backend, + "ready": true, + "workers": results, + })), + ) + } else { + tracing::warn!(?results, %backend, "RL init_transport: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "transport_id": transport_id, + "backend": backend, + "workers": results, + })), + ) + } +} + +/// `POST /v1/rl/load_lora_adapter` — hot-load/swap a LoRA adapter from a filesystem path. +/// +/// Expected body: `{"lora_name": "r16-a32.0", "lora_path": "/path/to/adapter_dir"}` +/// +/// The adapter directory must contain PEFT-style `adapter_model.safetensors` and +/// `adapter_config.json`. This is the RL-specific LoRA path used by Prime-RL every +/// training step (separate from Dynamo's URI-based `load_lora` gRPC endpoint which +/// downloads adapters from S3/file URIs and publishes a new ModelDeploymentCard). +/// +/// Hot-swap semantics: calling with a `lora_name` that is already loaded removes +/// the previous adapter and loads the new one under the same deterministic int ID, +/// then resets the prefix cache so stale KV entries don't poison new rollouts. +/// +/// Pair with `/v1/rl/pause` and `/v1/rl/resume` for a full drain-swap-resume cycle. +async fn rl_load_lora_adapter( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let lora_name = body.get("lora_name").and_then(|v| v.as_str()); + let lora_path = body.get("lora_path").and_then(|v| v.as_str()); + + let (lora_name, lora_path) = match (lora_name, lora_path) { + (Some(n), Some(p)) if !n.is_empty() && !p.is_empty() => (n.to_string(), p.to_string()), + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "status": "error", + "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)" + })), + ); + } + }; + + tracing::info!(%lora_name, %lora_path, "RL load_lora_adapter"); + let results = state + .fan_out( + "load_lora_adapter", + serde_json::json!({"lora_name": &lora_name, "lora_path": &lora_path}), + ) + .await; + + if RlState::all_ok(&results) { + tracing::info!( + worker_count = results.len(), + %lora_name, + %lora_path, + "RL load_lora_adapter: all workers loaded" + ); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!(?results, %lora_name, "RL load_lora_adapter: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `POST /v1/rl/unload_lora_adapter` — remove a previously loaded LoRA adapter by name. +/// +/// Expected body: `{"lora_name": "r16-a32.0"}` +/// +/// Idempotent: unloading an already-absent LoRA returns `status: ok` so callers +/// can retry safely without special-casing not-found. +async fn rl_unload_lora_adapter( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let lora_name = body + .get("lora_name") + .and_then(|v| v.as_str()) + .map(str::to_string); + + let lora_name = match lora_name { + Some(n) if !n.is_empty() => n, + _ => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "status": "error", + "message": "Expected body: {\"lora_name\": str} (required, non-empty)" + })), + ); + } + }; + + tracing::info!(%lora_name, "RL unload_lora_adapter"); + let results = state + .fan_out( + "unload_lora_adapter", + serde_json::json!({"lora_name": &lora_name}), + ) + .await; + + if RlState::all_ok(&results) { + tracing::info!( + worker_count = results.len(), + %lora_name, + "RL unload_lora_adapter: all workers unloaded" + ); + ( + StatusCode::OK, + Json(serde_json::json!({"status": "ok", "workers": results})), + ) + } else { + tracing::warn!(?results, %lora_name, "RL unload_lora_adapter: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({"status": "error", "workers": results})), + ) + } +} + +/// `GET /v1/rl/weight_version` — query weight version from all workers. +async fn rl_weight_version(State(state): State>) -> impl IntoResponse { + let results = state + .fan_out("get_weight_version", serde_json::json!({})) + .await; + + // Collect distinct versions and check for consistency + let versions: Vec<_> = results + .iter() + .filter_map(|r| { + r.get("version") + .and_then(|v| v.as_str()) + .map(str::to_string) + }) + .collect(); + + let unique: std::collections::HashSet<&str> = versions.iter().map(String::as_str).collect(); + if unique.len() == 1 { + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "version": unique.into_iter().next().unwrap_or(""), + "workers": results + })), + ) + } else { + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "inconsistent", + "versions": unique.into_iter().collect::>(), + "workers": results + })), + ) + } +} + +async fn rl_health() -> impl IntoResponse { + (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) +} + +/// `GET /v1/rl/liveness` — engine event-loop probe via the `liveness_probe` +/// engine route. The legacy `/v1/rl/health` returns OK as long as the +/// frontend process is up; this endpoint round-trips through the engine so +/// a hung event loop or wedged worker surfaces as 503. +/// +/// Each per-worker call carries a 5s timeout (override via +/// `DYN_RL_LIVENESS_TIMEOUT_MS`). Returns 200 only when every worker +/// reports `alive: true` within the deadline; 503 otherwise. +async fn rl_liveness(State(state): State>) -> impl IntoResponse { + if state.worker_system_urls.is_empty() { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "error", + "alive": false, + "message": "no workers registered" + })), + ); + } + let timeout = state.probe_timeout; + + let futures: Vec<_> = state + .worker_system_urls + .iter() + .map(|url| { + let client = state.http_client.clone(); + let endpoint = format!("{url}/engine/liveness_probe"); + async move { + tokio::time::timeout( + timeout, + async { + match client.post(&endpoint).json(&serde_json::json!({})).send().await { + Ok(resp) => resp + .json::() + .await + .unwrap_or_else(|e| serde_json::json!({ + "status": "error", + "alive": false, + "message": format!("decode failed: {e}") + })), + Err(e) => serde_json::json!({ + "status": "error", + "alive": false, + "message": format!("request failed: {e}") + }), + } + }, + ) + .await + .unwrap_or_else(|_| serde_json::json!({ + "status": "error", + "alive": false, + "message": format!("liveness_probe timed out after {}ms", timeout.as_millis()) + })) + } + }) + .collect(); + let results = futures::future::join_all(futures).await; + let all_alive = results + .iter() + .all(|r| r.get("alive").and_then(|v| v.as_bool()) == Some(true)); + if all_alive { + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "alive": true, + "workers": results, + })), + ) + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "status": "error", + "alive": false, + "workers": results, + })), + ) + } +} + +/// `GET /v1/rl/state` — composite RL fleet state snapshot. +/// +/// Replaces three v1 endpoints (`/v1/rl/health` + `/v1/rl/ready` + +/// `/v1/rl/weight_version`) with a single composite, scoped to RL-specific +/// readiness (engine alive, pause state, applied weight version, loaded +/// LoRAs). +/// +/// Aggregates per-worker `get_state` engine-route responses into: +/// +/// ```json +/// { +/// "ready": bool, +/// "ingress_alive": true, +/// "engine_alive": bool, // every worker's engine.check_health() ok +/// "pause_state": "running"|"paused"|"mixed", +/// "applied_weight_version": str, // when consistent across workers; null if mixed +/// "loras": [{name, loaded_on: [worker_idx]}], +/// "workers": [] +/// } +/// ``` +/// +/// `ingress_alive` is unconditionally `true` because reaching this handler +/// means the frontend HTTP listener is up. `ready = ingress_alive AND +/// engine_alive AND len(workers) > 0`. +async fn rl_state(State(state): State>) -> impl IntoResponse { + if state.worker_system_urls.is_empty() { + return ( + StatusCode::SERVICE_UNAVAILABLE, + Json(serde_json::json!({ + "ready": false, + "ingress_alive": true, + "engine_alive": false, + "pause_state": "running", + "applied_weight_version": null, + "loras": [], + "workers": [], + "status": "error", + "message": "no workers registered" + })), + ); + } + let results = state.fan_out("get_state", serde_json::json!({})).await; + + let engine_alive = results + .iter() + .all(|r| r.get("engine_alive").and_then(|v| v.as_bool()) == Some(true)); + + // Aggregate pause_state: if all workers agree, surface that; else "mixed". + let pause_states: std::collections::HashSet<&str> = results + .iter() + .filter_map(|r| r.get("pause_state").and_then(|v| v.as_str())) + .collect(); + let pause_state = if pause_states.len() == 1 { + pause_states + .into_iter() + .next() + .unwrap_or("running") + .to_string() + } else if pause_states.is_empty() { + "running".to_string() + } else { + "mixed".to_string() + }; + + // applied_weight_version is reported only when consistent. + let weight_versions: std::collections::HashSet<&str> = results + .iter() + .filter_map(|r| r.get("applied_weight_version").and_then(|v| v.as_str())) + .collect(); + let applied_weight_version: Option = if weight_versions.len() == 1 { + weight_versions.into_iter().next().map(|s| s.to_string()) + } else { + None + }; + + // LoRA name → list of worker indices that have it loaded. + let mut lora_loaded_on: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + for (idx, worker) in results.iter().enumerate() { + if let Some(loras) = worker.get("loras").and_then(|v| v.as_array()) { + for lora in loras { + if let Some(name) = lora.get("name").and_then(|v| v.as_str()) { + lora_loaded_on + .entry(name.to_string()) + .or_default() + .push(idx); + } + } + } + } + let loras: Vec = lora_loaded_on + .into_iter() + .map(|(name, loaded_on)| serde_json::json!({"name": name, "loaded_on": loaded_on})) + .collect(); + + let ready = engine_alive && !results.is_empty(); + let body = serde_json::json!({ + "ready": ready, + "ingress_alive": true, + "engine_alive": engine_alive, + "pause_state": pause_state, + "applied_weight_version": applied_weight_version, + "loras": loras, + "workers": results, + }); + let status = if ready { + StatusCode::OK + } else { + StatusCode::SERVICE_UNAVAILABLE + }; + (status, Json(body)) +} + +/// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. +/// +/// Worker system URLs are read from the `DYN_RL_WORKER_SYSTEM_URLS` environment +/// variable (comma-separated, defaults to `http://localhost:8081`). +/// +/// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` is set. +/// +/// Prime-RL usage: set `admin_base_url = ["http://dynamo-frontend:8000/v1/rl"]` +/// in the orchestrator config. Prime-RL strips the trailing `/v1` suffix only +/// if present, so `/v1/rl` is preserved and all routes resolve correctly. +pub fn rl_router() -> anyhow::Result<(Vec, Router)> { + let rl_state_arc = Arc::new(RlState::from_env()?); + let docs = vec![ + // Phase 1: composite endpoints. + RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/state"), + RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/liveness"), + // Pause / resume / update_weights bracket. + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/resume"), + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), + // WeightTransferConfig API (Phase 1+4) — idempotent transport setup. + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/init_transport"), + // LoRA hot-swap. + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/load_lora_adapter"), + // Legacy (deprecated; subsumed by /v1/rl/state — Phase 5 will drop): + RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/health"), + RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/ready"), + RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/weight_version"), + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/unload_lora_adapter"), + ]; + let router = Router::new() + // Phase 1: composite read-only endpoints. + .route("/v1/rl/state", get(rl_state)) + .route("/v1/rl/liveness", get(rl_liveness)) + // Pause / resume / update_weights bracket. + .route("/v1/rl/pause", post(rl_pause)) + .route("/v1/rl/resume", post(rl_resume)) + .route("/v1/rl/update_weights", post(rl_update_weights)) + // WeightTransferConfig API (Phase 1+4) — idempotent transport setup. + .route("/v1/rl/init_transport", post(rl_init_transport)) + // LoRA hot-swap. + .route("/v1/rl/load_lora_adapter", post(rl_load_lora_adapter)) + // Legacy endpoints — kept for back-compat until existing clients + // migrate to /v1/rl/state. Removed in a follow-up. + .route("/v1/rl/health", get(rl_health)) + .route("/v1/rl/ready", get(rl_ready)) + .route("/v1/rl/weight_version", get(rl_weight_version)) + .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) + .with_state(rl_state_arc); + Ok((docs, router)) +} From 67058be7c337abe913306839db81e9881dbf1a0b Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 01:43:06 -0700 Subject: [PATCH 12/18] feat(rl/vllm): add WeightTransport trait + FilesystemTransport/NcclTransport (Phase 1+4 of weight-transfer-config.md) --- components/src/dynamo/vllm/handlers.py | 170 +++++++++++++++- .../dynamo/vllm/weight_transports/__init__.py | 57 ++++++ .../src/dynamo/vllm/weight_transports/base.py | 191 ++++++++++++++++++ .../vllm/weight_transports/engine_adapter.py | 138 +++++++++++++ .../vllm/weight_transports/filesystem.py | 113 +++++++++++ .../src/dynamo/vllm/weight_transports/nccl.py | 167 +++++++++++++++ components/src/dynamo/vllm/worker_factory.py | 13 +- 7 files changed, 845 insertions(+), 4 deletions(-) create mode 100644 components/src/dynamo/vllm/weight_transports/__init__.py create mode 100644 components/src/dynamo/vllm/weight_transports/base.py create mode 100644 components/src/dynamo/vllm/weight_transports/engine_adapter.py create mode 100644 components/src/dynamo/vllm/weight_transports/filesystem.py create mode 100644 components/src/dynamo/vllm/weight_transports/nccl.py diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 02e5a54d97a5..4e477084fc70 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -902,7 +902,7 @@ async def resume_generation(self, body: dict) -> dict: try: await self.engine_client.resume_generation() self._paused = False - logger.debug("[RL] Engine resumed") + logger.info("[RL] Engine resumed") return {"status": "ok", "message": "Engine resumed"} except EngineDeadError as e: self._shutdown_on_engine_dead(e) @@ -1013,7 +1013,7 @@ async def update_weights_from_path(self, body: dict) -> dict: kwargs={"weights_path": path}, ) self._weight_version = version - logger.debug(f"[RL] Weights loaded from {path} (version={version})") + logger.info(f"[RL] Weights loaded from {path} (version={version})") return { "status": "ok", "message": f"Weights loaded from {path}", @@ -1176,7 +1176,7 @@ async def load_lora_adapter(self, body: dict) -> dict: "lora_name": lora_name, } - logger.debug( + logger.info( f"[RL] LoRA adapter {'hot-swapped' if is_hot_swap else 'loaded'}: " f"name={lora_name} id={lora_id} path={lora_path}" ) @@ -1265,6 +1265,170 @@ async def unload_lora_adapter(self, body: dict) -> dict: ) return {"status": "error", "message": str(e)} + # ── WeightTransferConfig API (Phase 1+4) ─────────────────────────── + # + # New unified surface paired with the Rust frontend's + # ``/v1/rl/init_transport`` and the discriminated ``/v1/rl/update_weights`` + # body. Backwards-compatible: legacy ``update_weights_from_path`` / + # ``load_lora_adapter`` / ``unload_lora_adapter`` engine routes stay live + # for callers that haven't migrated yet. + + def _ensure_weight_transports(self): + """Lazy-init transport registry + vLLM engine adapter.""" + if getattr(self, "_weight_transports", None) is not None: + return + from .weight_transports import VllmEngineAdapter + + adapter = VllmEngineAdapter(self.engine_client) + adapter.bind_lora_helpers( + loader=self._lora_load_via_admin, + unloader=self._lora_unload_via_admin, + ) + self._weight_engine_adapter = adapter + self._weight_transports: dict = {} + + async def _lora_load_via_admin(self, *, name: str, path: str) -> dict: + """Re-use the existing :meth:`load_lora_adapter` path so MDC publish, + hot-swap detection, and prefix-cache reset all stay consistent.""" + return await self.load_lora_adapter( + {"lora_name": name, "lora_path": path} + ) + + async def _lora_unload_via_admin(self, *, name: str) -> dict: + return await self.unload_lora_adapter({"lora_name": name}) + + async def weight_transport_init(self, body: dict) -> dict: + """Idempotent transport setup. Backs ``POST /v1/rl/init_transport``. + + Body: + - transport_id: str (caller-chosen) + - backend: "filesystem" | "nccl" + - : {…} (backend-specific block) + """ + body = body or {} + backend = body.get("backend") + if backend not in ("filesystem", "nccl"): + return { + "status": "error", + "message": ( + f"Unsupported backend '{backend}'. In scope this iteration: " + "filesystem, nccl. Future (deferred): nixl, model_express, ipc." + ), + } + transport_id = body.get("transport_id", backend) + cfg = dict(body.get(backend) or {}) + cfg.setdefault("transport_id", transport_id) + + try: + self._ensure_weight_transports() + from .weight_transports import build_transport, InitCtx + + existing = self._weight_transports.get(transport_id) + if existing is not None and existing.backend_id == backend: + logger.info( + f"[RL] init_transport: '{transport_id}' already configured " + f"(backend={backend}); idempotent re-init" + ) + # Re-run init for idempotency (eg. NCCL group bootstrap). + ctx = InitCtx(rank=0, world_size=1, served_model_name="") + result = await existing.init(ctx, cfg) + return result.to_dict() + + transport = build_transport(backend, self._weight_engine_adapter, cfg) + ctx = InitCtx(rank=0, world_size=1, served_model_name="") + result = await transport.init(ctx, cfg) + self._weight_transports[transport_id] = transport + logger.info( + f"[RL] init_transport: backend={backend} transport_id={transport_id} " + f"ready={result.ready}" + ) + return result.to_dict() + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.exception(f"[RL] init_transport failed: {e}") + return {"status": "error", "message": str(e)} + + async def weight_transport_update(self, body: dict) -> dict: + """Backs the new-shape ``POST /v1/rl/update_weights``. + + Body: + - version: str + - target: {"kind": "base"} | {"kind": "lora", "name": str, "op": …} + - transport: {"backend": "filesystem"|"nccl", : {…}} + """ + try: + from .weight_transports import UpdateWeightsRequest, build_transport, InitCtx + + req = UpdateWeightsRequest.from_dict(body or {}) + self._ensure_weight_transports() + backend = (req.transport or {}).get("backend") + + # LoRA unload may omit the transport block — synthesize filesystem. + if ( + req.target.kind == "lora" + and req.target.op == "unload" + and backend is None + ): + backend = "filesystem" + req.transport = {"backend": "filesystem", "filesystem": {}} + + if backend not in ("filesystem", "nccl"): + return { + "status": "error", + "message": ( + f"Unsupported backend '{backend}'. In scope this iteration: " + "filesystem, nccl." + ), + } + + # Resolve transport instance: prefer one bound by init_transport; + # for filesystem we lazily build per-call (no setup needed). + transport_id = (req.transport.get(backend) or {}).get( + "transport_id", backend + ) + transport = self._weight_transports.get(transport_id) + if transport is None: + if backend == "filesystem": + transport = build_transport( + backend, self._weight_engine_adapter, {"transport_id": transport_id} + ) + await transport.init( + InitCtx(rank=0, world_size=1, served_model_name=""), {} + ) + self._weight_transports[transport_id] = transport + else: + return { + "status": "error", + "message": ( + f"Transport '{transport_id}' (backend={backend}) is not " + f"initialized. Call POST /v1/rl/init_transport first." + ), + } + elif transport.backend_id != backend: + return { + "status": "error", + "message": ( + f"Transport '{transport_id}' is bound to backend " + f"'{transport.backend_id}', not '{backend}'." + ), + } + + result = await transport.update_weights(req) + self._weight_version = req.version + payload = result.to_dict() + payload.setdefault("version", req.version) + payload.setdefault("backend", backend) + payload.setdefault("transport_id", transport_id) + return payload + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except (ValueError, FileNotFoundError, NotImplementedError) as e: + return {"status": "error", "message": str(e)} + except Exception as e: + logger.exception(f"[RL] weight_transport_update failed: {e}") + return {"status": "error", "message": str(e)} + @abstractmethod def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]: raise NotImplementedError diff --git a/components/src/dynamo/vllm/weight_transports/__init__.py b/components/src/dynamo/vllm/weight_transports/__init__.py new file mode 100644 index 000000000000..391a60562cb9 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/__init__.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Weight-transport plug-ins for ``dynamo.vllm`` (Phase 1+4 of the +WeightTransferConfig design). + +In scope this iteration: + +* :class:`FilesystemTransport` — current default, safetensors via shared FS. +* :class:`NcclTransport` — collective broadcast on a pre-formed group + (vLLM ``collective_rpc("update_weights_from_distributed", …)``). + +Future (deferred): ``NixlTransport``, ``ModelExpressTransport``, +``IpcTransport``, plus an ``SglangEngineAdapter`` for the second engine +flavor. +""" + +from .base import ( + EngineAdapter, + InitCtx, + InitResult, + TransportState, + UpdateResult, + UpdateWeightsRequest, + WeightTarget, + WeightTransport, +) +from .engine_adapter import VllmEngineAdapter +from .filesystem import FilesystemTransport +from .nccl import NcclTransport + +__all__ = [ + "EngineAdapter", + "FilesystemTransport", + "InitCtx", + "InitResult", + "NcclTransport", + "TransportState", + "UpdateResult", + "UpdateWeightsRequest", + "VllmEngineAdapter", + "WeightTarget", + "WeightTransport", + "build_transport", +] + + +def build_transport(backend: str, engine_adapter, cfg: dict): + """Factory: instantiate the right transport for the given backend id.""" + if backend == "filesystem": + return FilesystemTransport(engine_adapter, cfg) + if backend == "nccl": + return NcclTransport(engine_adapter, cfg) + raise ValueError( + f"Unsupported weight-transport backend '{backend}'. " + "In-scope this iteration: filesystem, nccl. " + "Future (deferred): nixl, model_express, ipc." + ) diff --git a/components/src/dynamo/vllm/weight_transports/base.py b/components/src/dynamo/vllm/weight_transports/base.py new file mode 100644 index 000000000000..f8c6d9e8f791 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/base.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Trait + types for the WeightTransferConfig API (vLLM-scoped, Phase 1).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Optional, Protocol + + +PauseMode = Literal["keep", "wait", "abort"] +TransportState = Literal["configured", "ready", "receiving", "failed"] +TargetKind = Literal["base", "lora"] +LoraOp = Literal["load", "swap", "unload"] + + +@dataclass(frozen=True) +class WeightTarget: + """What is being updated. + + * ``kind="base"``: the base model itself (full-FT reload). + * ``kind="lora"``: a LoRA adapter; ``name`` is required and ``op`` selects + between load/swap/unload. + """ + + kind: TargetKind + name: Optional[str] = None + op: Optional[LoraOp] = None + + @classmethod + def from_dict(cls, body: dict) -> "WeightTarget": + kind = body.get("kind") + if kind not in ("base", "lora"): + raise ValueError( + f"WeightTarget.kind must be 'base' or 'lora', got {kind!r}" + ) + if kind == "lora": + name = body.get("name") + if not isinstance(name, str) or not name: + raise ValueError( + "WeightTarget.name is required when kind='lora'" + ) + op = body.get("op") + if op not in ("load", "swap", "unload"): + raise ValueError( + f"WeightTarget.op must be 'load'|'swap'|'unload' when " + f"kind='lora', got {op!r}" + ) + return cls(kind="lora", name=name, op=op) + return cls(kind="base") + + +@dataclass +class UpdateWeightsRequest: + """Single discriminated body for ``POST /v1/rl/update_weights``.""" + + version: str + target: WeightTarget + transport: dict # backend-specific block, validated by the transport impl + pause_mode: PauseMode = "keep" + clear_cache: bool = True + + @classmethod + def from_dict(cls, body: dict) -> "UpdateWeightsRequest": + version = body.get("version") + if not isinstance(version, str) or not version: + raise ValueError("update_weights: 'version' is required") + target = WeightTarget.from_dict(body.get("target", {}) or {}) + transport = body.get("transport") or {} + if target.kind == "base" or target.op != "unload": + if not isinstance(transport, dict) or "backend" not in transport: + raise ValueError( + "update_weights: 'transport.backend' is required " + "(except for lora unload)" + ) + pause_mode = body.get("pause_mode", "keep") + if pause_mode not in ("keep", "wait", "abort"): + raise ValueError( + f"update_weights: pause_mode must be 'keep'|'wait'|'abort', " + f"got {pause_mode!r}" + ) + clear_cache = bool(body.get("clear_cache", True)) + return cls( + version=version, + target=target, + transport=transport, + pause_mode=pause_mode, + clear_cache=clear_cache, + ) + + +@dataclass +class InitCtx: + """Constant context passed to every transport ``init`` call.""" + + rank: int + world_size: int + served_model_name: str + + +@dataclass +class InitResult: + status: str + transport_id: str + ready: bool + message: Optional[str] = None + extra: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + out = { + "status": self.status, + "transport_id": self.transport_id, + "ready": self.ready, + } + if self.message: + out["message"] = self.message + if self.extra: + out.update(self.extra) + return out + + +@dataclass +class UpdateResult: + status: str + message: str = "" + version: Optional[str] = None + extra: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + out = {"status": self.status, "message": self.message} + if self.version is not None: + out["version"] = self.version + if self.extra: + out.update(self.extra) + return out + + +class WeightTransport(Protocol): + """One implementation per backend. + + Phase 1: ``FilesystemTransport``. + Phase 4: ``NcclTransport``. + """ + + backend_id: str + + async def init(self, ctx: InitCtx, cfg: dict) -> InitResult: ... + + async def update_weights( + self, req: UpdateWeightsRequest + ) -> UpdateResult: ... + + async def teardown(self) -> None: ... + + @property + def state(self) -> TransportState: ... + + +class EngineAdapter(Protocol): + """Engine-flavor shim. One implementation per engine. + + Phase 1+4 ships :class:`VllmEngineAdapter` only. Future: + ``SglangEngineAdapter`` drops in as one extra subclass without touching + any :class:`WeightTransport` impl. + """ + + async def update_weights_from_disk( + self, *, path: str, version: str, target: WeightTarget + ) -> UpdateResult: ... + + async def update_weights_from_distributed( + self, + *, + group: str, + dtype: str, + version: str, + target: WeightTarget, + weight_names: Optional[list[str]] = None, + ) -> UpdateResult: ... + + async def update_weights_from_tensor( + self, *, tensors: Any, version: str, target: WeightTarget + ) -> UpdateResult: ... + + async def update_weights_from_ipc( + self, *, handle: Any, version: str, target: WeightTarget + ) -> UpdateResult: ... + + async def add_lora(self, *, name: str, source: str) -> UpdateResult: ... + + async def remove_lora(self, *, name: str) -> UpdateResult: ... diff --git a/components/src/dynamo/vllm/weight_transports/engine_adapter.py b/components/src/dynamo/vllm/weight_transports/engine_adapter.py new file mode 100644 index 000000000000..2c3f9e914a34 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/engine_adapter.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""vLLM-flavor engine adapter. + +Wraps ``engine_client.collective_rpc(...)`` so each :class:`WeightTransport` +implementation can call a stable, engine-agnostic API. Future: +``SglangEngineAdapter`` will wrap ``tokenizer_manager.update_weights_from_*`` +following the same Protocol. +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from .base import EngineAdapter, UpdateResult, WeightTarget + +logger = logging.getLogger(__name__) + + +class VllmEngineAdapter(EngineAdapter): + """vLLM-flavor :class:`EngineAdapter` backed by an ``engine_client``. + + All four ``update_weights_from_*`` paths route through ``collective_rpc`` + against the in-process worker(s); LoRA ops route through the engine's + ``add_lora`` / ``remove_lora`` (or equivalent collective) calls. + """ + + backend_id = "vllm" + + def __init__(self, engine_client, *, lora_loader=None): + self.engine_client = engine_client + self._lora_loader = lora_loader # optional callable for LoRA add path + + # ---- four canonical update paths --------------------------------------- + + async def update_weights_from_disk( + self, *, path: str, version: str, target: WeightTarget + ) -> UpdateResult: + await self.engine_client.collective_rpc( + "reload_weights", + kwargs={"weights_path": path}, + ) + return UpdateResult( + status="ok", + message=f"Weights loaded from {path}", + version=version, + ) + + async def update_weights_from_distributed( + self, + *, + group: str, + dtype: str, + version: str, + target: WeightTarget, + weight_names: Optional[list[str]] = None, + ) -> UpdateResult: + # vLLM exposes per-name distributed update via the worker's + # `update_weight_from_tensor` / `update_weight` collective. We loop + # over weight_names so the trainer can drive the broadcast iteration. + if not weight_names: + raise ValueError( + "update_weights_from_distributed: weight_names is required so " + "the worker knows which named parameters to receive on the " + "NCCL group." + ) + for name in weight_names: + await self.engine_client.collective_rpc( + "update_weight", + kwargs={"name": name, "dtype": dtype, "shape": None}, + ) + return UpdateResult( + status="ok", + message=f"Updated {len(weight_names)} weights via group '{group}'", + version=version, + extra={"weights_received": len(weight_names)}, + ) + + async def update_weights_from_tensor( + self, *, tensors: Any, version: str, target: WeightTarget + ) -> UpdateResult: + # Future hook for NIXL/MX paths (deferred). + raise NotImplementedError( + "update_weights_from_tensor is reserved for NIXL/ModelExpress " + "transports; not implemented in Phase 1+4." + ) + + async def update_weights_from_ipc( + self, *, handle: Any, version: str, target: WeightTarget + ) -> UpdateResult: + raise NotImplementedError( + "update_weights_from_ipc is reserved for the colocated-trainer " + "path; not implemented in Phase 1+4." + ) + + # ---- LoRA ops ---------------------------------------------------------- + + async def add_lora(self, *, name: str, source: str) -> UpdateResult: + if self._lora_loader is None: + raise RuntimeError( + "VllmEngineAdapter.add_lora called but no lora_loader was " + "supplied at construction. Wire it from the handler." + ) + result = await self._lora_loader(name=name, path=source) + return UpdateResult( + status=result.get("status", "ok"), + message=result.get("message", ""), + extra={k: v for k, v in result.items() if k not in ("status", "message")}, + ) + + async def remove_lora(self, *, name: str) -> UpdateResult: + if self._lora_loader is None: + raise RuntimeError( + "VllmEngineAdapter.remove_lora called but no lora_loader was " + "supplied at construction. Wire it from the handler." + ) + # The handler exposes both load and unload via the same `lora_loader` + # callable, dispatched on a sentinel ``op`` field. We use the same + # convention: invoke the unload helper if available; otherwise fall + # through and let the caller handle. + unloader = getattr(self, "_lora_unloader", None) + if unloader is None: + raise RuntimeError( + "VllmEngineAdapter.remove_lora called but no lora_unloader " + "was supplied. Wire it from the handler." + ) + result = await unloader(name=name) + return UpdateResult( + status=result.get("status", "ok"), + message=result.get("message", ""), + extra={k: v for k, v in result.items() if k not in ("status", "message")}, + ) + + # Convenience: handler wires both helpers in one shot. + def bind_lora_helpers(self, *, loader, unloader): + self._lora_loader = loader + self._lora_unloader = unloader diff --git a/components/src/dynamo/vllm/weight_transports/filesystem.py b/components/src/dynamo/vllm/weight_transports/filesystem.py new file mode 100644 index 000000000000..8ec0d7626acc --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/filesystem.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Filesystem weight transport (Phase 1). + +Equivalent of the existing ``update_weights_from_path`` route, but reachable +through the unified :class:`WeightTransport` Protocol so the same wire shape +covers full-FT and LoRA, and so future backends slot in alongside. +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +from .base import ( + EngineAdapter, + InitCtx, + InitResult, + TransportState, + UpdateResult, + UpdateWeightsRequest, + WeightTransport, +) + +logger = logging.getLogger(__name__) + + +class FilesystemTransport(WeightTransport): + """Filesystem path → engine reload. + + Config (the ``"filesystem"`` block of an ``init_transport`` body or a + ``transport.filesystem`` block of an ``update_weights`` body): + + path: str (required for base / lora-load / lora-swap) + require_marker: str (optional, default 'STABLE') + """ + + backend_id = "filesystem" + + def __init__(self, engine_adapter: EngineAdapter, cfg: dict): + self._engine = engine_adapter + self._cfg = cfg or {} + self._state: TransportState = "configured" + self._transport_id: str = self._cfg.get("transport_id", "filesystem") + + @property + def state(self) -> TransportState: + return self._state + + async def init(self, ctx: InitCtx, cfg: dict) -> InitResult: + # No setup needed for filesystem — degenerate one-shot. + self._cfg = {**self._cfg, **(cfg or {})} + self._transport_id = self._cfg.get("transport_id", self._transport_id) + self._state = "ready" + return InitResult( + status="ok", + transport_id=self._transport_id, + ready=True, + message="filesystem transport ready (no setup required)", + ) + + async def teardown(self) -> None: + self._state = "configured" + + async def update_weights( + self, req: UpdateWeightsRequest + ) -> UpdateResult: + fs = req.transport.get("filesystem") or {} + path: Optional[str] = fs.get("path") + require_marker: Optional[str] = fs.get( + "require_marker", self._cfg.get("require_marker", "STABLE") + ) + + # ---- LoRA unload: no transport, no path ---------------------------- + if req.target.kind == "lora" and req.target.op == "unload": + return await self._engine.remove_lora(name=req.target.name) + + if not path: + raise ValueError( + "filesystem.update_weights: 'transport.filesystem.path' is " + "required (except for lora unload)" + ) + + if require_marker: + marker = os.path.join(path, require_marker) + if not os.path.exists(marker): + raise FileNotFoundError( + f"filesystem transport: require_marker '{require_marker}' " + f"not found under {path!r}" + ) + + if req.target.kind == "base": + self._state = "receiving" + try: + result = await self._engine.update_weights_from_disk( + path=path, version=req.version, target=req.target + ) + finally: + self._state = "ready" + logger.info( + f"[RL] filesystem.update_weights: base reload from {path} " + f"(version={req.version})" + ) + return result + + # target.kind == "lora", op in {load, swap} + result = await self._engine.add_lora(name=req.target.name, source=path) + logger.info( + f"[RL] filesystem.update_weights: lora {req.target.op} " + f"name={req.target.name} from {path}" + ) + return result diff --git a/components/src/dynamo/vllm/weight_transports/nccl.py b/components/src/dynamo/vllm/weight_transports/nccl.py new file mode 100644 index 000000000000..0a3636f5bf61 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/nccl.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""NCCL weight transport (Phase 4). + +Trainer + dynamo.vllm worker(s) form a NCCL process group at +``init_transport`` time; per-step ``update_weights`` triggers receive via +``collective_rpc("update_weight", ...)`` for each named parameter. + +Phase 4 scope: vLLM only. The trainer side is responsible for driving the +broadcast itself; dynamo just exposes the receiver hook. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +from .base import ( + EngineAdapter, + InitCtx, + InitResult, + TransportState, + UpdateResult, + UpdateWeightsRequest, + WeightTransport, +) + +logger = logging.getLogger(__name__) + + +class NcclTransport(WeightTransport): + """NCCL collective broadcast → engine receive. + + Config (the ``"nccl"`` block of an ``init_transport`` body or a + ``transport.nccl`` block of an ``update_weights`` body): + + group_name: str (required) + init_method: str (e.g. "tcp://trainer:29500", required at init) + trainer_world_size: int (required at init) + inference_world_size: int (required at init; usually == # workers) + dtype: str (e.g. "bf16") + + For ``update_weights``: + + weight_names: list[str] (the iteration order of named params + the trainer is broadcasting; required) + """ + + backend_id = "nccl" + + def __init__(self, engine_adapter: EngineAdapter, cfg: dict): + self._engine = engine_adapter + self._cfg = cfg or {} + self._state: TransportState = "configured" + self._transport_id: str = self._cfg.get("transport_id", "nccl") + + @property + def state(self) -> TransportState: + return self._state + + async def init(self, ctx: InitCtx, cfg: dict) -> InitResult: + cfg = cfg or {} + merged = {**self._cfg, **cfg} + # vLLM's init_weight_transfer_engine takes: + # master_address, master_port, rank_offset, world_size + # The trainer is rank 0; inference workers are rank_offset..world_size-1. + for required in ("master_address", "master_port", "world_size"): + if required not in merged: + raise ValueError( + f"nccl transport: '{required}' is required in init_transport" + ) + + self._cfg = merged + self._transport_id = merged.get("transport_id", self._transport_id) + + # Drive the worker-side bootstrap via vLLM's + # `init_weight_transfer_engine` collective. + try: + init_info = { + "master_address": str(merged["master_address"]), + "master_port": int(merged["master_port"]), + "rank_offset": int(merged.get("rank_offset", 1)), + "world_size": int(merged["world_size"]), + } + await self._engine.engine_client.collective_rpc( + "init_weight_transfer_engine", + kwargs={"init_info": init_info}, + ) + self._state = "ready" + return InitResult( + status="ok", + transport_id=self._transport_id, + ready=True, + message=( + f"nccl init_weight_transfer_engine ok " + f"(master={init_info['master_address']}:{init_info['master_port']}, " + f"world_size={init_info['world_size']})" + ), + extra={"init_info": init_info}, + ) + except Exception as exc: + self._state = "failed" + logger.error(f"[RL] nccl.init failed: {exc}") + raise + + async def teardown(self) -> None: + # vLLM doesn't expose an explicit destroy hook; engine teardown handles it. + self._state = "configured" + + async def update_weights( + self, req: UpdateWeightsRequest + ) -> UpdateResult: + # NCCL transport does not own LoRA hot-swap in this iteration; LoRA + # adapters are tiny enough that filesystem stays the better path. + if req.target.kind == "lora": + raise NotImplementedError( + "nccl transport: LoRA adapter transfer is deferred. Use " + "transport.backend='filesystem' for LoRA in this iteration." + ) + + nccl = req.transport.get("nccl") or {} + # The trainer must supply (names, dtype_names, shapes) so the worker + # knows how big each `torch.empty(...)` receive buffer should be. + names: Optional[list[str]] = nccl.get("names") or nccl.get("weight_names") + dtype_names: Optional[list[str]] = nccl.get("dtype_names") + shapes: Optional[list[list[int]]] = nccl.get("shapes") + if not names: + raise ValueError( + "nccl.update_weights: 'transport.nccl.names' is required" + ) + if not dtype_names or not shapes: + raise ValueError( + "nccl.update_weights: 'transport.nccl.dtype_names' and " + "'transport.nccl.shapes' are required" + ) + if len(dtype_names) != len(names) or len(shapes) != len(names): + raise ValueError( + f"nccl.update_weights: names/dtype_names/shapes length mismatch " + f"({len(names)} / {len(dtype_names)} / {len(shapes)})" + ) + + update_info = { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "is_checkpoint_format": bool(nccl.get("is_checkpoint_format", True)), + "packed": bool(nccl.get("packed", False)), + } + + self._state = "receiving" + try: + await self._engine.engine_client.collective_rpc( + "update_weights", + kwargs={"update_info": update_info}, + ) + finally: + self._state = "ready" + logger.info( + f"[RL] nccl.update_weights: {len(names)} weights received " + f"(version={req.version})" + ) + return UpdateResult( + status="ok", + message=f"Updated {len(names)} weights via nccl", + version=req.version, + extra={"weights_received": len(names)}, + ) diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index 5d122e15b3b9..9f7f51e5a79e 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -701,10 +701,21 @@ def register_engine_routes( "unload_lora_adapter", handler.unload_lora_adapter ) + # RL WeightTransferConfig API (Phase 1+4): unified transport surface + # for filesystem + nccl backends. Coexists with the legacy routes + # above; legacy callers continue to work unchanged. + runtime.register_engine_route( + "weight_transport_init", handler.weight_transport_init + ) + runtime.register_engine_route( + "weight_transport_update", handler.weight_transport_update + ) + logger.info( "Registered engine routes: sleep, wake_up, scale_elastic_ep, " "start_profile, stop_profile, pause_generation, resume_generation, " "flush_cache, update_weights_from_path, get_weight_version, " "get_state, liveness_probe, " - "load_lora_adapter, unload_lora_adapter" + "load_lora_adapter, unload_lora_adapter, " + "weight_transport_init, weight_transport_update" ) From 93a7e41197e0034e37ea89748f0ca576c922f968 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 01:43:18 -0700 Subject: [PATCH 13/18] =?UTF-8?q?style(llm):=20cargo=20fmt=20=E2=80=94=20s?= =?UTF-8?q?trip=20trailing=20whitespace=20+=20reformat=20delta.rs=20chaine?= =?UTF-8?q?d=20.or=5Felse()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/llm/src/protocols/anthropic/types.rs | 2 +- .../openai/chat_completions/delta.rs | 54 ++++++++++--------- lib/llm/src/protocols/openai/nvext.rs | 4 +- lib/llm/src/protocols/openai/responses/mod.rs | 14 ++--- lib/llm/src/protocols/unified.rs | 2 +- .../tests/parallel_tool_call_integration.rs | 2 +- lib/llm/tests/preprocessor.rs | 4 +- lib/llm/tests/test_streaming_usage.rs | 4 +- lib/llm/tests/tool_choice.rs | 2 +- lib/llm/tests/tool_choice_finish_reasons.rs | 2 +- 10 files changed, 47 insertions(+), 43 deletions(-) diff --git a/lib/llm/src/protocols/anthropic/types.rs b/lib/llm/src/protocols/anthropic/types.rs index 33bc1be37422..db2ef8406375 100644 --- a/lib/llm/src/protocols/anthropic/types.rs +++ b/lib/llm/src/protocols/anthropic/types.rs @@ -823,7 +823,7 @@ mod tests { }), }, nvext: None, - + prompt_token_ids: None, }; diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index facff839fc1d..5ff1ed5522c7 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -429,29 +429,33 @@ impl crate::protocols::openai::DeltaGeneratorExt NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } diff --git a/lib/llm/tests/preprocessor.rs b/lib/llm/tests/preprocessor.rs index 45a6893b656d..29dbfb8b0297 100644 --- a/lib/llm/tests/preprocessor.rs +++ b/lib/llm/tests/preprocessor.rs @@ -261,7 +261,7 @@ impl Request { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } @@ -704,7 +704,7 @@ mod context_length_validation { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } diff --git a/lib/llm/tests/test_streaming_usage.rs b/lib/llm/tests/test_streaming_usage.rs index 2f94498f6168..5357fbb2ee31 100644 --- a/lib/llm/tests/test_streaming_usage.rs +++ b/lib/llm/tests/test_streaming_usage.rs @@ -195,7 +195,7 @@ fn create_chat_request( chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } @@ -532,7 +532,7 @@ fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index da2120b9a274..0be68c5015d0 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -41,7 +41,7 @@ fn create_test_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs index 12c9b4ac8cab..9141556cae21 100644 --- a/lib/llm/tests/tool_choice_finish_reasons.rs +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -34,7 +34,7 @@ fn create_test_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), - + return_token_ids: None, tokens: None, } From 575afd9e3c48055d82d7fe4e521644dfab135132 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 02:01:15 -0700 Subject: [PATCH 14/18] =?UTF-8?q?feat(rl):=20PR=20C=20=E2=80=94=20drop=20l?= =?UTF-8?q?egacy=20/v1/rl/{state,health,ready,liveness,weight=5Fversion,*?= =?UTF-8?q?=5Flora=5Fadapter}=20+=20legacy=20update=5Fweights=20body;=20ga?= =?UTF-8?q?te=20listener=20on=20DYN=5FENABLE=5FRL=5FENDPOINTS?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/llm/src/http/service/service_v2.rs | 12 +- lib/rl/src/lib.rs | 639 ++----------------------- 2 files changed, 57 insertions(+), 594 deletions(-) diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 9429fe38f6d9..cda62c4a6319 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -542,8 +542,16 @@ impl HttpServiceConfigBuilder { super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), ]; - // RL admin routes: enabled when builder flag is set OR when DYN_ENABLE_RL env var is truthy. - if config.enable_rl || env_is_truthy("DYN_ENABLE_RL") { + // RL admin routes: gated by `DYN_ENABLE_RL_ENDPOINTS` (frontend-only). + // `DYN_ENABLE_RL` is preserved as a fallback alias for the previous + // single-flag deployment shape until clients migrate. The + // builder-time `enable_rl` flag forces routes on regardless of env. + // PR C of `rl-crate.md`: split inference-plane (DYN_ENABLE_RL) from + // admin-plane (DYN_ENABLE_RL_ENDPOINTS). + if config.enable_rl + || env_is_truthy("DYN_ENABLE_RL_ENDPOINTS") + || env_is_truthy("DYN_ENABLE_RL") + { tracing::info!("RL admin routes enabled at /v1/rl/*"); system_routes.push(super::openai::rl_router()?); } diff --git a/lib/rl/src/lib.rs b/lib/rl/src/lib.rs index ac9deadb9ce1..0b5e52b824e1 100644 --- a/lib/rl/src/lib.rs +++ b/lib/rl/src/lib.rs @@ -24,7 +24,7 @@ use axum::{ extract::State, http::{Method, StatusCode}, response::IntoResponse, - routing::{get, post}, + routing::post, }; /// Documentation tuple for an RL admin route. The dynamo-llm caller wraps @@ -53,12 +53,10 @@ const DYN_RL_WORKER_SYSTEM_URLS_ENV: &str = "DYN_RL_WORKER_SYSTEM_URLS"; struct RlState { /// Worker system HTTP base URLs (e.g. `http://localhost:8081`). /// Set via `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated list). + /// PR B (deferred) replaces this with discovery-backed enumeration. worker_system_urls: Vec, /// Shared HTTP client for all fan-out calls to worker system ports. http_client: reqwest::Client, - /// Per-worker probe timeout for `/v1/rl/liveness` and `/v1/rl/ready`. - /// Read once from `DYN_RL_LIVENESS_TIMEOUT_MS` at construction. - probe_timeout: std::time::Duration, } impl RlState { @@ -69,38 +67,24 @@ impl RlState { .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect::>(); - let probe_timeout_ms = std::env::var("DYN_RL_LIVENESS_TIMEOUT_MS") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(5000); tracing::info!( worker_count = worker_system_urls.len(), ?worker_system_urls, - probe_timeout_ms, "RL admin router configured" ); let http_client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(600)) .build() .map_err(|e| anyhow::anyhow!("Failed to build RL router HTTP client: {e}"))?; - Ok(Self::new( - worker_system_urls, - http_client, - std::time::Duration::from_millis(probe_timeout_ms), - )) + Ok(Self::new(worker_system_urls, http_client)) } /// Test-friendly constructor — bypasses env reading so tests can pass in /// fake worker URLs and a stubbed `reqwest::Client`. - fn new( - worker_system_urls: Vec, - http_client: reqwest::Client, - probe_timeout: std::time::Duration, - ) -> Self { + fn new(worker_system_urls: Vec, http_client: reqwest::Client) -> Self { Self { worker_system_urls, http_client, - probe_timeout, } } @@ -149,43 +133,6 @@ impl RlState { } } -/// `GET /v1/rl/ready` — composite readiness check: worker health via system port. -/// -/// Bounded with a per-worker probe timeout (default 5s, override via -/// `DYN_RL_LIVENESS_TIMEOUT_MS`) so a wedged worker fails fast as 503 instead -/// of hanging on the shared 600s `http_client` timeout. -async fn rl_ready(State(state): State>) -> impl IntoResponse { - let timeout = state.probe_timeout; - let futures: Vec<_> = state - .worker_system_urls - .iter() - .map(|url| { - let client = state.http_client.clone(); - let health_url = format!("{url}/health"); - async move { - match tokio::time::timeout(timeout, client.get(&health_url).send()).await { - Ok(Ok(resp)) => resp.status().is_success(), - Ok(Err(_)) | Err(_) => false, - } - } - }) - .collect(); - let results = futures::future::join_all(futures).await; - let all_ready = !results.is_empty() && results.iter().all(|ok| *ok); - if all_ready { - (StatusCode::OK, Json(serde_json::json!({"status": "ready"}))) - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "status": "not_ready", - "workers_ready": results.iter().filter(|ok| **ok).count(), - "workers_total": results.len() - })), - ) - } -} - /// `POST /v1/rl/pause` — fan out `pause_generation` to all workers. /// /// Query params (both optional): @@ -326,72 +273,38 @@ async fn rl_resume(State(state): State>) -> impl IntoResponse { /// /// The pause/resume envelope is left to the caller; full-FT updates MUST /// bracket this call with `/v1/rl/pause` and `/v1/rl/resume`. +/// +/// **Phase 3 (PR C):** the legacy `{weight_dir, weight_version, reset_prefix_cache}` +/// body is gone. Every caller now provides `version`, `target`, and +/// `transport`. LoRA load/swap/unload also go through this same body via +/// `target.kind = "lora"` — see `weight-transfer-config.md` § 2. #[derive(Debug, serde::Deserialize)] -#[serde(untagged)] -enum RlUpdateWeightsBody { - /// New shape — required field is `transport`. Serde tries this variant - /// first; falls back to legacy if it fails. - NewShape { - version: String, - target: serde_json::Value, - transport: serde_json::Value, - #[serde(default)] - pause_mode: Option, - #[serde(default)] - clear_cache: Option, - }, - /// Legacy single-arg body kept live during Phase 1 / 2. - Legacy { - weight_dir: Option, - #[serde(default)] - weight_version: Option, - #[serde(default = "default_reset_prefix_cache")] - reset_prefix_cache: bool, - }, -} - -fn default_reset_prefix_cache() -> bool { - true +struct RlUpdateWeightsBody { + version: String, + target: serde_json::Value, + transport: serde_json::Value, + #[serde(default)] + pause_mode: Option, + #[serde(default)] + clear_cache: Option, } async fn rl_update_weights( State(state): State>, body: axum::extract::Json, ) -> impl IntoResponse { - // Dispatch on body shape. New shape goes through the WeightTransferConfig - // worker route; legacy keeps the existing flush_cache → update_weights_from_path - // sequence so unmigrated callers continue to work. - match body.0 { - RlUpdateWeightsBody::NewShape { - version, - target, - transport, - pause_mode, - clear_cache, - } => { - return rl_update_weights_new_shape( - state, - version, - target, - transport, - pause_mode, - clear_cache, - ) - .await; - } - RlUpdateWeightsBody::Legacy { - weight_dir, - weight_version, - reset_prefix_cache, - } => { - return rl_update_weights_legacy(state, weight_dir, weight_version, reset_prefix_cache) - .await; - } - } + let RlUpdateWeightsBody { + version, + target, + transport, + pause_mode, + clear_cache, + } = body.0; + rl_update_weights_inner(state, version, target, transport, pause_mode, clear_cache).await } -/// New WeightTransferConfig path — fans out to ``weight_transport_update``. -async fn rl_update_weights_new_shape( +/// WeightTransferConfig path — fans out to ``weight_transport_update``. +async fn rl_update_weights_inner( state: Arc, version: String, target: serde_json::Value, @@ -407,7 +320,7 @@ async fn rl_update_weights_new_shape( version = %version, backend = %backend, ?target, - "RL update_weights (new shape)" + "RL update_weights" ); let mut body = serde_json::json!({ "version": version, @@ -426,7 +339,7 @@ async fn rl_update_weights_new_shape( worker_count = results.len(), backend = %backend, version = %version, - "RL update_weights (new shape): all workers updated" + "RL update_weights: all workers updated" ); ( StatusCode::OK, @@ -438,7 +351,7 @@ async fn rl_update_weights_new_shape( })), ) } else { - tracing::warn!(?results, backend = %backend, "RL update_weights (new shape): some workers failed"); + tracing::warn!(?results, backend = %backend, "RL update_weights: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({ @@ -451,95 +364,6 @@ async fn rl_update_weights_new_shape( } } -/// Legacy single-arg body — Phase 1 backward-compat. -async fn rl_update_weights_legacy( - state: Arc, - weight_dir: Option, - weight_version: Option, - reset_prefix_cache: bool, -) -> (StatusCode, Json) { - // Treat empty string the same as missing/null (NCCL no-op). Otherwise - // an empty string would reach the engine as `path=""` and fail - // confusingly downstream. - let weight_dir = weight_dir - .as_ref() - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(str::to_string); - let Some(weight_dir) = weight_dir else { - tracing::info!("RL update_weights: weight_dir=null (NCCL mode, no-op on Dynamo side)"); - return ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "message": "NCCL mode, no-op on Dynamo side" - })), - ); - }; - - let version = weight_version.clone().unwrap_or_else(|| { - std::path::Path::new(&weight_dir) - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string() - }); - tracing::info!( - weight_dir = %weight_dir, - version = %version, - reset_prefix_cache, - "RL update_weights" - ); - - // Step 1 (optional): flush_cache across all workers. - if reset_prefix_cache { - let flush_results = state.fan_out("flush_cache", serde_json::json!({})).await; - if !RlState::all_ok(&flush_results) { - tracing::warn!(?flush_results, "RL update_weights: flush_cache failed"); - return ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({ - "status": "error", - "stage": "flush_cache", - "workers": flush_results - })), - ); - } - } - - // Step 2: update_weights_from_path across all workers. - let load_body = serde_json::json!({"path": &weight_dir, "version": version}); - let load_results = state.fan_out("update_weights_from_path", load_body).await; - if RlState::all_ok(&load_results) { - tracing::info!( - worker_count = load_results.len(), - weight_dir = %weight_dir, - "RL update_weights: all workers updated" - ); - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "applied_weight_version": version, - "workers": load_results, - })), - ) - } else { - tracing::warn!( - ?load_results, - "RL update_weights: update_weights_from_path failed" - ); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({ - "status": "error", - "stage": "update_weights_from_path", - "workers": load_results - })), - ) - } -} - /// `POST /v1/rl/init_transport` — idempotent one-time setup for a weight /// transport (filesystem / nccl). Replaces backend-specific bring-up /// endpoints with a single discriminated body. @@ -604,408 +428,39 @@ async fn rl_init_transport( } } -/// `POST /v1/rl/load_lora_adapter` — hot-load/swap a LoRA adapter from a filesystem path. -/// -/// Expected body: `{"lora_name": "r16-a32.0", "lora_path": "/path/to/adapter_dir"}` -/// -/// The adapter directory must contain PEFT-style `adapter_model.safetensors` and -/// `adapter_config.json`. This is the RL-specific LoRA path used by Prime-RL every -/// training step (separate from Dynamo's URI-based `load_lora` gRPC endpoint which -/// downloads adapters from S3/file URIs and publishes a new ModelDeploymentCard). -/// -/// Hot-swap semantics: calling with a `lora_name` that is already loaded removes -/// the previous adapter and loads the new one under the same deterministic int ID, -/// then resets the prefix cache so stale KV entries don't poison new rollouts. -/// -/// Pair with `/v1/rl/pause` and `/v1/rl/resume` for a full drain-swap-resume cycle. -async fn rl_load_lora_adapter( - State(state): State>, - body: axum::extract::Json, -) -> impl IntoResponse { - let lora_name = body.get("lora_name").and_then(|v| v.as_str()); - let lora_path = body.get("lora_path").and_then(|v| v.as_str()); - - let (lora_name, lora_path) = match (lora_name, lora_path) { - (Some(n), Some(p)) if !n.is_empty() && !p.is_empty() => (n.to_string(), p.to_string()), - _ => { - return ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "status": "error", - "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)" - })), - ); - } - }; - - tracing::info!(%lora_name, %lora_path, "RL load_lora_adapter"); - let results = state - .fan_out( - "load_lora_adapter", - serde_json::json!({"lora_name": &lora_name, "lora_path": &lora_path}), - ) - .await; - - if RlState::all_ok(&results) { - tracing::info!( - worker_count = results.len(), - %lora_name, - %lora_path, - "RL load_lora_adapter: all workers loaded" - ); - ( - StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), - ) - } else { - tracing::warn!(?results, %lora_name, "RL load_lora_adapter: some workers failed"); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), - ) - } -} - -/// `POST /v1/rl/unload_lora_adapter` — remove a previously loaded LoRA adapter by name. -/// -/// Expected body: `{"lora_name": "r16-a32.0"}` -/// -/// Idempotent: unloading an already-absent LoRA returns `status: ok` so callers -/// can retry safely without special-casing not-found. -async fn rl_unload_lora_adapter( - State(state): State>, - body: axum::extract::Json, -) -> impl IntoResponse { - let lora_name = body - .get("lora_name") - .and_then(|v| v.as_str()) - .map(str::to_string); - - let lora_name = match lora_name { - Some(n) if !n.is_empty() => n, - _ => { - return ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "status": "error", - "message": "Expected body: {\"lora_name\": str} (required, non-empty)" - })), - ); - } - }; - - tracing::info!(%lora_name, "RL unload_lora_adapter"); - let results = state - .fan_out( - "unload_lora_adapter", - serde_json::json!({"lora_name": &lora_name}), - ) - .await; - - if RlState::all_ok(&results) { - tracing::info!( - worker_count = results.len(), - %lora_name, - "RL unload_lora_adapter: all workers unloaded" - ); - ( - StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), - ) - } else { - tracing::warn!(?results, %lora_name, "RL unload_lora_adapter: some workers failed"); - ( - StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), - ) - } -} - -/// `GET /v1/rl/weight_version` — query weight version from all workers. -async fn rl_weight_version(State(state): State>) -> impl IntoResponse { - let results = state - .fan_out("get_weight_version", serde_json::json!({})) - .await; - - // Collect distinct versions and check for consistency - let versions: Vec<_> = results - .iter() - .filter_map(|r| { - r.get("version") - .and_then(|v| v.as_str()) - .map(str::to_string) - }) - .collect(); - - let unique: std::collections::HashSet<&str> = versions.iter().map(String::as_str).collect(); - if unique.len() == 1 { - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "version": unique.into_iter().next().unwrap_or(""), - "workers": results - })), - ) - } else { - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "inconsistent", - "versions": unique.into_iter().collect::>(), - "workers": results - })), - ) - } -} - -async fn rl_health() -> impl IntoResponse { - (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) -} - -/// `GET /v1/rl/liveness` — engine event-loop probe via the `liveness_probe` -/// engine route. The legacy `/v1/rl/health` returns OK as long as the -/// frontend process is up; this endpoint round-trips through the engine so -/// a hung event loop or wedged worker surfaces as 503. -/// -/// Each per-worker call carries a 5s timeout (override via -/// `DYN_RL_LIVENESS_TIMEOUT_MS`). Returns 200 only when every worker -/// reports `alive: true` within the deadline; 503 otherwise. -async fn rl_liveness(State(state): State>) -> impl IntoResponse { - if state.worker_system_urls.is_empty() { - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "status": "error", - "alive": false, - "message": "no workers registered" - })), - ); - } - let timeout = state.probe_timeout; - - let futures: Vec<_> = state - .worker_system_urls - .iter() - .map(|url| { - let client = state.http_client.clone(); - let endpoint = format!("{url}/engine/liveness_probe"); - async move { - tokio::time::timeout( - timeout, - async { - match client.post(&endpoint).json(&serde_json::json!({})).send().await { - Ok(resp) => resp - .json::() - .await - .unwrap_or_else(|e| serde_json::json!({ - "status": "error", - "alive": false, - "message": format!("decode failed: {e}") - })), - Err(e) => serde_json::json!({ - "status": "error", - "alive": false, - "message": format!("request failed: {e}") - }), - } - }, - ) - .await - .unwrap_or_else(|_| serde_json::json!({ - "status": "error", - "alive": false, - "message": format!("liveness_probe timed out after {}ms", timeout.as_millis()) - })) - } - }) - .collect(); - let results = futures::future::join_all(futures).await; - let all_alive = results - .iter() - .all(|r| r.get("alive").and_then(|v| v.as_bool()) == Some(true)); - if all_alive { - ( - StatusCode::OK, - Json(serde_json::json!({ - "status": "ok", - "alive": true, - "workers": results, - })), - ) - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "status": "error", - "alive": false, - "workers": results, - })), - ) - } -} - -/// `GET /v1/rl/state` — composite RL fleet state snapshot. -/// -/// Replaces three v1 endpoints (`/v1/rl/health` + `/v1/rl/ready` + -/// `/v1/rl/weight_version`) with a single composite, scoped to RL-specific -/// readiness (engine alive, pause state, applied weight version, loaded -/// LoRAs). -/// -/// Aggregates per-worker `get_state` engine-route responses into: -/// -/// ```json -/// { -/// "ready": bool, -/// "ingress_alive": true, -/// "engine_alive": bool, // every worker's engine.check_health() ok -/// "pause_state": "running"|"paused"|"mixed", -/// "applied_weight_version": str, // when consistent across workers; null if mixed -/// "loras": [{name, loaded_on: [worker_idx]}], -/// "workers": [] -/// } -/// ``` -/// -/// `ingress_alive` is unconditionally `true` because reaching this handler -/// means the frontend HTTP listener is up. `ready = ingress_alive AND -/// engine_alive AND len(workers) > 0`. -async fn rl_state(State(state): State>) -> impl IntoResponse { - if state.worker_system_urls.is_empty() { - return ( - StatusCode::SERVICE_UNAVAILABLE, - Json(serde_json::json!({ - "ready": false, - "ingress_alive": true, - "engine_alive": false, - "pause_state": "running", - "applied_weight_version": null, - "loras": [], - "workers": [], - "status": "error", - "message": "no workers registered" - })), - ); - } - let results = state.fan_out("get_state", serde_json::json!({})).await; - - let engine_alive = results - .iter() - .all(|r| r.get("engine_alive").and_then(|v| v.as_bool()) == Some(true)); - - // Aggregate pause_state: if all workers agree, surface that; else "mixed". - let pause_states: std::collections::HashSet<&str> = results - .iter() - .filter_map(|r| r.get("pause_state").and_then(|v| v.as_str())) - .collect(); - let pause_state = if pause_states.len() == 1 { - pause_states - .into_iter() - .next() - .unwrap_or("running") - .to_string() - } else if pause_states.is_empty() { - "running".to_string() - } else { - "mixed".to_string() - }; - - // applied_weight_version is reported only when consistent. - let weight_versions: std::collections::HashSet<&str> = results - .iter() - .filter_map(|r| r.get("applied_weight_version").and_then(|v| v.as_str())) - .collect(); - let applied_weight_version: Option = if weight_versions.len() == 1 { - weight_versions.into_iter().next().map(|s| s.to_string()) - } else { - None - }; - - // LoRA name → list of worker indices that have it loaded. - let mut lora_loaded_on: std::collections::BTreeMap> = - std::collections::BTreeMap::new(); - for (idx, worker) in results.iter().enumerate() { - if let Some(loras) = worker.get("loras").and_then(|v| v.as_array()) { - for lora in loras { - if let Some(name) = lora.get("name").and_then(|v| v.as_str()) { - lora_loaded_on - .entry(name.to_string()) - .or_default() - .push(idx); - } - } - } - } - let loras: Vec = lora_loaded_on - .into_iter() - .map(|(name, loaded_on)| serde_json::json!({"name": name, "loaded_on": loaded_on})) - .collect(); - - let ready = engine_alive && !results.is_empty(); - let body = serde_json::json!({ - "ready": ready, - "ingress_alive": true, - "engine_alive": engine_alive, - "pause_state": pause_state, - "applied_weight_version": applied_weight_version, - "loras": loras, - "workers": results, - }); - let status = if ready { - StatusCode::OK - } else { - StatusCode::SERVICE_UNAVAILABLE - }; - (status, Json(body)) -} - /// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. /// /// Worker system URLs are read from the `DYN_RL_WORKER_SYSTEM_URLS` environment -/// variable (comma-separated, defaults to `http://localhost:8081`). +/// variable (comma-separated, defaults to `http://localhost:8081`). Phase B of +/// `rl-crate.md` will replace this with discovery-backed fan-out; until then +/// the static URL list is the source of truth. /// -/// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` is set. +/// **Surface:** four POST routes after Phase 3 (this PR). Read-side endpoints +/// (`state`, `health`, `ready`, `liveness`, `weight_version`) and the +/// dedicated LoRA routes (`load_lora_adapter`, `unload_lora_adapter`) are +/// dropped — replacements piggyback on the frontend's existing `/live` and +/// `/health`, and LoRA flows through `update_weights {target.kind="lora"}`. +/// See `weight-transfer-config.md` § "Constraints from existing surface". /// -/// Prime-RL usage: set `admin_base_url = ["http://dynamo-frontend:8000/v1/rl"]` -/// in the orchestrator config. Prime-RL strips the trailing `/v1` suffix only -/// if present, so `/v1/rl` is preserved and all routes resolve correctly. +/// Mounted on the dedicated `/v1/rl/*` listener when +/// `DYN_ENABLE_RL_ENDPOINTS=true`. prime-rl usage: +/// `admin_base_url = "http://dynamo-frontend:8002/v1/rl"`. pub fn rl_router() -> anyhow::Result<(Vec, Router)> { let rl_state_arc = Arc::new(RlState::from_env()?); let docs = vec![ - // Phase 1: composite endpoints. - RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/state"), - RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/liveness"), - // Pause / resume / update_weights bracket. + // Pause / resume bracket. RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/resume"), - RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), - // WeightTransferConfig API (Phase 1+4) — idempotent transport setup. + // WeightTransferConfig API: init + discriminated update_weights body + // covering both base-model reload and LoRA load/swap/unload. RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/init_transport"), - // LoRA hot-swap. - RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/load_lora_adapter"), - // Legacy (deprecated; subsumed by /v1/rl/state — Phase 5 will drop): - RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/health"), - RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/ready"), - RlRouteDoc::new(axum::http::Method::GET, "/v1/rl/weight_version"), - RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/unload_lora_adapter"), + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), ]; let router = Router::new() - // Phase 1: composite read-only endpoints. - .route("/v1/rl/state", get(rl_state)) - .route("/v1/rl/liveness", get(rl_liveness)) - // Pause / resume / update_weights bracket. .route("/v1/rl/pause", post(rl_pause)) .route("/v1/rl/resume", post(rl_resume)) - .route("/v1/rl/update_weights", post(rl_update_weights)) - // WeightTransferConfig API (Phase 1+4) — idempotent transport setup. .route("/v1/rl/init_transport", post(rl_init_transport)) - // LoRA hot-swap. - .route("/v1/rl/load_lora_adapter", post(rl_load_lora_adapter)) - // Legacy endpoints — kept for back-compat until existing clients - // migrate to /v1/rl/state. Removed in a follow-up. - .route("/v1/rl/health", get(rl_health)) - .route("/v1/rl/ready", get(rl_ready)) - .route("/v1/rl/weight_version", get(rl_weight_version)) - .route("/v1/rl/unload_lora_adapter", post(rl_unload_lora_adapter)) + .route("/v1/rl/update_weights", post(rl_update_weights)) .with_state(rl_state_arc); Ok((docs, router)) } From b6e471de03d471a2b258f6e26e0d93e95b806cd9 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 02:43:22 -0700 Subject: [PATCH 15/18] =?UTF-8?q?feat(rl):=20PR=20B=20=E2=80=94=20request-?= =?UTF-8?q?plane=20fan-out=20via=20Discovery+PushRouter=20(worker=20regist?= =?UTF-8?q?ers=20..rl=20with=20rl=5Fdispatch;=20frontend=20dispa?= =?UTF-8?q?tches=20via=20Client::wait=5Ffor=5Finstances=20+=20direct())?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 1 - components/src/dynamo/vllm/handlers.py | 53 ++++ components/src/dynamo/vllm/worker_factory.py | 16 ++ lib/bindings/python/Cargo.lock | 1 - lib/llm/src/entrypoint/input/http.rs | 6 + lib/llm/src/http/service/openai.rs | 6 +- lib/llm/src/http/service/service_v2.rs | 21 +- lib/rl/Cargo.toml | 1 - lib/rl/src/lib.rs | 287 +++++++++++++------ 9 files changed, 297 insertions(+), 95 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bd0ce7b991c8..663e2c024271 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2615,7 +2615,6 @@ dependencies = [ "axum 0.8.4", "dynamo-runtime", "futures", - "reqwest 0.12.28", "serde", "serde_json", "tokio", diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 4e477084fc70..792503cec317 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -1429,6 +1429,59 @@ async def weight_transport_update(self, body: dict) -> dict: logger.exception(f"[RL] weight_transport_update failed: {e}") return {"status": "error", "message": str(e)} + # ── PR B: unified `rl` request-plane endpoint ───────────────────── + # + # Worker registers ``dyn://..rl`` and serves this + # dispatcher. The frontend (dynamo-rl crate) discovers live `rl` + # instances via the standard discovery plane and dispatches via + # ``PushRouter::direct`` over NATS / shared TCP — no system-port HTTP + # fan-out, no static `DYN_RL_WORKER_SYSTEM_URLS` list. + # + # Wire shape: ``{"op": str, "body": dict}`` where `op` is one of + # ``pause | resume | init_transport | update_weights``. The dispatcher + # routes to the existing per-op handlers and yields a single response + # dict (matching the serve_endpoint async-generator contract used by + # ``generate``, ``load_lora``, etc.). + # + # Legacy ``register_engine_route`` HTTP-on-system-port routes stay + # live during PR B / PR C overlap so unmigrated callers don't break. + async def rl_dispatch(self, request=None): + """Single-endpoint RL admin dispatcher (PR B). + + Async generator yielding exactly one response dict per call. + """ + if request is None: + yield {"status": "error", "message": "rl_dispatch: request required"} + return + op = request.get("op") + body = request.get("body") or {} + if not isinstance(op, str) or not op: + yield { + "status": "error", + "message": "rl_dispatch: missing 'op' (str)", + } + return + try: + if op == "pause": + yield await self.pause_generation(body) + elif op == "resume": + yield await self.resume_generation(body) + elif op == "init_transport": + yield await self.weight_transport_init(body) + elif op == "update_weights": + yield await self.weight_transport_update(body) + else: + yield { + "status": "error", + "message": ( + f"rl_dispatch: unknown op {op!r}; expected one of " + "pause|resume|init_transport|update_weights" + ), + } + except Exception as e: + logger.exception(f"[RL] rl_dispatch op={op!r} failed: {e}") + yield {"status": "error", "op": op, "message": str(e)} + @abstractmethod def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]: raise NotImplementedError diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index 9f7f51e5a79e..9018f22a9fd1 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -241,9 +241,19 @@ async def _create_decode_worker( f"{config.namespace}.{config.component}.clear_kv_blocks" ) + # PR B: unified RL admin endpoint on the request plane. Discoverable + # via etcd as ``..rl``; the dynamo-rl frontend crate + # uses Discovery::list(NamespacedEndpoints) + PushRouter::direct to + # fan out admin ops here, replacing the legacy HTTP-on-system-port + # ``register_engine_route("pause_generation", …)`` etc. mechanism. + rl_endpoint = runtime.endpoint( + f"{config.namespace}.{config.component}.rl" + ) + shutdown_endpoints[:] = [ generate_endpoint, clear_endpoint, + rl_endpoint, ] lora_enabled = config.engine_args.enable_lora @@ -442,6 +452,12 @@ async def _create_decode_worker( handler.get_perf_metrics, metrics_labels=model_metrics_labels, ), + # PR B: unified RL admin endpoint (rl_dispatch dispatches + # by op name to pause/resume/init_transport/update_weights). + rl_endpoint.serve_endpoint( + handler.rl_dispatch, + metrics_labels=model_metrics_labels, + ), ] if lora_enabled: diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index 1a31b48b50ac..c8b32e0b194d 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -2250,7 +2250,6 @@ dependencies = [ "axum", "dynamo-runtime", "futures", - "reqwest", "serde", "serde_json", "tokio", diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index a9bdf1c6c09a..28b231ae3e67 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -63,6 +63,12 @@ pub async fn run( http_service_builder = http_service_builder.drt_discovery(Some(distributed_runtime.discovery())); + // Wire the full DRT so the RL admin router (when DYN_ENABLE_RL_ENDPOINTS=true) + // can use the discovery + request planes to fan out to live `..rl` + // worker endpoints. + http_service_builder = + http_service_builder.runtime(Some(Arc::new(distributed_runtime.clone()))); + let http_service = match engine_config { EngineConfig::Dynamic { ref model, diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index e5e4c6c9e8a7..c2cac10cf122 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -2979,8 +2979,10 @@ fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { /// /// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` /// is set. Mounted by `service_v2.rs`. -pub fn rl_router() -> anyhow::Result<(Vec, Router)> { - let (rl_docs, router) = dynamo_rl::rl_router()?; +pub fn rl_router( + drt: std::sync::Arc, +) -> anyhow::Result<(Vec, Router)> { + let (rl_docs, router) = dynamo_rl::rl_router(drt)?; let docs = rl_docs .into_iter() .map(|d| RouteDoc::new(d.method, d.path)) diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index cda62c4a6319..f42f2fdb93be 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -270,6 +270,12 @@ pub struct HttpServiceConfig { /// are registered using discovery.instance_id() and exposed on /metrics. #[builder(default = "None")] drt_discovery: Option>, + + /// Required when `enable_rl` (or `DYN_ENABLE_RL_ENDPOINTS=true`): the + /// dynamo-rl crate uses this runtime's discovery + request planes to + /// fan out admin calls to live `..rl` endpoint instances. + #[builder(default = "None")] + runtime: Option>, } impl HttpService { @@ -552,8 +558,19 @@ impl HttpServiceConfigBuilder { || env_is_truthy("DYN_ENABLE_RL_ENDPOINTS") || env_is_truthy("DYN_ENABLE_RL") { - tracing::info!("RL admin routes enabled at /v1/rl/*"); - system_routes.push(super::openai::rl_router()?); + match config.runtime.as_ref() { + Some(drt) => { + tracing::info!("RL admin routes enabled at /v1/rl/* (request-plane fan-out)"); + system_routes.push(super::openai::rl_router(drt.clone())?); + } + None => { + tracing::warn!( + "RL admin routes requested (DYN_ENABLE_RL_ENDPOINTS=true) but \ + HttpServiceConfigBuilder.runtime is None — skipping mount. \ + The frontend caller must supply the DistributedRuntime." + ); + } + } } let mut system_router = axum::Router::new(); for (route_docs, route) in system_routes { diff --git a/lib/rl/Cargo.toml b/lib/rl/Cargo.toml index ac3fcdae5e14..99ed788636e7 100644 --- a/lib/rl/Cargo.toml +++ b/lib/rl/Cargo.toml @@ -24,4 +24,3 @@ tokio = { workspace = true } tracing = { workspace = true } anyhow = { workspace = true } futures = { workspace = true } -reqwest = { workspace = true } diff --git a/lib/rl/src/lib.rs b/lib/rl/src/lib.rs index 0b5e52b824e1..9629779a57fb 100644 --- a/lib/rl/src/lib.rs +++ b/lib/rl/src/lib.rs @@ -5,17 +5,13 @@ //! //! See `plans/rl-crate.md` and `plans/weight-transfer-config.md`. //! -//! **PR A status:** pure refactor — handlers + state moved verbatim out of -//! `lib/llm/src/http/service/openai.rs` so the admin code lives in its own -//! crate. Behavior unchanged. Future work (per the plan): -//! -//! - **PR B:** replace `worker_system_urls: Vec` (HTTP system-port -//! fan-out, env-driven) with discovery-backed fan-out via the dynamo -//! request plane. Drop `reqwest::Client`. Drop `DYN_RL_WORKER_SYSTEM_URLS`. -//! - **PR C:** introduce `DYN_ENABLE_RL_ENDPOINTS` (frontend-only) to gate -//! this router on a separate Axum listener (`DYN_RL_PORT` / `--rl-port`). -//! `DYN_ENABLE_RL` keeps its meaning as the inference-plane RL extensions -//! gate plus worker-side engine-route registration. +//! **PR B status:** request-plane fan-out via the dynamo discovery plane. +//! Workers register one endpoint `dyn://..rl` (see +//! `worker_factory.py::rl_endpoint.serve_endpoint(handler.rl_dispatch, …)`) +//! and the frontend dispatches by listing live `rl` instances and calling +//! each via [`PushRouter::direct`]. The legacy `register_engine_route` +//! HTTP-on-system-port mechanism + `DYN_RL_WORKER_SYSTEM_URLS` static URL +//! list are gone. use std::sync::Arc; @@ -26,6 +22,15 @@ use axum::{ response::IntoResponse, routing::post, }; +use dynamo_runtime::{ + DistributedRuntime, + pipeline::{ + SingleIn, + network::egress::push_router::{PushRouter, RouterMode}, + }, + protocols::annotated::Annotated, +}; +use futures::StreamExt; /// Documentation tuple for an RL admin route. The dynamo-llm caller wraps /// each tuple into its own `RouteDoc` for `/openapi.json` aggregation. @@ -44,92 +49,194 @@ impl RlRouteDoc { } } -/// Environment variable for comma-separated worker system HTTP URLs. -/// Defaults to `http://localhost:8081` when not set. -const DYN_RL_WORKER_SYSTEM_URLS_ENV: &str = "DYN_RL_WORKER_SYSTEM_URLS"; - /// Shared state for the RL admin router. +/// +/// Holds a runtime handle, a target `.` pair, and the +/// name of the unified RL endpoint (always `"rl"`). Each fan-out call: +/// +/// 1. Lists live instances of `..rl` via discovery. +/// 2. Builds a [`PushRouter`] over the runtime's request plane (NATS / shared TCP). +/// 3. Calls [`PushRouter::direct`] per `instance_id` with a JSON +/// `{"op": , "body": }` envelope. +/// 4. Drains the response stream and extracts the first `Annotated.data`. #[derive(Clone)] struct RlState { - /// Worker system HTTP base URLs (e.g. `http://localhost:8081`). - /// Set via `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated list). - /// PR B (deferred) replaces this with discovery-backed enumeration. - worker_system_urls: Vec, - /// Shared HTTP client for all fan-out calls to worker system ports. - http_client: reqwest::Client, + drt: Arc, + namespace: String, + component: String, + /// The endpoint name workers serve their RL dispatcher on. Always `"rl"`. + rl_endpoint: String, } impl RlState { - fn from_env() -> anyhow::Result { - let worker_system_urls = std::env::var(DYN_RL_WORKER_SYSTEM_URLS_ENV) - .unwrap_or_else(|_| "http://localhost:8081".to_string()) - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect::>(); + fn from_env(drt: Arc) -> anyhow::Result { + let namespace = std::env::var("DYN_NAMESPACE").unwrap_or_else(|_| "dynamo".into()); + // Workers default to component="backend" (vLLM, sglang). Allow + // override for disagg / multi-component deployments. + let component = std::env::var("DYN_RL_COMPONENT").unwrap_or_else(|_| "backend".into()); + let rl_endpoint = "rl".to_string(); tracing::info!( - worker_count = worker_system_urls.len(), - ?worker_system_urls, - "RL admin router configured" + ns = %namespace, + comp = %component, + rl_endpoint = %rl_endpoint, + "RL admin router configured (request-plane discovery)" ); - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build RL router HTTP client: {e}"))?; - Ok(Self::new(worker_system_urls, http_client)) + Ok(Self { + drt, + namespace, + component, + rl_endpoint, + }) } - /// Test-friendly constructor — bypasses env reading so tests can pass in - /// fake worker URLs and a stubbed `reqwest::Client`. - fn new(worker_system_urls: Vec, http_client: reqwest::Client) -> Self { - Self { - worker_system_urls, - http_client, + /// Fan out an admin op to every live worker via the request plane. + /// + /// `route` is the legacy engine-route name (`pause_generation`, + /// `resume_generation`, `weight_transport_init`, `weight_transport_update`) + /// preserved from the call sites; we map it to the unified op name on + /// the wire. + /// + /// Source of truth for "which workers are live" is the + /// [`Client::instance_source`] watcher (etcd-backed), not a one-shot + /// discovery `list()`. PushRouter's `direct()` checks the same client + /// view internally — going through the client avoids the race where a + /// freshly-built client hasn't populated yet. + async fn fan_out(&self, route: &str, body: serde_json::Value) -> Vec { + let op = route_to_op(route); + + let endpoint = match self + .drt + .namespace(&self.namespace) + .and_then(|ns| ns.component(&self.component)) + { + Ok(comp) => comp.endpoint(&self.rl_endpoint), + Err(err) => { + tracing::warn!(%err, route, "RL fan_out: failed to build endpoint"); + return vec![serde_json::json!({ + "status": "error", + "message": format!("endpoint build failed: {err}"), + })]; + } + }; + + let client = match endpoint.client().await { + Ok(c) => c, + Err(err) => { + tracing::warn!(%err, route, "RL fan_out: failed to create endpoint client"); + return vec![serde_json::json!({ + "status": "error", + "message": format!("client create failed: {err}"), + })]; + } + }; + + // Bound the watcher-population race: wait until the client sees + // ≥1 instance (or a short deadline elapses, in which case we + // surface the empty-fanout warning below). 5s is generous — + // workers register synchronously on serve_endpoint() before they + // start serving traffic, so by the time anything POSTs `/v1/rl/*` + // they should already be in etcd. + let _ = tokio::time::timeout( + std::time::Duration::from_secs(5), + client.wait_for_instances(), + ) + .await; + + let instance_ids: Vec = client.instance_ids(); + if instance_ids.is_empty() { + tracing::warn!( + ns = %self.namespace, + comp = %self.component, + route, + "RL fan_out: no live workers under {}.{}.rl; \ + check DYN_NAMESPACE / DYN_RL_COMPONENT vs worker --component", + self.namespace, + self.component, + ); + return Vec::new(); } - } - /// Call a single engine route on one worker. Returns the JSON body. - async fn call_engine_route( - &self, - url: &str, - route: &str, - body: &serde_json::Value, - ) -> serde_json::Value { - let endpoint = format!("{url}/engine/{route}"); - match self.http_client.post(&endpoint).json(body).send().await { - Ok(resp) => { - let status = resp.status(); - match resp.json::().await { - Ok(v) => v, - Err(e) => serde_json::json!({ + let router = + match PushRouter::>::from_client( + client, + RouterMode::Direct, + ) + .await + { + Ok(r) => r, + Err(err) => { + tracing::warn!(%err, route, "RL fan_out: failed to build PushRouter"); + return vec![serde_json::json!({ "status": "error", - "message": format!("Failed to decode response from {endpoint}: {e}"), - "http_status": status.as_u16() - }), + "message": format!("PushRouter build failed: {err}"), + })]; } - } - Err(e) => serde_json::json!({ - "status": "error", - "message": format!("Request to {endpoint} failed: {e}") - }), - } - } + }; - /// Fan out an engine route call to all configured workers concurrently. - async fn fan_out(&self, route: &str, body: serde_json::Value) -> Vec { - let futures: Vec<_> = self - .worker_system_urls + let envelope = serde_json::json!({"op": op, "body": body}); + + let futures: Vec<_> = instance_ids .iter() - .map(|url| self.call_engine_route(url, route, &body)) + .copied() + .map(|id| { + let router = router.clone(); + let envelope = envelope.clone(); + async move { + let req = SingleIn::new(envelope.clone()); + match router.direct(req, id).await { + Ok(mut stream) => { + // Drain the first non-empty data chunk from the + // worker's async-generator response. + while let Some(chunk) = stream.next().await { + if let Some(data) = chunk.data { + return data; + } + if let Some(err) = chunk.error { + return serde_json::json!({ + "status": "error", + "instance_id": id, + "message": err.to_string(), + }); + } + } + serde_json::json!({ + "status": "error", + "instance_id": id, + "message": "empty response stream", + }) + } + Err(err) => serde_json::json!({ + "status": "error", + "instance_id": id, + "message": format!("dispatch failed: {err}"), + }), + } + } + }) .collect(); futures::future::join_all(futures).await } - /// Returns true only if all results have `status: "ok"`. + /// Returns true only if every result is `status: "ok"` AND there is at + /// least one. Empty fan-out (no workers found) is `503`, not silent OK. fn all_ok(results: &[serde_json::Value]) -> bool { - results - .iter() - .all(|r| r.get("status").and_then(|s| s.as_str()) == Some("ok")) + !results.is_empty() + && results + .iter() + .all(|r| r.get("status").and_then(|s| s.as_str()) == Some("ok")) + } +} + +/// Map a legacy engine-route name to the corresponding `rl_dispatch` op. +fn route_to_op(route: &str) -> &str { + match route { + "pause_generation" => "pause", + "resume_generation" => "resume", + "weight_transport_init" => "init_transport", + "weight_transport_update" => "update_weights", + // Anything else — pass through verbatim so `rl_dispatch` can return + // a meaningful "unknown op" error instead of us silently rewriting. + other => other, } } @@ -430,23 +537,27 @@ async fn rl_init_transport( /// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. /// -/// Worker system URLs are read from the `DYN_RL_WORKER_SYSTEM_URLS` environment -/// variable (comma-separated, defaults to `http://localhost:8081`). Phase B of -/// `rl-crate.md` will replace this with discovery-backed fan-out; until then -/// the static URL list is the source of truth. +/// **PR B:** fan-out goes through the dynamo discovery plane + request +/// plane. Workers register `..rl` (default +/// `dynamo.backend.rl`) on the request plane via +/// `runtime.endpoint(...).serve_endpoint(handler.rl_dispatch, ...)`. The +/// frontend lists live instances via [`DistributedRuntime::discovery`] +/// + [`DiscoveryQuery::NamespacedEndpoints`] and dispatches each call via +/// [`PushRouter::direct`] over NATS / shared TCP. /// -/// **Surface:** four POST routes after Phase 3 (this PR). Read-side endpoints -/// (`state`, `health`, `ready`, `liveness`, `weight_version`) and the -/// dedicated LoRA routes (`load_lora_adapter`, `unload_lora_adapter`) are -/// dropped — replacements piggyback on the frontend's existing `/live` and -/// `/health`, and LoRA flows through `update_weights {target.kind="lora"}`. +/// **Surface:** four POST routes after Phase 3. +/// `pause`, `resume`, `init_transport`, `update_weights`. Read-side +/// endpoints (`state`, `health`, `ready`, `liveness`, `weight_version`) +/// and the dedicated LoRA routes (`load_lora_adapter`, `unload_lora_adapter`) +/// are dropped — replacements piggyback on the frontend's existing `/live` +/// and `/health`, and LoRA flows through `update_weights {target.kind="lora"}`. /// See `weight-transfer-config.md` § "Constraints from existing surface". /// /// Mounted on the dedicated `/v1/rl/*` listener when /// `DYN_ENABLE_RL_ENDPOINTS=true`. prime-rl usage: -/// `admin_base_url = "http://dynamo-frontend:8002/v1/rl"`. -pub fn rl_router() -> anyhow::Result<(Vec, Router)> { - let rl_state_arc = Arc::new(RlState::from_env()?); +/// `admin_base_url = "http://dynamo-frontend:8000/v1/rl"`. +pub fn rl_router(drt: Arc) -> anyhow::Result<(Vec, Router)> { + let rl_state_arc = Arc::new(RlState::from_env(drt)?); let docs = vec![ // Pause / resume bracket. RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), From 626d3e44dfe4b66d04390a826ba2e39824a44d80 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 03:35:48 -0700 Subject: [PATCH 16/18] rl: add request-plane admin fanout --- .../src/dynamo/frontend/frontend_args.py | 11 + components/src/dynamo/frontend/main.py | 1 + components/src/dynamo/vllm/handlers.py | 34 +- docs/RL.md | 667 ++++++++++++++ docs/dynamo-RL-api.md | 569 ------------ docs/index.yml | 3 + lib/bindings/python/rust/lib.rs | 43 + lib/bindings/python/rust/llm/entrypoint.rs | 6 +- lib/bindings/python/src/dynamo/_core.pyi | 14 + lib/llm/src/entrypoint/input/http.rs | 3 + lib/llm/src/http/service/openai.rs | 19 +- lib/llm/src/http/service/service_v2.rs | 340 ++++--- lib/llm/src/local_model.rs | 14 + lib/rl/src/lib.rs | 861 ++++++++++++++---- .../pipeline/network/egress/push_router.rs | 37 + 15 files changed, 1738 insertions(+), 884 deletions(-) create mode 100644 docs/RL.md delete mode 100644 docs/dynamo-RL-api.md diff --git a/components/src/dynamo/frontend/frontend_args.py b/components/src/dynamo/frontend/frontend_args.py index 2040a19ec874..5517386e1eb3 100644 --- a/components/src/dynamo/frontend/frontend_args.py +++ b/components/src/dynamo/frontend/frontend_args.py @@ -56,6 +56,7 @@ class FrontendConfig(RouterConfigBase, KvRouterConfigBase, AicPerfConfigBase): kv_cache_block_size: Optional[int] http_host: str http_port: int + rl_port: int tls_cert_path: Optional[pathlib.Path] tls_key_path: Optional[pathlib.Path] @@ -97,6 +98,8 @@ def validate(self) -> None: raise ValueError( f"--migration-limit must be between 0 and {_U32_MAX} (0=disabled)" ) + if self.rl_port < 0 or self.rl_port > 65535: + raise ValueError("--rl-port must be between 0 and 65535") if self.migration_max_seq_len is not None and ( self.migration_max_seq_len < 1 or self.migration_max_seq_len > _U32_MAX ): @@ -208,6 +211,14 @@ def add_arguments(self, parser) -> None: help="HTTP port for the engine (u16).", arg_type=int, ) + add_argument( + g, + flag_name="--rl-port", + env_var="DYN_RL_PORT", + default=8002, + help="Dedicated HTTP port for RL admin endpoints (u16).", + arg_type=int, + ) add_negatable_bool_argument( g, flag_name="--serve-indexer", diff --git a/components/src/dynamo/frontend/main.py b/components/src/dynamo/frontend/main.py index 361d16ffd25c..a54845d4a558 100644 --- a/components/src/dynamo/frontend/main.py +++ b/components/src/dynamo/frontend/main.py @@ -237,6 +237,7 @@ def signal_handler(): kwargs: dict[str, Any] = { "http_host": config.http_host, "http_port": config.http_port, + "rl_port": config.rl_port, "kv_cache_block_size": config.kv_cache_block_size, "router_config": router_config, "migration_limit": config.migration_limit, diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 792503cec317..37c8a42c7706 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -1429,16 +1429,38 @@ async def weight_transport_update(self, body: dict) -> dict: logger.exception(f"[RL] weight_transport_update failed: {e}") return {"status": "error", "message": str(e)} + async def describe_rl(self, body: dict | None = None) -> dict: + """Return lightweight RL worker metadata for SDK topology probes.""" + mode = getattr(self.config, "disaggregation_mode", None) + if hasattr(mode, "value"): + mode = mode.value + return { + "status": "ok", + "namespace": getattr(self.config, "namespace", None), + "component": getattr(self.config, "component", None), + "endpoint": "rl", + "worker_role": mode, + "details": { + "model": getattr(self.config, "model", None), + "served_model_name": ( + getattr(self.config, "served_model_name", None) + or getattr(self.config, "model", None) + ), + "weight_version": getattr(self, "_weight_version", "initial"), + "lora_count": len(self.loaded_loras), + }, + } + # ── PR B: unified `rl` request-plane endpoint ───────────────────── # # Worker registers ``dyn://..rl`` and serves this # dispatcher. The frontend (dynamo-rl crate) discovers live `rl` - # instances via the standard discovery plane and dispatches via - # ``PushRouter::direct`` over NATS / shared TCP — no system-port HTTP + # instances via the standard discovery plane and dispatches via strict + # request-plane direct calls over NATS / shared TCP — no system-port HTTP # fan-out, no static `DYN_RL_WORKER_SYSTEM_URLS` list. # # Wire shape: ``{"op": str, "body": dict}`` where `op` is one of - # ``pause | resume | init_transport | update_weights``. The dispatcher + # ``describe | pause | resume | init_transport | update_weights``. The dispatcher # routes to the existing per-op handlers and yields a single response # dict (matching the serve_endpoint async-generator contract used by # ``generate``, ``load_lora``, etc.). @@ -1462,7 +1484,9 @@ async def rl_dispatch(self, request=None): } return try: - if op == "pause": + if op == "describe": + yield await self.describe_rl(body) + elif op == "pause": yield await self.pause_generation(body) elif op == "resume": yield await self.resume_generation(body) @@ -1475,7 +1499,7 @@ async def rl_dispatch(self, request=None): "status": "error", "message": ( f"rl_dispatch: unknown op {op!r}; expected one of " - "pause|resume|init_transport|update_weights" + "describe|pause|resume|init_transport|update_weights" ), } except Exception as e: diff --git a/docs/RL.md b/docs/RL.md new file mode 100644 index 000000000000..3956acccaafd --- /dev/null +++ b/docs/RL.md @@ -0,0 +1,667 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: RL +--- + +# Dynamo RL + +Dynamo RL support has two separate surfaces: + +1. The inference surface on the normal OpenAI listener, usually + `:8000 /v1/chat/completions`. This carries rollout-time extensions such as + token-in/token-out, cache salt, and weight version metadata. +2. The admin surface for pause, resume, transport setup, and weight updates. + The canonical implementation is the `dynamo-rl` SDK. The optional HTTP + facade exposes the same operations at `:8002 /v1/rl/*` by default. + +The admin surface does not fan out through worker system ports. Workers are +ephemeral, so the SDK snapshots the discovery plane, finds live worker +endpoints named `rl`, and dispatches strict direct calls through Dynamo's +request plane. The request plane may be TCP, NATS, or HTTP depending on the +deployment. + +## Architecture + +```mermaid +flowchart LR + subgraph ClientSide["RL clients"] + Trainer["Trainer or orchestrator"] + Slime["Slime / in-process client"] + Prime["prime-rl HTTP client"] + end + + subgraph Frontend["Dynamo frontend"] + OpenAI[":8000 /v1/chat/completions"] + RlHttp[":8002 /v1/rl/*"] + RlClient["dynamo-rl RlClient"] + end + + subgraph Runtime["Dynamo runtime"] + Discovery["Discovery plane"] + RequestPlane["Request plane\nTCP / NATS / HTTP"] + end + + subgraph Workers["Inference workers"] + W1["namespace.component.rl\ninstance 1"] + W2["namespace.component.rl\ninstance 2"] + Wn["namespace.component.rl\ninstance N"] + Adapter["vLLM RL adapter"] + Transport["WeightTransport\nfilesystem / nccl"] + end + + Trainer --> OpenAI + Trainer --> RlClient + Slime --> RlClient + Prime --> RlHttp + RlHttp --> RlClient + RlClient --> Discovery + RlClient --> RequestPlane + RequestPlane --> W1 + RequestPlane --> W2 + RequestPlane --> Wn + W1 --> Adapter + W2 --> Adapter + Wn --> Adapter + Adapter --> Transport +``` + +System ports remain useful for process health, metrics, and debugging, but +they are not the RL worker fan-out contract. There is no +`DYN_RL_WORKER_SYSTEM_URLS` static worker list. + +## Enablement + +Frontend configuration: + +| Setting | Default | Purpose | +|---|---:|---| +| `DYN_ENABLE_RL` | `false` | Enables inference-plane RL extensions on `/v1/chat/completions`, including automatic token-id return on unary chat responses. | +| `DYN_ENABLE_RL_ENDPOINTS` | `false` | Enables the optional admin HTTP facade. | +| `DYN_RL_PORT` or `--rl-port` | `8002` | Dedicated listener for `/v1/rl/*`; routes are not mounted on the main `:8000` listener. | +| `DYN_NAMESPACE` | `dynamo` | Namespace scanned by the RL SDK. | +| `DYN_RL_COMPONENT` | unset | Optional component filter. When unset, all live endpoints named `rl` in the namespace are targeted. | +| `DYN_REQUEST_PLANE` | deployment-specific | Selects the request-plane transport, for example TCP, NATS, or HTTP. | +| `DYN_DISCOVERY_BACKEND` | deployment-specific | Selects the discovery backend, for example etcd, Kubernetes, file, or memory. | + +Example trainer endpoints: + +```toml +base_url = "http://dynamo-frontend:8000/v1" +admin_base_url = "http://dynamo-frontend:8002/v1/rl" +``` + +Worker requirements: + +- Workers serving RL workloads register a request-plane endpoint named `rl`. +- The endpoint receives a single envelope, `{"op": "...", "body": {...}}`. +- Supported operations are `describe`, `pause`, `resume`, `init_transport`, + and `update_weights`. + +## Discovery And Fan-Out + +The `dynamo-rl` SDK owns membership and dispatch: + +```rust +pub struct RlClient { + runtime: Arc, + namespace: String, + rl_endpoint: String, // default: "rl" + policy: FanoutPolicy, +} + +pub struct MembershipSnapshot { + pub epoch: u64, + pub targets: Vec, +} + +pub struct FanoutPolicy { + pub min_workers: usize, + pub membership_timeout: Duration, + pub request_timeout: Duration, + pub strict_direct: bool, + pub abort_on_membership_change: bool, + pub component_filter: Option>, +} +``` + +For each admin operation, the SDK: + +1. Lists live namespaced endpoints through discovery. +2. Filters to endpoint name `rl` and the optional component filter. +3. Builds a `MembershipSnapshot` with an epoch fingerprint. +4. Groups targets by `(namespace, component, endpoint)`. +5. Sends a strict direct request to each discovered `instance_id`. +6. Optionally snapshots membership again and fails with + `membership_changed` if the epoch changed during fan-out. + +`strict_direct` matters for RL admin calls. A pause or weight update addressed +to worker instance `A` must not silently fall back to instance `B` if `A` +disappears. If the target is gone, the call fails and the caller receives a +per-worker error. + +```mermaid +sequenceDiagram + autonumber + participant Caller as Trainer / SDK user + participant Client as RlClient + participant Discovery as Discovery plane + participant RP as Request plane + participant Worker as Worker rl endpoint + + Caller->>Client: update_weights(request) + Client->>Discovery: list namespaced endpoints + Discovery-->>Client: live endpoints + Client->>Client: filter endpoint == rl and compute epoch + loop Each worker instance + Client->>RP: strict direct op=update_weights, instance_id + RP->>Worker: op=update_weights, body=request + Worker-->>RP: status payload + RP-->>Client: worker result + end + Client->>Discovery: list namespaced endpoints + Discovery-->>Client: live endpoints + alt epoch unchanged + Client-->>Caller: FanoutReport with membership_epoch and workers + else membership changed + Client-->>Caller: error membership_changed + end +``` + +The snapshot is a consistency guard, not a distributed lock. In a deployment +where workers are added or removed frequently, callers should treat a +`membership_changed` response as a retryable orchestration event. If membership +stays stable but a worker rejects or times out, the HTTP facade returns `502` +with per-worker status so the orchestrator can retry, drain, or rebuild the +fleet. + +The SDK does not poll worker system-port health. It snapshots discovery for +each fan-out operation and waits briefly for the request-plane client to see +the target instance IDs before dispatching. + +## Inference Surface + +RL rollout traffic uses the standard chat-completions route: + +```http +POST /v1/chat/completions +``` + +When `DYN_ENABLE_RL=true`, unary chat responses promote token metadata for RL +clients: + +- `response.prompt_token_ids` is populated from the original messages or from + pre-tokenized input. +- `choices[].token_ids` is populated from completion token IDs. +- `return_token_ids` is auto-enabled for unary chat responses. + +Callers can also request token IDs explicitly with `return_token_ids: true`. +When token IDs are requested, `n > 1` is rejected because the current +aggregation path cannot safely assign one shared completion-token vector back +to multiple choices. + +Supported request extensions include: + +| Field | Direction | Purpose | +|---|---|---| +| `prompt_token_ids` | request | Token-in/token-out path. Send pre-tokenized prompt IDs instead of messages. | +| `tokens` | request | Legacy pre-tokenized prompt path mapped into `nvext.token_data`. | +| `return_token_ids` | request | Requests completion token IDs in the response. | +| `cache_salt` | request | Salts prefix-cache identity for rollout isolation. | +| `weight_version` | request | Routes or annotates requests against a caller-selected weight version. | +| `stop_token_ids` | request | Stop generation when any listed token ID is produced. | +| `allowed_token_ids` | request | Sampling constraint passthrough. | +| `bad_words_token_ids` | request | Sampling constraint passthrough. | +| `truncate_prompt_tokens` | request | Prompt truncation passthrough. | +| `return_prompt_logprobs` | request | Requests prompt logprobs where supported by the backend. | +| `return_routed_experts` | request | Requests routed expert metadata where supported by the backend. | + +TITO callers should send `prompt_token_ids` on `/v1/chat/completions`. The +separate `/v1/chat/completions/tokens` route is not part of the current +surface. + +## Admin HTTP Facade + +The HTTP facade is optional. It exists for clients that cannot embed the SDK +but still need the same fan-out semantics. The facade is mounted only when +`DYN_ENABLE_RL_ENDPOINTS=true` or the service configuration enables RL +endpoints. + +Routes: + +| Method | Path | Description | +|---|---|---| +| `POST` | `/v1/rl/pause` | Fan out `pause` to every discovered worker. | +| `POST` | `/v1/rl/resume` | Fan out `resume` to every discovered worker. | +| `POST` | `/v1/rl/init_transport` | Initialize a weight-transfer backend on every worker. | +| `POST` | `/v1/rl/update_weights` | Apply a base-model or LoRA weight update on every worker. | + +Read-side RL routes are not part of the current HTTP surface: +`/v1/rl/state`, `/v1/rl/health`, `/v1/rl/ready`, `/v1/rl/liveness`, and +`/v1/rl/weight_version` are dropped. Use the frontend's existing `/live` and +`/health` process checks for Kubernetes probes. SDK callers can use +`describe` for topology and worker metadata probes. + +### Pause + +```http +POST /v1/rl/pause?mode=keep&clear_cache=false +``` + +Query parameters: + +| Parameter | Values | Default | +|---|---|---| +| `mode` | `keep`, `wait`, `abort` | `keep` | +| `clear_cache` | `true`, `false` | `false` | + +Successful response: + +```json +{ + "status": "ok", + "mode": "keep", + "clear_cache": false, + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "version": "initial" + } + ] +} +``` + +### Resume + +```http +POST /v1/rl/resume +``` + +Successful response: + +```json +{ + "status": "ok", + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok" + } + ] +} +``` + +### Init Transport + +`init_transport` is idempotent setup for a weight-transfer backend. Filesystem +is a no-op that marks the transport ready. NCCL initializes the worker-side +vLLM weight-transfer engine. + +Filesystem: + +```http +POST /v1/rl/init_transport +``` + +```json +{ + "transport_id": "fs-step", + "backend": "filesystem", + "filesystem": { + "require_marker": "STABLE" + } +} +``` + +NCCL: + +```http +POST /v1/rl/init_transport +``` + +```json +{ + "transport_id": "rl-nccl", + "backend": "nccl", + "nccl": { + "master_address": "trainer-0.trainer", + "master_port": 29500, + "world_size": 9, + "rank_offset": 1 + } +} +``` + +Successful response: + +```json +{ + "status": "ok", + "transport_id": "rl-nccl", + "backend": "nccl", + "ready": true, + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "transport_id": "rl-nccl", + "ready": true + } + ] +} +``` + +### Update Weights + +All weight updates use one discriminated body: + +```json +{ + "version": "step_42", + "target": { + "kind": "base" + }, + "transport": { + "backend": "filesystem" + }, + "pause_mode": "keep", + "clear_cache": true +} +``` + +Required fields: + +| Field | Description | +|---|---| +| `version` | Caller-assigned version string applied to the update. | +| `target.kind` | `base` or `lora`. | +| `transport.backend` | `filesystem` or `nccl` for the current vLLM implementation. Not required for LoRA unload. | + +Optional fields: + +| Field | Default | Description | +|---|---|---| +| `pause_mode` | `keep` | Worker-side pause behavior: `keep`, `wait`, or `abort`. | +| `clear_cache` | `true` | Whether the worker should clear prefix/KV cache where supported. | + +Successful response: + +```json +{ + "status": "ok", + "applied_weight_version": "step_42", + "backend": "filesystem", + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "message": "Updated weights from filesystem", + "version": "step_42" + } + ] +} +``` + +#### Base Model From Filesystem + +The trainer writes a checkpoint to shared storage, creates the marker file +after the checkpoint is complete, then calls `update_weights`. + +```json +{ + "version": "step_42", + "target": { + "kind": "base" + }, + "transport": { + "backend": "filesystem", + "filesystem": { + "path": "/share/broadcasts/step_42", + "require_marker": "STABLE" + } + }, + "pause_mode": "keep", + "clear_cache": true +} +``` + +#### Base Model From NCCL + +The trainer and inference workers form a group during `init_transport`. On +each update, the trainer broadcasts the named tensors and the workers receive +through vLLM's weight-update collective. + +```json +{ + "version": "step_42", + "target": { + "kind": "base" + }, + "transport": { + "backend": "nccl", + "nccl": { + "transport_id": "rl-nccl", + "names": [ + "model.layers.0.self_attn.q_proj.weight" + ], + "dtype_names": [ + "bfloat16" + ], + "shapes": [ + [4096, 4096] + ], + "is_checkpoint_format": true, + "packed": false + } + }, + "pause_mode": "keep", + "clear_cache": true +} +``` + +#### LoRA Load, Swap, And Unload + +LoRA uses the same `update_weights` route. In the current vLLM implementation, +LoRA transfer uses the filesystem backend. NCCL LoRA transfer is deferred. + +Load: + +```json +{ + "version": "step_42", + "target": { + "kind": "lora", + "name": "qwen3-06b-gsm8k", + "op": "load" + }, + "transport": { + "backend": "filesystem", + "filesystem": { + "path": "/share/lora/qwen3-06b-gsm8k/step_42", + "require_marker": "STABLE" + } + }, + "pause_mode": "wait", + "clear_cache": false +} +``` + +Swap: + +```json +{ + "version": "step_43", + "target": { + "kind": "lora", + "name": "qwen3-06b-gsm8k", + "op": "swap" + }, + "transport": { + "backend": "filesystem", + "filesystem": { + "path": "/share/lora/qwen3-06b-gsm8k/step_43", + "require_marker": "STABLE" + } + }, + "pause_mode": "wait", + "clear_cache": false +} +``` + +Unload: + +```json +{ + "version": "step_44", + "target": { + "kind": "lora", + "name": "qwen3-06b-gsm8k", + "op": "unload" + } +} +``` + +Dedicated `load_lora_adapter` and `unload_lora_adapter` RL routes are not part +of the current surface. + +## Weight-Update Sequence + +```mermaid +sequenceDiagram + autonumber + participant Trainer + participant Admin as RlClient or HTTP facade + participant Discovery as Discovery plane + participant RP as Request plane + participant Worker as Worker rl endpoint + participant Adapter as vLLM adapter + participant Engine as vLLM engine + + Trainer->>Admin: pause(mode=keep) + Admin->>Discovery: snapshot rl workers + Admin->>RP: strict direct op=pause to each instance + RP->>Worker: pause + Worker->>Adapter: pause_generation + Adapter-->>Worker: ok + Worker-->>Admin: worker result + + alt filesystem backend + Trainer->>Trainer: write checkpoint and marker to shared storage + Trainer->>Admin: update_weights(version, filesystem path) + Admin->>Discovery: snapshot rl workers + Admin->>RP: strict direct op=update_weights + RP->>Worker: update_weights + Worker->>Adapter: FilesystemTransport.update_weights + Adapter->>Adapter: verify require_marker + Adapter->>Engine: reload weights from path + Engine-->>Adapter: ok + Adapter-->>Worker: ok + else nccl backend + Trainer->>Admin: init_transport(nccl) + Admin->>RP: strict direct op=init_transport + RP->>Worker: init_transport + Worker->>Engine: init_weight_transfer_engine + Engine-->>Worker: ready + Trainer->>Trainer: prepare named tensor broadcast + Trainer->>Admin: update_weights(version, tensor metadata) + Admin->>RP: strict direct op=update_weights + RP->>Worker: update_weights + Worker->>Engine: update_weights receive collective + Engine-->>Worker: ok + end + + Admin-->>Trainer: FanoutReport + Trainer->>Admin: resume() + Admin->>RP: strict direct op=resume to each instance + RP->>Worker: resume + Worker->>Adapter: resume_generation + Adapter-->>Worker: ok + Admin-->>Trainer: FanoutReport +``` + +Weight updates are not atomic across workers. If some workers update and one +worker fails, the fleet can be left at mixed versions. The response includes +per-worker results so the orchestrator can decide whether to retry, drain the +failed worker, or rebuild the serving group. + +## Kubernetes + +Kubernetes deployments should expose two frontend ports: + +- Main inference port, usually `8000`, for OpenAI-compatible traffic and + standard `/health` and `/live` checks. +- RL admin port, usually `8002`, for `/v1/rl/*`. Keep this port + cluster-internal and protect it with service policy or network policy. + +Workers do not need their system ports exposed for RL admin fan-out. They must +be discoverable through the configured Dynamo discovery backend and reachable +through the configured request plane. + +Transport-specific Kubernetes notes: + +- Filesystem transfer requires shared storage mounted at the same path on the + trainer and every inference worker, or a path mapping layer in the + orchestrator. +- NCCL transfer requires the trainer and workers to resolve the NCCL + `master_address` and connect to `master_port`. This rendezvous is separate + from Dynamo's request plane. +- NATS request-plane deployments need the worker and frontend pods connected + to the same NATS deployment. +- TCP request-plane deployments need pod-to-pod connectivity for the Dynamo + request-plane endpoints. + +## Error Responses + +The HTTP facade maps SDK errors to stable status codes: + +| Status | `error_type` | Meaning | +|---:|---|---| +| `503` | `no_workers` | Discovery found fewer than `min_workers` live `rl` endpoints. | +| `409` | `membership_changed` | Membership changed during fan-out and the policy requires a stable epoch. | +| `502` | `fanout_failed` | Request-plane setup failed, worker dispatch failed, or one or more workers returned an error. | + +Per-worker failures also return `502` with a `workers` array: + +```json +{ + "status": "error", + "stage": "weight_transport_update", + "backend": "filesystem", + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "version": "step_42" + }, + { + "status": "error", + "message": "filesystem transport: require_marker 'STABLE' not found" + } + ] +} +``` + +## Backend Status + +Current implementation scope: + +- `dynamo-rl` Rust SDK and HTTP facade. +- Discovery-backed membership snapshots. +- Request-plane strict direct fan-out. +- vLLM worker `rl` dispatcher. +- vLLM filesystem base-model and LoRA updates. +- vLLM NCCL base-model updates. + +Deferred or backend-specific: + +- SGLang weight-transfer adapter parity. +- NCCL LoRA transfer. +- NIXL, Model Express, CUDA IPC, and tensor-handle transports. +- Public read-side RL state endpoints. +- Auth and RBAC inside the RL facade. Deploy the admin port as an internal + control-plane surface. + +`call_tokenizer_manager` is SGLang-specific tokenizer-manager passthrough. It +is not the generic Dynamo RL admin fan-out path. The portable RL admin contract +is the discoverable worker endpoint named `rl` plus the SDK fan-out policy. diff --git a/docs/dynamo-RL-api.md b/docs/dynamo-RL-api.md deleted file mode 100644 index 72ba903a4f9d..000000000000 --- a/docs/dynamo-RL-api.md +++ /dev/null @@ -1,569 +0,0 @@ -# Dynamo RL API - -This document describes the RL training API surface on the Dynamo serving stack. The Dynamo Rust frontend exposes a small, focused set of endpoints that let an RL trainer drive a vLLM-served model through pause / weight-update / resume cycles, hot-swap LoRA adapters, and post pre-tokenized inputs on the standard chat-completions endpoint. - -## Table of Contents - -1. [Overview](#1-overview) -2. [Architecture](#2-architecture) -3. [Configuration](#3-configuration) -4. [API Reference](#4-api-reference) - - 4.1 Chat Completions (RL-enhanced + TITO) - - 4.2 RL Lifecycle (`/v1/rl/*`) -5. [Data Flow](#5-data-flow) -6. [Key Data Structures](#6-key-data-structures) -7. [Worker Engine Routes (Internal)](#7-worker-engine-routes-internal) -8. [Known Limitations](#8-known-limitations) - ---- - -## 1. Overview - -The Dynamo Rust frontend exposes: - -- A `/v1/rl/*` router for the full RL control-plane lifecycle (composite state, liveness probe, pause/resume, weight update, LoRA hot-swap) -- Token-level data injection (`prompt_token_ids`, `choices[i].token_ids`, `nvext.completion_token_ids`) on standard chat-completion responses -- Pre-tokenized prompt support on the standard `/v1/chat/completions` endpoint via the `prompt_token_ids` extension (no separate URI) - -Zero Python in the inference or admin data path. The Rust frontend handles every HTTP route; vLLM workers expose a small set of internal engine routes for pause/update/resume on the GPU. - -### Endpoint Summary - -| Capability | Endpoint | Method | Notes | -|---|---|---|---| -| Inference | `/v1/chat/completions` | POST | Standard OpenAI plus RL extras: `prompt_token_ids`, `stop_token_ids`, `allowed_token_ids`, `bad_words_token_ids`, `truncate_prompt_tokens`, `weight_version`, `nvext.{completion_token_ids,return_token_ids,return_routed_experts,return_prompt_logprobs}` | -| Composite state | `/v1/rl/state` | GET | Aggregated per-worker `{ready, engine_alive, pause_state, applied_weight_version, loras, workers}` | -| Liveness | `/v1/rl/liveness` | GET | Round-trips `engine_client.check_health()` so a wedged event loop surfaces 503 | -| Pause fleet | `/v1/rl/pause` | POST | `?mode=keep\|wait\|abort&clear_cache=bool` | -| Resume fleet | `/v1/rl/resume` | POST | | -| Update weights | `/v1/rl/update_weights` | POST | Typed body: `{weight_dir, weight_version?, reset_prefix_cache=true}` | -| Load LoRA adapter | `/v1/rl/load_lora_adapter` | POST | Filesystem-native PEFT-style hot-swap | -| Unload LoRA adapter | `/v1/rl/unload_lora_adapter` | POST | Idempotent | -| Legacy: health | `/v1/rl/health` | GET | Kept for back-compat; prefer `/v1/rl/state` | -| Legacy: ready | `/v1/rl/ready` | GET | Kept for back-compat; prefer `/v1/rl/state` | -| Legacy: weight_version | `/v1/rl/weight_version` | GET | Kept for back-compat; folded into `/v1/rl/state.applied_weight_version` | - -Endpoints intentionally **not** present (returned 404): - -| Removed | Reason | -|---|---| -| `/v1/chat/completions/tokens` | TITO collapsed into `/v1/chat/completions` via the `prompt_token_ids` top-level extension | -| `/v1/tokenize` | Out of scope for this surface (covered by a separate PR) | -| `/v1/detokenize` | Same as above | - -The handler functions and route helpers are kept in source under `#[allow(dead_code)]` so downstream code that still references them compiles; physical deletion is a follow-up cleanup commit. - ---- - -## 2. Architecture - -### Component Topology - -```mermaid -flowchart TD - subgraph rl_client["RL Trainer (external)"] - orch["Orchestrator
(rollouts + admin calls)"] - trainer["Trainer
(torchrun, FSDP/EP/etc.)"] - end - - subgraph dynamo["Dynamo Serving Stack"] - subgraph frontend["Frontend Pod (Rust, port 8000)"] - cc["/v1/chat/completions
+ prompt_token_ids extension
+ choices[].token_ids
+ stop_token_ids / allowed_token_ids / ..."] - rl["/v1/rl/* (admin)
state, liveness,
pause, resume, update_weights,
load_lora_adapter, unload_lora_adapter"] - end - subgraph worker["vLLM Worker Pod (Python, system port 9090)"] - eng["Engine routes:
get_state, liveness_probe,
pause_generation, resume_generation,
flush_cache, update_weights_from_path,
get_weight_version,
load_lora_adapter, unload_lora_adapter"] - gpu["GPU
Model Weights"] - end - end - - subgraph storage["Shared Storage (PVC)"] - pvc["safetensors checkpoints
+ adapter_model.safetensors / adapter_config.json"] - end - - orch -- "rollouts: POST /v1/chat/completions
(messages OR prompt_token_ids)" --> cc - orch -- "weight lifecycle:
pause → update_weights → resume" --> rl - rl -- "fan-out (concurrent)" --> eng - eng --> gpu - trainer -- "write checkpoint" --> pvc - eng -- "reload_weights
(collective_rpc)" --> pvc -``` - -### Key Design Decisions - -1. **Single entry point.** The trainer points both `base_url` and `admin_base_url` at the Dynamo frontend. No separate admin service. -2. **Fan-out in Rust.** `/v1/rl/*` handlers fan out to all vLLM workers via `DYN_RL_WORKER_SYSTEM_URLS`. Supports DP > 1 without the client needing to discover workers. Returns HTTP 200 only when every worker responds OK; otherwise 502 with per-worker details. -3. **Token IDs as a response extension.** When `DYN_ENABLE_RL=true`, `prompt_token_ids` and `choices[i].token_ids` are injected into every non-streaming response automatically. `nvext.completion_token_ids` is the canonical Dynamo location; the choice-level field is a compatibility shim for clients that read tokens from the choice object. -4. **Backward compatible.** All new response fields use `#[serde(skip_serializing_if = "Option::is_none")]`. Clients that don't set `DYN_ENABLE_RL` see standard OpenAI-compatible responses with no extra fields. -5. **TITO without a URI fork.** Pre-tokenized input is a top-level extension on the standard chat-completions request (`prompt_token_ids`), not a separate `/v1/chat/completions/tokens` URI. Aligns with vLLM 0.20+ which accepts the same extension natively. - ---- - -## 3. Configuration - -### Environment Variables (Frontend) - -| Variable | Default | Description | -|---|---|---| -| `DYN_ENABLE_RL` | `false` | Master switch. Mounts `/v1/rl/*` routes and auto-injects token IDs in chat completion responses. | -| `DYN_RL_WORKER_SYSTEM_URLS` | `http://localhost:8081` | Comma-separated vLLM worker system HTTP base URLs for fan-out. | -| `DYN_RL_LIVENESS_TIMEOUT_MS` | `5000` | Per-worker timeout for `/v1/rl/liveness`. | - -### Environment Variables (Worker) - -| Variable | Default | Description | -|---|---|---| -| `DYN_SYSTEM_PORT` | `8081` (local) / `9090` (k8s) | Worker's system HTTP port where engine routes are registered. | - -### Sample trainer config - -```toml -[client] -base_url = ["http://:8000/v1"] -admin_base_url = ["http://:8000/v1/rl"] -backend = "vllm" -skip_model_check = true - -[weight_broadcast] -type = "filesystem" # NCCL is a Dynamo-side no-op today; see §8 -``` - -### Kubernetes (DGD frontend env) - -```yaml -- name: DYN_ENABLE_RL - value: "true" -- name: DYN_RL_WORKER_SYSTEM_URLS - value: "http://-vllmworker..svc.cluster.local:9090" -``` - ---- - -## 4. API Reference - -### 4.1 Chat Completions (RL-enhanced + TITO) - -``` -POST /v1/chat/completions -``` - -Standard OpenAI chat completions. When `DYN_ENABLE_RL=true`, every non-streaming response is automatically enriched with token IDs. - -#### RL request extensions - -The following top-level fields are accepted in addition to the OpenAI schema. They are validated by `validate.rs::PASSTHROUGH_EXTRA_FIELDS` and forwarded to the engine where vLLM 0.20+ accepts them natively: - -| Field | Type | Purpose | -|---|---|---| -| `prompt_token_ids` | `u32[]` | Pre-tokenized prompt (TITO). Mutually exclusive with non-empty `messages` (except for the legacy `nvext.token_data` renderer-mode placeholder, which still coexists). | -| `stop_token_ids` | `u32[]` | Plumbed into `SamplingParams.stop_token_ids`; forces stop on any of these IDs. Malformed input (e.g. `"not-an-array"`) returns a typed 400. | -| `allowed_token_ids` | `u32[]` | Restricts decoding to this set. | -| `bad_words_token_ids` | `u32[]` | Suppresses these IDs. | -| `truncate_prompt_tokens` | `int` | Truncates prompt to N most-recent tokens. | -| `weight_version` | `string` | Routing filter for IS-correction strict-version mode (today accepted; routing follow-up). | -| `cache_salt` | `string` | KV prefix-cache isolation hint. The equivalent `X-Tenant-Id` request header is also accepted; the header takes precedence when both are present. | -| `return_token_ids` | `bool` | Per-request opt-in for `nvext.completion_token_ids` (also achievable via `extra_fields`). | -| `return_routed_experts` | `bool` | MoE expert-routing replay capture. | -| `return_prompt_logprobs` | `bool` | Streaming logprobs for input tokens. | - -In the legacy `nvext` channel, `nvext.token_data` (renderer-mode pre-tokenized prompt) and `nvext.extra_fields = ["token_ids", "completion_token_ids", ...]` continue to work unchanged. - -#### TITO via `prompt_token_ids` - -```bash -curl -s -X POST http://localhost:8000/v1/chat/completions \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "Qwen/Qwen3-0.6B", - "messages": [], - "prompt_token_ids": [151644, 8948, 198, 151645, 198, 151644, 872, 198, - 49, 1075, 513, 420, 25, 24748, 1879, 198, 151645, - 198, 151644, 77091, 198], - "stop_token_ids": [151643], - "max_tokens": 64 - }' -``` - -Validation rules: - -- `messages` may be empty when `prompt_token_ids` is non-empty (the chat template short-circuits). -- `messages` non-empty + `prompt_token_ids` non-empty → 400 mutual-exclusion error (canonical channel only). -- `nvext.token_data` + non-empty `messages` → still allowed (legacy renderer-mode placeholder pattern that uses a synthetic user message alongside pre-tokenized input). - -#### Sample response (non-streaming, `DYN_ENABLE_RL=true`) - -```jsonc -{ - "id": "chatcmpl-abc123", - "object": "chat.completion", - "model": "Qwen/Qwen3-0.6B", - "choices": [{ - "index": 0, - "message": {"role": "assistant", "content": "dlrow olleh"}, - "finish_reason": "stop", - "logprobs": {"content": [...]}, - "token_ids": [67, 1245, 893, 15] - }], - "prompt_token_ids": [151644, 8948, 198, ...], - "usage": {"prompt_tokens": 21, "completion_tokens": 4, "total_tokens": 25}, - "nvext": { - "completion_token_ids": [67, 1245, 893, 15] - } -} -``` - -#### Response field reference - -| Field | JSON path | Description | -|---|---|---| -| `prompt_token_ids` | `response.prompt_token_ids` | Promoted by `rl_tokenize_prompt`: messages → tokenizer (model chat template) → token IDs. | -| `token_ids` | `response.choices[i].token_ids` | Per-choice output token IDs, promoted by `rl_promote_token_ids_in_response` from `nvext.completion_token_ids`. | -| `completion_token_ids` | `response.nvext.completion_token_ids` | Canonical Dynamo location; accumulated across SSE chunks by `DeltaGenerator`. | - -**Why two locations?** Some RL clients read tokens from `response.prompt_token_ids` / `choices[i].token_ids`; Dynamo natively emits them under `nvext.completion_token_ids`. The Rust post-processor promotes the canonical field to the choice-level field so both client conventions work. - -**Invariant:** `len(completion_token_ids) == len(logprobs.content)`. - -#### Streaming (SSE) - -Intermediate chunks carry `delta.content` only. Token IDs appear exclusively on the **final chunk** (the one with a non-null `finish_reason`). - ---- - -### 4.2 RL Lifecycle (`/v1/rl/*`) - -Mounted only when `DYN_ENABLE_RL=true`. All non-trivial routes fan out to the worker URLs in `DYN_RL_WORKER_SYSTEM_URLS`. - -#### `GET /v1/rl/state` — composite read-only - -Single endpoint that returns the full fleet state in one call. Aggregates `get_state` per-worker payloads. - -```bash -curl -s http://localhost:8000/v1/rl/state -``` - -```jsonc -// 200 -{ - "ready": true, - "ingress_alive": true, - "engine_alive": true, - "pause_state": "running", // or "paused" | "mixed" - "applied_weight_version": "step_5", // null when workers disagree - "loras": [ - {"name": "r16-a32", "loaded_on": [0, 1]} - ], - "workers": [, ...] -} - -// 503 — no workers registered -{"ready": false, "ingress_alive": true, "engine_alive": false, "pause_state": "running", - "applied_weight_version": null, "loras": [], "workers": [], - "status": "error", "message": "no workers registered"} -``` - -`ready = ingress_alive AND engine_alive AND len(workers) > 0`. `ingress_alive` is unconditionally `true` because reaching this handler proves the frontend HTTP listener is up. - -#### `GET /v1/rl/liveness` — deep liveness probe - -Round-trips `engine_client.check_health()` per worker so a wedged event loop or hung NCCL collective surfaces as 503. Override timeout via `DYN_RL_LIVENESS_TIMEOUT_MS` (default 5000). - -```bash -curl -s http://localhost:8000/v1/rl/liveness -``` - -```jsonc -// 200 -{"status": "ok", "alive": true, "workers": [{"alive": true}, ...]} - -// 503 — at least one worker hung past timeout -{"status": "error", "alive": false, "workers": [{"alive": false, "error": "timeout"}]} -``` - -#### `POST /v1/rl/pause` — 3-mode pause + cache control - -Query parameters (or JSON body): - -| Param | Type | Default | Effect | -|---|---|---|---| -| `mode` | `keep` \| `wait` \| `abort` | `keep` | `keep`: drain in-flight (legacy behaviour). `wait`: same as `keep` but block on completion. `abort`: trigger `collective_rpc(abort_all_requests)` on the engine (graceful warn-fallback on vLLM 0.19 where that RPC isn't implemented). | -| `clear_cache` | `bool` | `false` | If `true`, calls `reset_prefix_cache` after the pause completes. | - -```bash -curl -s -X POST 'http://localhost:8000/v1/rl/pause?mode=abort&clear_cache=true' -``` - -400 on unknown `mode`: - -```json -{"status": "error", "message": "Invalid mode 'foo'; expected one of keep|wait|abort"} -``` - -#### `POST /v1/rl/resume` - -Resumes generation on all workers. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/resume -H 'Content-Type: application/json' -d '{}' -``` - -```json -{"status": "ok", "workers": [{"status": "ok", "message": "Engine resumed"}]} -``` - -#### `POST /v1/rl/update_weights` — typed body - -Body schema: - -```jsonc -{ - "weight_dir": "/data/outputs/.../broadcasts/step_5", // required (string | null) - "weight_version": "step_5", // optional, defaults to basename(weight_dir) - "reset_prefix_cache": true // optional, default true -} -``` - -Behaviour: - -- `weight_dir = "/path/..."` → fan out `update_weights_from_path` to every worker. Each worker calls `engine_client.collective_rpc("reload_weights", kwargs={"weights_path": path})` (vLLM's in-place layerwise load). -- `weight_dir = null` → NCCL mode. Dynamo logs `"NCCL mode, no-op on Dynamo side"` and returns 200 immediately. The actual GPU↔GPU transfer happens out of band on a pre-established NCCL group between trainer and inference workers. **Today the inference-side NCCL receiver is not wired into `dynamo.vllm`**; see §8. -- `reset_prefix_cache = true` → flush prefix/KV cache after the load (default). - -```bash -# Filesystem mode -curl -s -X POST http://localhost:8000/v1/rl/update_weights \ - -H 'Content-Type: application/json' \ - -d '{"weight_dir": "/data/outputs/run_default/broadcasts/step_5"}' - -# NCCL mode (Dynamo no-op — see §8) -curl -s -X POST http://localhost:8000/v1/rl/update_weights \ - -H 'Content-Type: application/json' \ - -d '{"weight_dir": null}' -``` - -```jsonc -// 200 -{ - "status": "ok", - "applied_weight_version": "step_5", - "workers": [ - {"status": "ok", "message": "Weights loaded from /data/...", "version": "step_5"} - ] -} - -// 502 (some worker failed) -{"status": "error", "stage": "update_weights_from_path", - "workers": [{"status": "ok", ...}, {"status": "error", "message": "..."}]} -``` - -#### `POST /v1/rl/load_lora_adapter` - -Hot-load / hot-swap a LoRA adapter from a filesystem path. Adapter dir must contain PEFT-style `adapter_model.safetensors` and `adapter_config.json`. - -- First call for a given `lora_name` → `add_lora` + publish a ModelDeploymentCard so subsequent inference with `model=` routes here. -- Subsequent calls (hot-swap) → `remove_lora(old_id)` → `add_lora` with new weights → `reset_prefix_cache`. MDC is left in place. - -Pair with `/v1/rl/pause` + `/v1/rl/resume` for full drain-swap-resume. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/load_lora_adapter \ - -H 'Content-Type: application/json' \ - -d '{"lora_name": "r16-a32", "lora_path": "/data/outputs/run_default/broadcasts/step_5"}' -``` - -```jsonc -// 200 -{"status": "ok", - "workers": [{"status": "ok", "message": "LoRA adapter 'r16-a32' loaded from /data/...", - "lora_name": "r16-a32", "lora_id": 788776416, "hot_swap": false}]} - -// 400 — missing/empty fields -{"status": "error", - "message": "Expected body: {\"lora_name\": str, \"lora_path\": str} (both required, non-empty)"} -``` - -vLLM worker requirements: started with `--enable-lora --max-lora-rank R --max-loras N`. For a single-adapter training loop, `--max-loras 1` is sufficient. - -#### `POST /v1/rl/unload_lora_adapter` - -Remove an adapter by name. Idempotent — unloading an already-absent adapter returns `status: ok`. - -```bash -curl -s -X POST http://localhost:8000/v1/rl/unload_lora_adapter \ - -H 'Content-Type: application/json' \ - -d '{"lora_name": "r16-a32"}' -``` - -#### Legacy endpoints (kept for back-compat) - -`GET /v1/rl/health`, `GET /v1/rl/ready`, `GET /v1/rl/weight_version` — return the same shapes they did before `/v1/rl/state` was added. They will be removed once existing clients migrate to `/v1/rl/state`. - ---- - -## 5. Data Flow - -### 5.1 Rollout (inference) path - -```mermaid -sequenceDiagram - participant Orch as RL Orchestrator - participant FE as Dynamo Frontend (Rust) - participant Worker as vLLM Worker (GPU) - - Orch->>FE: POST /v1/chat/completions
{messages OR prompt_token_ids, stop_token_ids?, ...} - Note over FE: validate.rs: PASSTHROUGH_EXTRA_FIELDS
plumbs RL extras into SamplingParams
If DYN_ENABLE_RL=true, inject
nvext.extra_fields = ["token_ids","completion_token_ids"]
force logprobs=true - FE->>Worker: forward request (TCP/NATS) - Worker-->>FE: SSE chunks (delta.content + delta.token_ids) - Note over FE: DeltaGenerator accumulates
completion_token_ids; serde failures
now log tracing::warn! (no silent drops) - Worker-->>FE: final chunk (finish_reason + nvext.completion_token_ids) - Note over FE: rl_tokenize_prompt(messages) -> prompt_token_ids
rl_promote_token_ids_in_response()
nvext.completion_token_ids -> choices[i].token_ids - FE-->>Orch: enriched response -``` - -### 5.2 Weight update path - -```mermaid -sequenceDiagram - participant Trainer as RL Trainer - participant PVC as Shared Storage - participant Orch as RL Orchestrator - participant FE as Dynamo Frontend (Rust) - participant W1 as vLLM Worker 1 - participant W2 as vLLM Worker 2 - - Trainer->>PVC: write checkpoint
/data/outputs/.../step_N/*.safetensors - Orch->>FE: POST /v1/rl/pause?mode=keep - FE->>W1: pause_generation - FE->>W2: pause_generation - W1-->>FE: ok - W2-->>FE: ok - FE-->>Orch: {status: ok} - Orch->>FE: POST /v1/rl/update_weights
{weight_dir: /data/.../step_N, reset_prefix_cache: true} - FE->>W1: update_weights_from_path - FE->>W2: update_weights_from_path - Note over W1,W2: collective_rpc("reload_weights")
vLLM in-place layerwise load - W1-->>FE: {status: ok, version: step_N} - W2-->>FE: {status: ok, version: step_N} - FE-->>Orch: {status: ok, applied_weight_version: step_N} - Orch->>FE: POST /v1/rl/resume - FE->>W1: resume_generation - FE->>W2: resume_generation - W1-->>FE: ok - W2-->>FE: ok - FE-->>Orch: {status: ok} -``` - -NCCL mode: `weight_dir=null` returns 200 immediately; the actual GPU↔GPU broadcast must be coordinated out of band (see §8 for the wiring gap). - -### 5.3 LoRA hot-swap - -```mermaid -sequenceDiagram - participant Orch as RL Orchestrator - participant FE as Dynamo Frontend - participant W1 as vLLM Worker 1 - - Orch->>FE: POST /v1/rl/pause?mode=keep - FE-->>Orch: ok - Orch->>FE: POST /v1/rl/load_lora_adapter
{lora_name, lora_path} - Note over FE,W1: First call: add_lora + publish MDC
Subsequent: remove_lora(old) → add_lora → reset_prefix_cache - FE-->>Orch: {status: ok, lora_id, hot_swap} - Orch->>FE: POST /v1/rl/resume - FE-->>Orch: ok -``` - ---- - -## 6. Key Data Structures - -### `NvCreateChatCompletionRequest` (Rust, request side) - -Custom fields (top-level, beyond stock OpenAI): - -| Field | `serde` behaviour | Notes | -|---|---|---| -| `prompt_token_ids` | passthrough | Canonical TITO channel. Read by `NvExtProvider::get_pretokenized_input`. | -| `stop_token_ids` | passthrough | Read by `OpenAIStopConditionsProvider::get_stop_token_ids() → Result>>`. Malformed input returns 400. | -| `allowed_token_ids`, `bad_words_token_ids`, `truncate_prompt_tokens` | passthrough | Plumbed into `SamplingParams`. | -| `weight_version`, `cache_salt`, `return_*` | passthrough | See §4.1. | -| `tokens` | `skip_serializing` | Legacy compat — caught and ignored. | -| `return_token_ids` | `skip_serializing` | Legacy compat — use `nvext.extra_fields` or `DYN_ENABLE_RL`. | - -### `NvCreateChatCompletionResponse` (Rust, response side) - -```rust -NvCreateChatCompletionResponse { - inner: CreateChatCompletionResponse, // standard OpenAI - nvext: Option, // NvExtResponse JSON - prompt_token_ids: Option>, // RL only -} -``` - -### `NvExtResponse` - -Serialized as `nvext` on each SSE chunk and the unary response body: - -```rust -NvExtResponse { - worker_id: Option, - timing: Option, - token_ids: Option>, // pre-tokenized prompt (used by disaggregated query/fill stages) - routed_experts: Option, - completion_token_ids: Option>, // RL output, final chunk only -} -``` - -### `RlUpdateWeightsBody` - -```rust -struct RlUpdateWeightsBody { - weight_dir: Option, // null => NCCL mode - weight_version: Option, // defaults to basename(weight_dir) - #[serde(default = "default_reset_prefix_cache")] - reset_prefix_cache: bool, // default true -} -``` - -### `DeltaGenerator` (streaming pipeline) - -Tracks `accumulated_completion_token_ids: Vec` per request. Activated when `extra_fields` includes `"completion_token_ids"` (auto-set under `DYN_ENABLE_RL`). Emits the full vector in `nvext.completion_token_ids` on the final chunk. - -### Post-processing helpers - -- `rl_tokenize_prompt(state, model, messages) -> Option>` — resolves the model card, builds `PromptFormatter`, renders messages through the chat template, tokenizes, returns IDs. -- `rl_promote_token_ids_in_response(json_val)` — copies `nvext.completion_token_ids` to `choices[i].token_ids` per choice. Doc-block now lives on this function (commit `d295ebc6` move). - ---- - -## 7. Worker Engine Routes (Internal) - -Registered on each vLLM worker's system HTTP port (default `8081` local / `9090` k8s) by `worker_factory.py::register_engine_routes()`. Called by the Rust `/v1/rl/*` handlers — not by external clients directly. - -| Route | vLLM API called | Used by | -|---|---|---| -| `pause_generation` | `engine_client.pause_generation()` (+ `abort_all_requests` when mode=abort) | `/v1/rl/pause` | -| `resume_generation` | `engine_client.resume_generation()` | `/v1/rl/resume` | -| `flush_cache` | `engine_client.reset_prefix_cache()` | `/v1/rl/update_weights` (when `reset_prefix_cache=true`) | -| `update_weights_from_path` | `collective_rpc("reload_weights", weights_path=...)` | `/v1/rl/update_weights` | -| `get_weight_version` | reads `self._weight_version` | `/v1/rl/weight_version` (legacy) | -| `get_state` | composite per-worker snapshot (engine_alive, pause_state, applied_weight_version, loras) | `/v1/rl/state` | -| `liveness_probe` | round-trips `engine_client.check_health()` so a wedged event loop returns 503 | `/v1/rl/liveness` | -| `load_lora_adapter` | `add_lora`, `remove_lora` | `/v1/rl/load_lora_adapter` | -| `unload_lora_adapter` | `remove_lora` + MDC unregister | `/v1/rl/unload_lora_adapter` | - -### `publisher.py` crash guard - -`DynamoStatLoggerPublisher.record()` guards against `scheduler_stats is None`. This prevents an `AttributeError` crash during the transient window right after a weight reload, when the vLLM stats logger fires before the engine core has re-initialized its scheduler. - ---- - -## 8. Known Limitations - -| Limitation | Workaround | Notes | -|---|---|---| -| **NCCL mode is a no-op on Dynamo's vLLM side.** `update_weights` with `weight_dir=null` returns 200 immediately, but `dynamo.vllm` does not load an NCCL weight-broadcast receiver as a vLLM worker class — so the trainer's NCCL broadcast has no peer on the inference side, and `init_process_group` on the trainer times out at `weight_broadcast.timeout` (default 1200 s). | Use `weight_broadcast.type = "filesystem"`. The `dynamo.sglang` backend ships `update_weights_from_distributed` natively and does work over NCCL. | The bootstrap admin route the trainer expects (`/v1/rl/init_broadcaster`) is not exposed by `dynamo.vllm` today; wiring it (and the receiver class) is the next workstream. | -| `cache_salt` not yet honored end-to-end | Disable prefix-cache-salt on the client side, or send the equivalent `X-Tenant-Id` header. | Field is whitelisted (`PASSTHROUGH_EXTRA_FIELDS`) so requests don't 400; routing-side filter is a follow-up. | -| `prompt_token_ids` only injected for non-streaming responses | Use non-streaming mode for RL rollouts (the default). | Streaming final-chunk injection is planned. | -| Weight version `"initial"` before first update | Use `/v1/rl/state.applied_weight_version` for source-of-truth; don't rely on the version string for correctness. | | -| Filesystem weight broadcast scales poorly for large models | Ok for 0.6B (~250 ms load); marginal at 7B (~25 s); ~150 s at 30B-A3B BF16; impractical at 70B+. | RDMA / NCCL-receive on dynamo.vllm planned. | - diff --git a/docs/index.yml b/docs/index.yml index 14b48b8e012d..a2b1aa8fbd2f 100644 --- a/docs/index.yml +++ b/docs/index.yml @@ -346,6 +346,9 @@ navigation: - section: Additional Resources hidden: true contents: + # -- RL -- + - page: RL + path: RL.md # -- Development -- - page: Runtime Guide path: development/runtime-guide.md diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index c2ff76c0b509..2fb9ffd1898b 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -1159,6 +1159,49 @@ impl Client { Ok(AsyncResponseStream::new(rx, annotated)) }) } + + /// Directly send a request to a specific endpoint without fallback re-selection. + #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] + fn direct_strict<'p>( + &self, + py: Python<'p>, + request: PyObject, + instance_id: u64, + annotated: Option, + context: Option, + ) -> PyResult> { + let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; + let request_ctx = create_request_context(request, &context); + let annotated = annotated.unwrap_or(false); + + let (tx, rx) = tokio::sync::mpsc::channel(32); + let client = self.router.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let stream = match context { + Some(context) => { + let span = get_span_for_direct_context( + &context, + "direct_strict", + &instance_id.to_string(), + ); + client + .direct_strict(request_ctx, instance_id) + .instrument(span) + .await + .map_err(to_pyerr)? + } + _ => client + .direct_strict(request_ctx, instance_id) + .await + .map_err(to_pyerr)?, + }; + + tokio::spawn(process_stream(stream, tx)); + + Ok(AsyncResponseStream::new(rx, annotated)) + }) + } } async fn process_stream( diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index 87663b5e4af5..808b561b5770 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -320,6 +320,7 @@ pub(crate) struct EntrypointArgs { kv_cache_block_size: Option, http_host: Option, http_port: u16, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -339,7 +340,7 @@ pub(crate) struct EntrypointArgs { impl EntrypointArgs { #[allow(clippy::too_many_arguments)] #[new] - #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, migration_max_seq_len=None, chat_engine_factory=None, aic_perf_config=None))] + #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, rl_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, migration_max_seq_len=None, chat_engine_factory=None, aic_perf_config=None))] pub fn new( py: Python<'_>, engine_type: EngineType, @@ -352,6 +353,7 @@ impl EntrypointArgs { kv_cache_block_size: Option, http_host: Option, http_port: Option, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -402,6 +404,7 @@ impl EntrypointArgs { kv_cache_block_size, http_host, http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT), + rl_port, http_metrics_port, tls_cert_path, tls_key_path, @@ -450,6 +453,7 @@ pub fn make_engine<'p>( .migration_max_seq_len(args.migration_max_seq_len) .http_host(args.http_host.clone()) .http_port(args.http_port) + .rl_port(args.rl_port) .http_metrics_port(args.http_metrics_port) .tls_cert_path(args.tls_cert_path.clone()) .tls_key_path(args.tls_key_path.clone()) diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index dba5466d0532..eb3a847abeca 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -246,6 +246,18 @@ class Client: """ ... + async def direct_strict( + self, + request: JsonLike, + instance_id: int, + annotated: bool | None = True, + context: Context | None = None, + ) -> AsyncIterator[JsonLike]: + """ + Pick a specific instance of the endpoint without fallback re-selection. + """ + ... + async def generate( self, request: JsonLike, @@ -2114,6 +2126,7 @@ class EntrypointArgs: kv_cache_block_size: Optional[int] = None, http_host: Optional[str] = None, http_port: Optional[int] = None, + rl_port: Optional[int] = None, http_metrics_port: Optional[int] = None, tls_cert_path: Optional[str] = None, tls_key_path: Optional[str] = None, @@ -2141,6 +2154,7 @@ class EntrypointArgs: kv_cache_block_size: Optional KV cache block size http_host: HTTP host to bind to http_port: HTTP port to bind to + rl_port: Dedicated RL admin HTTP port to bind to http_metrics_port: HTTP metrics port (for gRPC service) tls_cert_path: TLS certificate path (PEM format) tls_key_path: TLS key path (PEM format) diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index 28b231ae3e67..b54a0f799fd5 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -49,6 +49,9 @@ pub async fn run( if let Some(http_host) = local_model.http_host() { http_service_builder = http_service_builder.host(http_host); } + if let Some(rl_port) = local_model.rl_port() { + http_service_builder = http_service_builder.rl_port(rl_port); + } http_service_builder = http_service_builder.cancel_token(Some(distributed_runtime.primary_token())); http_service_builder = diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index c2cac10cf122..c6a87a43638d 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -2977,12 +2977,25 @@ fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { /// `smart_json_error_middleware` so 422s are coerced to 400s consistently /// with the OpenAI-compat surface. /// -/// Exposed only when `DYN_ENABLE_RL=true` or `HttpServiceConfig.enable_rl` -/// is set. Mounted by `service_v2.rs`. +/// Exposed only on the dedicated RL listener when +/// `DYN_ENABLE_RL_ENDPOINTS=true` or `HttpServiceConfig.enable_rl` is set. pub fn rl_router( drt: std::sync::Arc, ) -> anyhow::Result<(Vec, Router)> { - let (rl_docs, router) = dynamo_rl::rl_router(drt)?; + let namespace = std::env::var("DYN_NAMESPACE").unwrap_or_else(|_| "dynamo".into()); + let mut policy = dynamo_rl::FanoutPolicy::default_admin(); + if let Ok(component) = std::env::var("DYN_RL_COMPONENT") { + policy = policy.with_component_filter(vec![component]); + } + + let client = dynamo_rl::RlClient::new(dynamo_rl::RlClientConfig { + runtime: drt, + namespace, + rl_endpoint: dynamo_rl::DEFAULT_RL_ENDPOINT.to_string(), + policy, + })?; + + let (rl_docs, router) = dynamo_rl::rl_router(dynamo_rl::RlHttpDeps { client })?; let docs = rl_docs .into_iter() .map(|d| RouteDoc::new(d.method, d.path)) diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index f42f2fdb93be..e87edbf87c21 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::env::var; +use std::io::ErrorKind; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; @@ -204,6 +205,8 @@ pub struct HttpService { router: axum::Router, port: u16, + rl_router: Option, + rl_port: u16, host: String, enable_tls: bool, tls_cert_path: Option, @@ -217,6 +220,9 @@ pub struct HttpServiceConfig { #[builder(default = "8787")] port: u16, + #[builder(default = "default_rl_port()")] + rl_port: u16, + #[builder(setter(into), default = "String::from(\"0.0.0.0\")")] host: String, @@ -246,9 +252,9 @@ pub struct HttpServiceConfig { #[builder(default = "false")] enable_anthropic_endpoints: bool, - /// When true, expose the RL admin routes at `/v1/rl/*` (pause, resume, - /// update_weights, weight_version, ready). Worker system URLs are read - /// from `DYN_RL_WORKER_SYSTEM_URLS` (comma-separated, default `http://localhost:8081`). + /// When true, expose the RL admin routes at `/v1/rl/*` on the dedicated + /// `rl_port` listener. Fan-out uses dynamo-rl over the discovery and + /// request planes; worker system ports are not part of this contract. #[builder(default = "false")] enable_rl: bool, @@ -301,103 +307,57 @@ impl HttpService { } pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> { - let address = format!("{}:{}", self.host, self.port); - let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" }; - tracing::info!(protocol, address, "Starting HTTP(S) service"); - - let router = self.router.clone(); - let observer = cancel_token.child_token(); - - let state_cancel = self.state.cancel_token().clone(); - - let addr: SocketAddr = address - .parse() - .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?; - - if self.enable_tls { - let cert_path = self - .tls_cert_path - .as_ref() - .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?; - let key_path = self - .tls_key_path - .as_ref() - .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?; - - // aws_lc_rs is the default but other crates pull in `ring` also, - // so rustls doesn't know which one to use. Tell it. - if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() { - tracing::debug!("TLS crypto provider already installed: {e:?}"); - } + let mut handles = vec![spawn_http_listener(HttpListenerConfig { + name: "openai", + router: self.router.clone(), + host: self.host.clone(), + port: self.port, + port_arg: "--http-port", + enable_tls: self.enable_tls, + tls_cert_path: self.tls_cert_path.clone(), + tls_key_path: self.tls_key_path.clone(), + cancel_token: cancel_token.clone(), + state_cancel: self.state.cancel_token().clone(), + })]; + + if let Some(router) = self.rl_router.clone() { + handles.push(spawn_http_listener(HttpListenerConfig { + name: "rl", + router, + host: self.host.clone(), + port: self.rl_port, + port_arg: "--rl-port", + enable_tls: self.enable_tls, + tls_cert_path: self.tls_cert_path.clone(), + tls_key_path: self.tls_key_path.clone(), + cancel_token: cancel_token.clone(), + state_cancel: self.state.cancel_token().clone(), + })); + } - let config = RustlsConfig::from_pem_file(cert_path, key_path) - .await - .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?; + tokio::spawn(tokio_metrics_and_canary_loop(cancel_token.clone())); - let handle = axum_server::Handle::new(); - let server = axum_server::bind_rustls(addr, config) - .handle(handle.clone()) - .serve(router.into_make_service()); + let (first_result, _idx, remaining) = futures::future::select_all(handles).await; + cancel_token.cancel(); - // Spawn canary after all fallible startup so it won't leak on early errors - tokio::spawn(tokio_metrics_and_canary_loop(cancel_token.clone())); + let mut result = match first_result { + Ok(result) => result, + Err(err) => Err(anyhow::anyhow!("HTTP listener task failed: {err}")), + }; - tokio::select! { - result = server => { - let result = result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e)); - cancel_token.cancel(); - result?; - } - _ = observer.cancelled() => { - state_cancel.cancel(); - tracing::info!("HTTPS server shutdown requested"); - // accepting requests for 5 more seconds, to allow incorrectly routed requests to arrive - handle.graceful_shutdown(Some(Duration::from_secs(get_graceful_shutdown_timeout() as u64))); - // no longer accepting requests, draining all existing connections + for handle in remaining { + match handle.await { + Ok(Ok(())) => {} + Ok(Err(err)) if result.is_ok() => result = Err(err), + Ok(Err(_)) => {} + Err(err) if result.is_ok() => { + result = Err(anyhow::anyhow!("HTTP listener task failed: {err}")); } + Err(_) => {} } - } else { - let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { - tracing::error!( - protocol = %protocol, - address = %address, - error = %e, - "Failed to bind server to address" - ); - match e.kind() { - std::io::ErrorKind::AddrInUse => anyhow::anyhow!( - "Failed to start {} server: port {} already in use. Use --http-port to specify a different port.", - protocol, - self.port - ), - _ => anyhow::anyhow!( - "Failed to start {} server on {}: {}", - protocol, - address, - e - ), - } - })?; - - // Spawn canary after all fallible startup so it won't leak on early errors - tokio::spawn(tokio_metrics_and_canary_loop(cancel_token.clone())); - - axum::serve(listener, router) - .with_graceful_shutdown(async move { - observer.cancelled_owned().await; - state_cancel.cancel(); - tracing::info!("HTTP server shutdown requested"); - // accepting requests for 5 more seconds, to allow incorrectly routed requests to arrive - tokio::time::sleep(Duration::from_secs(get_graceful_shutdown_timeout() as u64)) - .await; - // no longer accepting requests, draining all existing connections - }) - .await - .inspect_err(|_| cancel_token.cancel())?; - cancel_token.cancel(); } - Ok(()) + result } /// Documentation of exposed HTTP endpoints @@ -415,6 +375,150 @@ impl HttpService { } } +struct HttpListenerConfig { + name: &'static str, + router: axum::Router, + host: String, + port: u16, + port_arg: &'static str, + enable_tls: bool, + tls_cert_path: Option, + tls_key_path: Option, + cancel_token: CancellationToken, + state_cancel: CancellationToken, +} + +fn spawn_http_listener(config: HttpListenerConfig) -> JoinHandle> { + tokio::spawn(run_http_listener(config)) +} + +async fn run_http_listener(config: HttpListenerConfig) -> Result<()> { + let address = format!("{}:{}", config.host, config.port); + let protocol = if config.enable_tls { "HTTPS" } else { "HTTP" }; + tracing::info!( + listener = config.name, + protocol, + address, + "Starting HTTP listener" + ); + + let addr: SocketAddr = address + .parse() + .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?; + + if config.enable_tls { + run_tls_listener(config, addr, protocol, address).await + } else { + run_plain_listener(config, addr, protocol, address).await + } +} + +async fn run_tls_listener( + config: HttpListenerConfig, + addr: SocketAddr, + protocol: &'static str, + address: String, +) -> Result<()> { + let cert_path = config + .tls_cert_path + .as_ref() + .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?; + let key_path = config + .tls_key_path + .as_ref() + .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?; + + // aws_lc_rs is the default but other crates pull in `ring` also, + // so rustls doesn't know which one to use. Tell it. + if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() { + tracing::debug!("TLS crypto provider already installed: {e:?}"); + } + + let tls_config = RustlsConfig::from_pem_file(cert_path, key_path) + .await + .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?; + + let handle = axum_server::Handle::new(); + let observer = config.cancel_token.child_token(); + let state_cancel = config.state_cancel.clone(); + let listener_name = config.name; + let server = axum_server::bind_rustls(addr, tls_config) + .handle(handle.clone()) + .serve(config.router.into_make_service()); + + tokio::select! { + result = server => { + result.map_err(|e| { + tracing::error!( + listener = listener_name, + protocol = %protocol, + address = %address, + error = %e, + "HTTP listener failed" + ); + anyhow::anyhow!("{} listener '{}' error: {}", protocol, listener_name, e) + })?; + } + _ = observer.cancelled_owned() => { + state_cancel.cancel(); + tracing::info!(listener = listener_name, "HTTP listener shutdown requested"); + // accepting requests for a short window allows incorrectly routed + // requests already in flight to arrive before draining connections. + handle.graceful_shutdown(Some(Duration::from_secs(get_graceful_shutdown_timeout() as u64))); + } + } + + Ok(()) +} + +async fn run_plain_listener( + config: HttpListenerConfig, + addr: SocketAddr, + protocol: &'static str, + address: String, +) -> Result<()> { + let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { + tracing::error!( + listener = config.name, + protocol = %protocol, + address = %address, + error = %e, + "Failed to bind HTTP listener to address" + ); + match e.kind() { + ErrorKind::AddrInUse => anyhow::anyhow!( + "Failed to start {} listener '{}': port {} already in use. Use {} to specify a different port.", + protocol, + config.name, + config.port, + config.port_arg + ), + _ => anyhow::anyhow!( + "Failed to start {} listener '{}' on {}: {}", + protocol, + config.name, + address, + e + ), + } + })?; + + let observer = config.cancel_token.child_token(); + let state_cancel = config.state_cancel.clone(); + let listener_name = config.name; + + axum::serve(listener, config.router) + .with_graceful_shutdown(async move { + observer.cancelled_owned().await; + state_cancel.cancel(); + tracing::info!(listener = listener_name, "HTTP listener shutdown requested"); + tokio::time::sleep(Duration::from_secs(get_graceful_shutdown_timeout() as u64)).await; + }) + .await?; + + Ok(()) +} + fn get_graceful_shutdown_timeout() -> usize { std::env::var(env_llm::DYN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT_SECS) .ok() @@ -422,6 +526,16 @@ fn get_graceful_shutdown_timeout() -> usize { .unwrap_or(5) } +const DEFAULT_RL_PORT: u16 = 8002; +const DYN_RL_PORT_ENV: &str = "DYN_RL_PORT"; + +fn default_rl_port() -> u16 { + std::env::var(DYN_RL_PORT_ENV) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(DEFAULT_RL_PORT) +} + /// Environment variable to set the metrics endpoint path (default: `/metrics`) static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH"; /// Environment variable to set the models endpoint path (default: `/v1/models`) @@ -530,7 +644,7 @@ impl HttpServiceConfigBuilder { }; // System routes (health, metrics, models) — debug-level spans - let mut system_routes = vec![ + let system_routes = vec![ metrics::router( registry, var(HTTP_SVC_METRICS_PATH_ENV).ok(), @@ -548,20 +662,29 @@ impl HttpServiceConfigBuilder { super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), ]; - // RL admin routes: gated by `DYN_ENABLE_RL_ENDPOINTS` (frontend-only). - // `DYN_ENABLE_RL` is preserved as a fallback alias for the previous - // single-flag deployment shape until clients migrate. The - // builder-time `enable_rl` flag forces routes on regardless of env. - // PR C of `rl-crate.md`: split inference-plane (DYN_ENABLE_RL) from - // admin-plane (DYN_ENABLE_RL_ENDPOINTS). - if config.enable_rl - || env_is_truthy("DYN_ENABLE_RL_ENDPOINTS") - || env_is_truthy("DYN_ENABLE_RL") - { + // RL admin routes: gated by `DYN_ENABLE_RL_ENDPOINTS` (frontend-only) + // and served on a separate listener (`DYN_RL_PORT`, default 8002). + // `DYN_ENABLE_RL` remains the inference-plane flag and no longer + // mounts admin routes on the OpenAI listener. + let rl_router = if config.enable_rl || env_is_truthy("DYN_ENABLE_RL_ENDPOINTS") { match config.runtime.as_ref() { Some(drt) => { - tracing::info!("RL admin routes enabled at /v1/rl/* (request-plane fan-out)"); - system_routes.push(super::openai::rl_router(drt.clone())?); + tracing::info!( + rl_port = config.rl_port, + "RL admin routes enabled at /v1/rl/* on dedicated listener" + ); + let (rl_docs, router) = super::openai::rl_router(drt.clone())?; + let (_openapi_docs, openapi_route) = + super::openapi_docs::openapi_router(rl_docs, None); + let router = router + .merge(openapi_route) + .layer( + TraceLayer::new_for_http() + .make_span_with(make_system_request_span) + .on_response(on_response), + ) + .layer(axum::middleware::from_fn(echo_request_id_header)); + Some(router) } None => { tracing::warn!( @@ -569,9 +692,12 @@ impl HttpServiceConfigBuilder { HttpServiceConfigBuilder.runtime is None — skipping mount. \ The frontend caller must supply the DistributedRuntime." ); + None } } - } + } else { + None + }; let mut system_router = axum::Router::new(); for (route_docs, route) in system_routes { system_router = system_router.merge(route); @@ -612,6 +738,8 @@ impl HttpServiceConfigBuilder { state, router, port: config.port, + rl_router, + rl_port: config.rl_port, host: config.host, enable_tls: config.enable_tls, tls_cert_path: config.tls_cert_path, diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 000f87a300b3..3c88dd64e254 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -57,6 +57,7 @@ pub struct LocalModelBuilder { kv_cache_block_size: u32, http_host: Option, http_port: u16, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -80,6 +81,7 @@ impl Default for LocalModelBuilder { kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE, http_host: Default::default(), http_port: DEFAULT_HTTP_PORT, + rl_port: None, http_metrics_port: None, tls_cert_path: Default::default(), tls_key_path: Default::default(), @@ -152,6 +154,11 @@ impl LocalModelBuilder { self } + pub fn rl_port(&mut self, port: Option) -> &mut Self { + self.rl_port = port; + self + } + pub fn http_metrics_port(&mut self, port: Option) -> &mut Self { self.http_metrics_port = port; self @@ -282,6 +289,7 @@ impl LocalModelBuilder { template, http_host: self.http_host.take(), http_port: self.http_port, + rl_port: self.rl_port, http_metrics_port: self.http_metrics_port, tls_cert_path: self.tls_cert_path.take(), tls_key_path: self.tls_key_path.take(), @@ -339,6 +347,7 @@ impl LocalModelBuilder { template, http_host: self.http_host.take(), http_port: self.http_port, + rl_port: self.rl_port, http_metrics_port: self.http_metrics_port, tls_cert_path: self.tls_cert_path.take(), tls_key_path: self.tls_key_path.take(), @@ -362,6 +371,7 @@ pub struct LocalModel { template: Option, http_host: Option, http_port: u16, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -418,6 +428,10 @@ impl LocalModel { self.http_port } + pub fn rl_port(&self) -> Option { + self.rl_port + } + pub fn http_metrics_port(&self) -> Option { self.http_metrics_port } diff --git a/lib/rl/src/lib.rs b/lib/rl/src/lib.rs index 9629779a57fb..63f77a604855 100644 --- a/lib/rl/src/lib.rs +++ b/lib/rl/src/lib.rs @@ -8,12 +8,17 @@ //! **PR B status:** request-plane fan-out via the dynamo discovery plane. //! Workers register one endpoint `dyn://..rl` (see //! `worker_factory.py::rl_endpoint.serve_endpoint(handler.rl_dispatch, …)`) -//! and the frontend dispatches by listing live `rl` instances and calling -//! each via [`PushRouter::direct`]. The legacy `register_engine_route` +//! and the frontend dispatches by snapshotting live `rl` instances and calling +//! each via strict request-plane direct routing. The legacy `register_engine_route` //! HTTP-on-system-port mechanism + `DYN_RL_WORKER_SYSTEM_URLS` static URL //! list are gone. -use std::sync::Arc; +use std::{ + collections::{HashMap, hash_map::DefaultHasher}, + hash::{Hash, Hasher}, + sync::Arc, + time::Duration, +}; use axum::{ Json, Router, @@ -24,209 +29,609 @@ use axum::{ }; use dynamo_runtime::{ DistributedRuntime, + component::Client, + discovery::{DiscoveryInstance, DiscoveryQuery}, pipeline::{ SingleIn, network::egress::push_router::{PushRouter, RouterMode}, }, protocols::annotated::Annotated, }; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; + +pub const DEFAULT_RL_ENDPOINT: &str = "rl"; -/// Documentation tuple for an RL admin route. The dynamo-llm caller wraps -/// each tuple into its own `RouteDoc` for `/openapi.json` aggregation. #[derive(Debug, Clone)] -pub struct RlRouteDoc { - pub method: Method, - pub path: String, +pub enum RlError { + NoWorkers { + namespace: String, + rl_endpoint: String, + }, + MembershipChanged { + before_epoch: u64, + after_epoch: u64, + }, } -impl RlRouteDoc { - fn new(method: Method, path: impl Into) -> Self { +impl std::fmt::Display for RlError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RlError::NoWorkers { + namespace, + rl_endpoint, + } => write!( + f, + "no live RL workers found in namespace '{namespace}' for endpoint '{rl_endpoint}'" + ), + RlError::MembershipChanged { + before_epoch, + after_epoch, + } => write!( + f, + "RL worker membership changed during fan-out (before={before_epoch}, after={after_epoch})" + ), + } + } +} + +impl std::error::Error for RlError {} + +#[derive(Debug, Clone)] +pub struct RlClientConfig { + pub runtime: Arc, + pub namespace: String, + pub rl_endpoint: String, + pub policy: FanoutPolicy, +} + +#[derive(Debug, Clone)] +pub struct FanoutPolicy { + pub min_workers: usize, + pub membership_timeout: Duration, + pub request_timeout: Duration, + pub strict_direct: bool, + pub abort_on_membership_change: bool, + pub component_filter: Option>, +} + +impl FanoutPolicy { + pub fn default_admin() -> Self { Self { - method, - path: path.into(), + min_workers: 1, + membership_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(30), + strict_direct: true, + abort_on_membership_change: true, + component_filter: None, } } + + pub fn with_component_filter(mut self, components: Vec) -> Self { + let components: Vec = components + .into_iter() + .map(|c| c.trim().to_string()) + .filter(|c| !c.is_empty()) + .collect(); + self.component_filter = if components.is_empty() { + None + } else { + Some(components) + }; + self + } +} + +impl Default for FanoutPolicy { + fn default() -> Self { + Self::default_admin() + } +} + +#[derive( + Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +pub struct WorkerTarget { + pub namespace: String, + pub component: String, + pub endpoint: String, + pub instance_id: u64, +} + +impl WorkerTarget { + fn endpoint_key(&self) -> (String, String, String) { + ( + self.namespace.clone(), + self.component.clone(), + self.endpoint.clone(), + ) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct MembershipSnapshot { + pub epoch: u64, + pub targets: Vec, +} + +impl MembershipSnapshot { + fn new(mut targets: Vec) -> Self { + targets.sort(); + targets.dedup(); + + let mut hasher = DefaultHasher::new(); + targets.hash(&mut hasher); + let epoch = hasher.finish(); + + Self { epoch, targets } + } + + pub fn is_empty(&self) -> bool { + self.targets.is_empty() + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RlRequest { + pub op: String, + #[serde(default)] + pub body: serde_json::Value, +} + +impl RlRequest { + pub fn new(op: impl Into, body: serde_json::Value) -> Self { + Self { + op: op.into(), + body, + } + } + + pub fn describe(_req: DescribeRequest) -> Self { + Self::new("describe", serde_json::json!({})) + } + + pub fn pause(req: PauseRequest) -> Self { + Self::new("pause", serde_json::to_value(req).unwrap_or_default()) + } + + pub fn resume(_req: ResumeRequest) -> Self { + Self::new("resume", serde_json::json!({})) + } + + pub fn init_transport(req: InitTransportRequest) -> Self { + Self::new("init_transport", req.0) + } + + pub fn update_weights(req: UpdateWeightsRequest) -> Self { + Self::new("update_weights", req.into_body()) + } +} + +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct DescribeRequest {} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PauseRequest { + pub mode: String, + pub clear_cache: bool, +} + +impl Default for PauseRequest { + fn default() -> Self { + Self { + mode: "keep".to_string(), + clear_cache: false, + } + } +} + +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct ResumeRequest {} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct InitTransportRequest(pub serde_json::Value); + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UpdateWeightsRequest { + pub version: String, + pub target: serde_json::Value, + pub transport: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pause_mode: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub clear_cache: Option, +} + +impl UpdateWeightsRequest { + fn into_body(self) -> serde_json::Value { + let mut body = serde_json::json!({ + "version": self.version, + "target": self.target, + "transport": self.transport, + }); + if let Some(pause_mode) = self.pause_mode { + body["pause_mode"] = serde_json::Value::String(pause_mode); + } + if let Some(clear_cache) = self.clear_cache { + body["clear_cache"] = serde_json::Value::Bool(clear_cache); + } + body + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct WorkerResult { + pub target: WorkerTarget, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl WorkerResult { + fn ok(target: WorkerTarget, response: serde_json::Value) -> Self { + Self { + target, + response: Some(response), + error: None, + } + } + + fn error(target: WorkerTarget, error: impl Into) -> Self { + Self { + target, + response: None, + error: Some(error.into()), + } + } + + pub fn is_ok(&self) -> bool { + self.error.is_none() + && self + .response + .as_ref() + .and_then(|r| r.get("status")) + .and_then(|s| s.as_str()) + == Some("ok") + } + + pub fn payload(&self) -> serde_json::Value { + match (&self.response, &self.error) { + (Some(response), None) => response.clone(), + (_, Some(error)) => serde_json::json!({ + "status": "error", + "namespace": self.target.namespace, + "component": self.target.component, + "endpoint": self.target.endpoint, + "instance_id": self.target.instance_id, + "message": error, + }), + _ => serde_json::json!({ + "status": "error", + "namespace": self.target.namespace, + "component": self.target.component, + "endpoint": self.target.endpoint, + "instance_id": self.target.instance_id, + "message": "missing worker response", + }), + } + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct FanoutReport { + pub snapshot: MembershipSnapshot, + pub workers: Vec, +} + +impl FanoutReport { + pub fn all_ok(&self) -> bool { + !self.workers.is_empty() && self.workers.iter().all(WorkerResult::is_ok) + } + + pub fn worker_payloads(&self) -> Vec { + self.workers.iter().map(WorkerResult::payload).collect() + } } -/// Shared state for the RL admin router. -/// -/// Holds a runtime handle, a target `.` pair, and the -/// name of the unified RL endpoint (always `"rl"`). Each fan-out call: -/// -/// 1. Lists live instances of `..rl` via discovery. -/// 2. Builds a [`PushRouter`] over the runtime's request plane (NATS / shared TCP). -/// 3. Calls [`PushRouter::direct`] per `instance_id` with a JSON -/// `{"op": , "body": }` envelope. -/// 4. Drains the response stream and extracts the first `Annotated.data`. #[derive(Clone)] -struct RlState { - drt: Arc, +pub struct RlClient { + runtime: Arc, namespace: String, - component: String, - /// The endpoint name workers serve their RL dispatcher on. Always `"rl"`. rl_endpoint: String, + policy: FanoutPolicy, } -impl RlState { - fn from_env(drt: Arc) -> anyhow::Result { - let namespace = std::env::var("DYN_NAMESPACE").unwrap_or_else(|_| "dynamo".into()); - // Workers default to component="backend" (vLLM, sglang). Allow - // override for disagg / multi-component deployments. - let component = std::env::var("DYN_RL_COMPONENT").unwrap_or_else(|_| "backend".into()); - let rl_endpoint = "rl".to_string(); - tracing::info!( - ns = %namespace, - comp = %component, - rl_endpoint = %rl_endpoint, - "RL admin router configured (request-plane discovery)" - ); +impl RlClient { + pub fn new(config: RlClientConfig) -> anyhow::Result { + if config.namespace.trim().is_empty() { + anyhow::bail!("RlClientConfig.namespace must not be empty"); + } + if config.rl_endpoint.trim().is_empty() { + anyhow::bail!("RlClientConfig.rl_endpoint must not be empty"); + } + Ok(Self { - drt, - namespace, - component, - rl_endpoint, + runtime: config.runtime, + namespace: config.namespace, + rl_endpoint: config.rl_endpoint, + policy: config.policy, }) } - /// Fan out an admin op to every live worker via the request plane. - /// - /// `route` is the legacy engine-route name (`pause_generation`, - /// `resume_generation`, `weight_transport_init`, `weight_transport_update`) - /// preserved from the call sites; we map it to the unified op name on - /// the wire. - /// - /// Source of truth for "which workers are live" is the - /// [`Client::instance_source`] watcher (etcd-backed), not a one-shot - /// discovery `list()`. PushRouter's `direct()` checks the same client - /// view internally — going through the client avoids the race where a - /// freshly-built client hasn't populated yet. - async fn fan_out(&self, route: &str, body: serde_json::Value) -> Vec { - let op = route_to_op(route); - - let endpoint = match self - .drt - .namespace(&self.namespace) - .and_then(|ns| ns.component(&self.component)) - { - Ok(comp) => comp.endpoint(&self.rl_endpoint), - Err(err) => { - tracing::warn!(%err, route, "RL fan_out: failed to build endpoint"); - return vec![serde_json::json!({ - "status": "error", - "message": format!("endpoint build failed: {err}"), - })]; - } - }; + pub async fn snapshot(&self) -> anyhow::Result { + let instances = self + .runtime + .discovery() + .list(DiscoveryQuery::NamespacedEndpoints { + namespace: self.namespace.clone(), + }) + .await?; + + let targets = instances + .into_iter() + .filter_map(|instance| match instance { + DiscoveryInstance::Endpoint(instance) if instance.endpoint == self.rl_endpoint => { + Some(instance) + } + _ => None, + }) + .filter(|instance| { + self.policy + .component_filter + .as_ref() + .map(|components| components.iter().any(|c| c == &instance.component)) + .unwrap_or(true) + }) + .map(|instance| WorkerTarget { + namespace: instance.namespace, + component: instance.component, + endpoint: instance.endpoint, + instance_id: instance.instance_id, + }) + .collect(); + + Ok(MembershipSnapshot::new(targets)) + } + + pub async fn describe(&self) -> anyhow::Result { + self.fanout(RlRequest::describe(DescribeRequest::default())) + .await + } + + pub async fn pause(&self, req: PauseRequest) -> anyhow::Result { + self.fanout(RlRequest::pause(req)).await + } + + pub async fn resume(&self, req: ResumeRequest) -> anyhow::Result { + self.fanout(RlRequest::resume(req)).await + } + + pub async fn init_transport(&self, req: InitTransportRequest) -> anyhow::Result { + self.fanout(RlRequest::init_transport(req)).await + } + + pub async fn update_weights(&self, req: UpdateWeightsRequest) -> anyhow::Result { + self.fanout(RlRequest::update_weights(req)).await + } - let client = match endpoint.client().await { - Ok(c) => c, - Err(err) => { - tracing::warn!(%err, route, "RL fan_out: failed to create endpoint client"); - return vec![serde_json::json!({ - "status": "error", - "message": format!("client create failed: {err}"), - })]; + pub async fn fanout(&self, request: RlRequest) -> anyhow::Result { + let snapshot = self.snapshot().await?; + self.fanout_snapshot(snapshot, request).await + } + + pub async fn fanout_snapshot( + &self, + snapshot: MembershipSnapshot, + request: RlRequest, + ) -> anyhow::Result { + if snapshot.targets.len() < self.policy.min_workers { + return Err(RlError::NoWorkers { + namespace: self.namespace.clone(), + rl_endpoint: self.rl_endpoint.clone(), } - }; + .into()); + } - // Bound the watcher-population race: wait until the client sees - // ≥1 instance (or a short deadline elapses, in which case we - // surface the empty-fanout warning below). 5s is generous — - // workers register synchronously on serve_endpoint() before they - // start serving traffic, so by the time anything POSTs `/v1/rl/*` - // they should already be in etcd. - let _ = tokio::time::timeout( - std::time::Duration::from_secs(5), - client.wait_for_instances(), - ) - .await; - - let instance_ids: Vec = client.instance_ids(); - if instance_ids.is_empty() { - tracing::warn!( - ns = %self.namespace, - comp = %self.component, - route, - "RL fan_out: no live workers under {}.{}.rl; \ - check DYN_NAMESPACE / DYN_RL_COMPONENT vs worker --component", - self.namespace, - self.component, - ); - return Vec::new(); + let mut grouped: HashMap<(String, String, String), Vec> = HashMap::new(); + for target in &snapshot.targets { + grouped + .entry(target.endpoint_key()) + .or_default() + .push(target.clone()); } - let router = - match PushRouter::>::from_client( - client, - RouterMode::Direct, - ) - .await + let mut calls: Vec> = Vec::new(); + for ((namespace, component, endpoint_name), targets) in grouped { + let endpoint = match self + .runtime + .namespace(&namespace) + .and_then(|ns| ns.component(&component)) { - Ok(r) => r, + Ok(component) => component.endpoint(endpoint_name), + Err(err) => { + for target in targets { + calls.push( + futures::future::ready(WorkerResult::error( + target, + format!("endpoint build failed: {err}"), + )) + .boxed(), + ); + } + continue; + } + }; + + let client = match endpoint.client().await { + Ok(client) => client, Err(err) => { - tracing::warn!(%err, route, "RL fan_out: failed to build PushRouter"); - return vec![serde_json::json!({ - "status": "error", - "message": format!("PushRouter build failed: {err}"), - })]; + for target in targets { + calls.push( + futures::future::ready(WorkerResult::error( + target, + format!("client create failed: {err}"), + )) + .boxed(), + ); + } + continue; } }; - let envelope = serde_json::json!({"op": op, "body": body}); - - let futures: Vec<_> = instance_ids - .iter() - .copied() - .map(|id| { - let router = router.clone(); - let envelope = envelope.clone(); - async move { - let req = SingleIn::new(envelope.clone()); - match router.direct(req, id).await { - Ok(mut stream) => { - // Drain the first non-empty data chunk from the - // worker's async-generator response. - while let Some(chunk) = stream.next().await { - if let Some(data) = chunk.data { - return data; - } - if let Some(err) = chunk.error { - return serde_json::json!({ - "status": "error", - "instance_id": id, - "message": err.to_string(), - }); - } - } - serde_json::json!({ - "status": "error", - "instance_id": id, - "message": "empty response stream", - }) + let target_ids: Vec = targets.iter().map(|target| target.instance_id).collect(); + wait_for_client_targets(&client, &target_ids, self.policy.membership_timeout).await; + + let router = + match PushRouter::>::from_client( + client, + RouterMode::Direct, + ) + .await + { + Ok(router) => router, + Err(err) => { + for target in targets { + calls.push( + futures::future::ready(WorkerResult::error( + target, + format!("PushRouter build failed: {err}"), + )) + .boxed(), + ); } - Err(err) => serde_json::json!({ - "status": "error", - "instance_id": id, - "message": format!("dispatch failed: {err}"), - }), + continue; } + }; + + for target in targets { + calls.push( + call_worker( + router.clone(), + target, + request.clone(), + self.policy.request_timeout, + self.policy.strict_direct, + ) + .boxed(), + ); + } + } + + let workers = futures::future::join_all(calls).await; + + if self.policy.abort_on_membership_change { + let after = self.snapshot().await?; + if after.epoch != snapshot.epoch { + return Err(RlError::MembershipChanged { + before_epoch: snapshot.epoch, + after_epoch: after.epoch, } - }) - .collect(); - futures::future::join_all(futures).await + .into()); + } + } + + Ok(FanoutReport { snapshot, workers }) + } +} + +async fn wait_for_client_targets(client: &Client, target_ids: &[u64], timeout: Duration) { + let wait = async { + loop { + let instance_ids = client.instance_ids(); + if target_ids.iter().all(|id| instance_ids.contains(id)) { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + }; + + let _ = tokio::time::timeout(timeout, wait).await; +} + +async fn call_worker( + router: PushRouter>, + target: WorkerTarget, + request: RlRequest, + timeout: Duration, + strict_direct: bool, +) -> WorkerResult { + let request_value = match serde_json::to_value(request) { + Ok(value) => value, + Err(err) => return WorkerResult::error(target, format!("request encode failed: {err}")), + }; + + let instance_id = target.instance_id; + let dispatch = async { + let req = SingleIn::new(request_value); + let mut stream = if strict_direct { + router.direct_strict(req, instance_id).await? + } else { + router.direct(req, instance_id).await? + }; + + while let Some(chunk) = stream.next().await { + if let Some(data) = chunk.data { + return anyhow::Ok(data); + } + if let Some(err) = chunk.error { + anyhow::bail!(err.to_string()); + } + } + + anyhow::bail!("empty response stream"); + }; + + match tokio::time::timeout(timeout, dispatch).await { + Ok(Ok(response)) => WorkerResult::ok(target, response), + Ok(Err(err)) => WorkerResult::error(target, format!("dispatch failed: {err}")), + Err(_) => WorkerResult::error( + target, + format!("dispatch timed out after {}s", timeout.as_secs()), + ), + } +} + +/// Documentation tuple for an RL admin route. The dynamo-llm caller wraps +/// each tuple into its own `RouteDoc` for `/openapi.json` aggregation. +#[derive(Debug, Clone)] +pub struct RlRouteDoc { + pub method: Method, + pub path: String, +} + +impl RlRouteDoc { + fn new(method: Method, path: impl Into) -> Self { + Self { + method, + path: path.into(), + } + } +} + +/// Shared state for the RL admin HTTP facade. +#[derive(Clone)] +struct RlState { + client: RlClient, +} + +impl RlState { + fn new(client: RlClient) -> Self { + Self { client } } - /// Returns true only if every result is `status: "ok"` AND there is at - /// least one. Empty fan-out (no workers found) is `503`, not silent OK. - fn all_ok(results: &[serde_json::Value]) -> bool { - !results.is_empty() - && results - .iter() - .all(|r| r.get("status").and_then(|s| s.as_str()) == Some("ok")) + async fn fan_out(&self, route: &str, body: serde_json::Value) -> anyhow::Result { + self.client + .fanout(RlRequest::new(route_to_op(route), body)) + .await } } +#[derive(Clone)] +pub struct RlHttpDeps { + pub client: RlClient, +} + /// Map a legacy engine-route name to the corresponding `rl_dispatch` op. fn route_to_op(route: &str) -> &str { match route { @@ -283,21 +688,45 @@ struct RlPauseQuery { clear_cache: Option, } +fn rl_error_response(err: anyhow::Error) -> (StatusCode, Json) { + let (status, error_type) = match err.downcast_ref::() { + Some(RlError::NoWorkers { .. }) => (StatusCode::SERVICE_UNAVAILABLE, "no_workers"), + Some(RlError::MembershipChanged { .. }) => (StatusCode::CONFLICT, "membership_changed"), + None => (StatusCode::BAD_GATEWAY, "fanout_failed"), + }; + + ( + status, + Json(serde_json::json!({ + "status": "error", + "error_type": error_type, + "message": err.to_string(), + })), + ) +} + async fn rl_pause( State(state): State>, axum::extract::Query(q): axum::extract::Query, ) -> impl IntoResponse { let mode = q.mode.unwrap_or_default(); let clear_cache = q.clear_cache.unwrap_or(false); - let results = state + let report = match state .fan_out( "pause_generation", serde_json::json!({"mode": mode.as_str(), "clear_cache": clear_cache}), ) - .await; - if RlState::all_ok(&results) { + .await + { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { tracing::info!( - worker_count = results.len(), + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, mode = %mode.as_str(), clear_cache, "RL pause: all workers paused" @@ -308,37 +737,57 @@ async fn rl_pause( "status": "ok", "mode": mode.as_str(), "clear_cache": clear_cache, - "workers": results, + "membership_epoch": report.snapshot.epoch, + "workers": workers, })), ) } else { - tracing::warn!(?results, "RL pause: some workers failed"); + tracing::warn!(?workers, "RL pause: some workers failed"); ( StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), + Json(serde_json::json!({ + "status": "error", + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), ) } } /// `POST /v1/rl/resume` — fan out `resume_generation` to all workers. async fn rl_resume(State(state): State>) -> impl IntoResponse { - let results = state + let report = match state .fan_out("resume_generation", serde_json::json!({})) - .await; - if RlState::all_ok(&results) { + .await + { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { tracing::info!( - worker_count = results.len(), + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, "RL resume: all workers resumed" ); ( StatusCode::OK, - Json(serde_json::json!({"status": "ok", "workers": results})), + Json(serde_json::json!({ + "status": "ok", + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), ) } else { - tracing::warn!(?results, "RL resume: some workers failed"); + tracing::warn!(?workers, "RL resume: some workers failed"); ( StatusCode::BAD_GATEWAY, - Json(serde_json::json!({"status": "error", "workers": results})), + Json(serde_json::json!({ + "status": "error", + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), ) } } @@ -440,10 +889,16 @@ async fn rl_update_weights_inner( if let Some(cc) = clear_cache { body["clear_cache"] = serde_json::Value::Bool(cc); } - let results = state.fan_out("weight_transport_update", body).await; - if RlState::all_ok(&results) { + let report = match state.fan_out("weight_transport_update", body).await { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { tracing::info!( - worker_count = results.len(), + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, backend = %backend, version = %version, "RL update_weights: all workers updated" @@ -454,18 +909,20 @@ async fn rl_update_weights_inner( "status": "ok", "applied_weight_version": version, "backend": backend, - "workers": results, + "membership_epoch": report.snapshot.epoch, + "workers": workers, })), ) } else { - tracing::warn!(?results, backend = %backend, "RL update_weights: some workers failed"); + tracing::warn!(?workers, backend = %backend, "RL update_weights: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "status": "error", "stage": "weight_transport_update", "backend": backend, - "workers": results, + "membership_epoch": report.snapshot.epoch, + "workers": workers, })), ) } @@ -503,10 +960,16 @@ async fn rl_init_transport( .to_string(); tracing::info!(%backend, %transport_id, "RL init_transport"); - let results = state.fan_out("weight_transport_init", body).await; - if RlState::all_ok(&results) { + let report = match state.fan_out("weight_transport_init", body).await { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { tracing::info!( - worker_count = results.len(), + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, %backend, %transport_id, "RL init_transport: all workers ready" @@ -518,18 +981,20 @@ async fn rl_init_transport( "transport_id": transport_id, "backend": backend, "ready": true, - "workers": results, + "membership_epoch": report.snapshot.epoch, + "workers": workers, })), ) } else { - tracing::warn!(?results, %backend, "RL init_transport: some workers failed"); + tracing::warn!(?workers, %backend, "RL init_transport: some workers failed"); ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "status": "error", "transport_id": transport_id, "backend": backend, - "workers": results, + "membership_epoch": report.snapshot.epoch, + "workers": workers, })), ) } @@ -537,13 +1002,9 @@ async fn rl_init_transport( /// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. /// -/// **PR B:** fan-out goes through the dynamo discovery plane + request -/// plane. Workers register `..rl` (default -/// `dynamo.backend.rl`) on the request plane via -/// `runtime.endpoint(...).serve_endpoint(handler.rl_dispatch, ...)`. The -/// frontend lists live instances via [`DistributedRuntime::discovery`] -/// + [`DiscoveryQuery::NamespacedEndpoints`] and dispatches each call via -/// [`PushRouter::direct`] over NATS / shared TCP. +/// Fan-out goes through [`RlClient`], which snapshots the discovery plane, +/// groups live `..rl` workers, and dispatches with +/// request-plane strict direct calls over NATS / TCP / HTTP. /// /// **Surface:** four POST routes after Phase 3. /// `pause`, `resume`, `init_transport`, `update_weights`. Read-side @@ -555,9 +1016,9 @@ async fn rl_init_transport( /// /// Mounted on the dedicated `/v1/rl/*` listener when /// `DYN_ENABLE_RL_ENDPOINTS=true`. prime-rl usage: -/// `admin_base_url = "http://dynamo-frontend:8000/v1/rl"`. -pub fn rl_router(drt: Arc) -> anyhow::Result<(Vec, Router)> { - let rl_state_arc = Arc::new(RlState::from_env(drt)?); +/// `admin_base_url = "http://dynamo-frontend:8002/v1/rl"`. +pub fn rl_router(deps: RlHttpDeps) -> anyhow::Result<(Vec, Router)> { + let rl_state_arc = Arc::new(RlState::new(deps.client)); let docs = vec![ // Pause / resume bracket. RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), diff --git a/lib/runtime/src/pipeline/network/egress/push_router.rs b/lib/runtime/src/pipeline/network/egress/push_router.rs index 41fa024caced..001865c95dca 100644 --- a/lib/runtime/src/pipeline/network/egress/push_router.rs +++ b/lib/runtime/src/pipeline/network/egress/push_router.rs @@ -452,6 +452,27 @@ where .await } + /// Issue a request to a specific endpoint without fallback re-selection. + /// + /// This is intended for admin/control-plane operations where the caller has + /// already selected a concrete membership snapshot and routing the request + /// to any other instance would be incorrect. + pub async fn direct_strict( + &self, + request: SingleIn, + instance_id: u64, + ) -> anyhow::Result> { + if !self.client.instance_ids().contains(&instance_id) { + return Err(anyhow::anyhow!( + "instance_id={instance_id} not found for endpoint {}", + self.client.endpoint.id() + )); + } + + self.generate_with_fault_detection_options(instance_id, request, false) + .await + } + /// Issue a request using device-aware weighted routing. /// /// Instances are partitioned by device type (CPU vs non-CPU), then the router @@ -650,9 +671,19 @@ where */ async fn generate_with_fault_detection( + &self, + instance_id: u64, + request: SingleIn, + ) -> anyhow::Result> { + self.generate_with_fault_detection_options(instance_id, request, true) + .await + } + + async fn generate_with_fault_detection_options( &self, mut instance_id: u64, request: SingleIn, + allow_fallback: bool, ) -> anyhow::Result> { let route_start = Instant::now(); let request_id = request.id().to_string(); @@ -734,6 +765,12 @@ where if let Some(result) = resolve_transport(instance_id) { result + } else if !allow_fallback { + return Err(anyhow::anyhow!( + "Instance {} not found for endpoint {}", + instance_id, + self.client.endpoint.id() + )); } else { // Instance vanished — pick a different one from the current // availability list and retry the lookup once. From 7c15b6b9cd86cacf96603c691bc93cd91c109b15 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Fri, 8 May 2026 03:57:04 -0700 Subject: [PATCH 17/18] rl: allow lora unload without transport --- lib/rl/src/lib.rs | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/lib/rl/src/lib.rs b/lib/rl/src/lib.rs index 63f77a604855..dfd453afc330 100644 --- a/lib/rl/src/lib.rs +++ b/lib/rl/src/lib.rs @@ -235,6 +235,7 @@ pub struct InitTransportRequest(pub serde_json::Value); pub struct UpdateWeightsRequest { pub version: String, pub target: serde_json::Value, + #[serde(default)] pub transport: serde_json::Value, #[serde(default, skip_serializing_if = "Option::is_none")] pub pause_mode: Option, @@ -838,6 +839,7 @@ async fn rl_resume(State(state): State>) -> impl IntoResponse { struct RlUpdateWeightsBody { version: String, target: serde_json::Value, + #[serde(default)] transport: serde_json::Value, #[serde(default)] pause_mode: Option, @@ -852,13 +854,45 @@ async fn rl_update_weights( let RlUpdateWeightsBody { version, target, - transport, + mut transport, pause_mode, clear_cache, } = body.0; + if is_lora_unload(&target) && transport.get("backend").is_none() { + transport = serde_json::json!({ + "backend": "filesystem", + "filesystem": {}, + }); + } rl_update_weights_inner(state, version, target, transport, pause_mode, clear_cache).await } +fn is_lora_unload(target: &serde_json::Value) -> bool { + target.get("kind").and_then(|v| v.as_str()) == Some("lora") + && target.get("op").and_then(|v| v.as_str()) == Some("unload") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn update_weights_body_accepts_lora_unload_without_transport() { + let body: RlUpdateWeightsBody = serde_json::from_value(serde_json::json!({ + "version": "step_44", + "target": { + "kind": "lora", + "name": "adapter", + "op": "unload" + } + })) + .unwrap(); + + assert!(is_lora_unload(&body.target)); + assert!(body.transport.is_null()); + } +} + /// WeightTransferConfig path — fans out to ``weight_transport_update``. async fn rl_update_weights_inner( state: Arc, From dbda27c75781cae62f1ce25ac1c43747fd069dd1 Mon Sep 17 00:00:00 2001 From: Biswa Panda Date: Mon, 11 May 2026 09:47:59 -0700 Subject: [PATCH 18/18] feat(vllm): gate RL endpoint on --enable-rl / DYN_ENABLE_RL --- components/src/dynamo/vllm/backend_args.py | 13 ++++++++ components/src/dynamo/vllm/worker_factory.py | 33 ++++++++++---------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/components/src/dynamo/vllm/backend_args.py b/components/src/dynamo/vllm/backend_args.py index 28c9d0eb3034..9081dd960c9c 100644 --- a/components/src/dynamo/vllm/backend_args.py +++ b/components/src/dynamo/vllm/backend_args.py @@ -136,6 +136,16 @@ def add_arguments(self, parser) -> None: choices=[m.value for m in EmbeddingTransferMode], ) + # RL admin control plane + add_negatable_bool_argument( + g, + flag_name="--enable-rl", + env_var="DYN_ENABLE_RL", + default=False, + help="Register the RL admin endpoint (dyn://..rl) so the " + "frontend can fan out pause/resume/update_weights operations to this worker.", + ) + # Headless mode for multi-node TP/PP add_negatable_bool_argument( g, @@ -267,6 +277,9 @@ class DynamoVllmConfig(ConfigBase): str, EmbeddingTransferMode ] # resolved to enum in validate() + # RL admin control plane + enable_rl: bool = False + # Headless mode for multi-node TP/PP headless: bool = False diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index 9018f22a9fd1..a00c7d016527 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -241,21 +241,20 @@ async def _create_decode_worker( f"{config.namespace}.{config.component}.clear_kv_blocks" ) - # PR B: unified RL admin endpoint on the request plane. Discoverable - # via etcd as ``..rl``; the dynamo-rl frontend crate - # uses Discovery::list(NamespacedEndpoints) + PushRouter::direct to - # fan out admin ops here, replacing the legacy HTTP-on-system-port - # ``register_engine_route("pause_generation", …)`` etc. mechanism. - rl_endpoint = runtime.endpoint( - f"{config.namespace}.{config.component}.rl" - ) - shutdown_endpoints[:] = [ generate_endpoint, clear_endpoint, - rl_endpoint, ] + # RL admin endpoint — registered only when --enable-rl / DYN_ENABLE_RL=true. + # Discoverable via etcd as ``..rl``; the dynamo-rl frontend + # crate fans out pause/resume/update_weights operations here. + if config.enable_rl: + rl_endpoint = runtime.endpoint( + f"{config.namespace}.{config.component}.rl" + ) + shutdown_endpoints.append(rl_endpoint) + lora_enabled = config.engine_args.enable_lora if lora_enabled: load_lora_endpoint = runtime.endpoint( @@ -452,14 +451,16 @@ async def _create_decode_worker( handler.get_perf_metrics, metrics_labels=model_metrics_labels, ), - # PR B: unified RL admin endpoint (rl_dispatch dispatches - # by op name to pause/resume/init_transport/update_weights). - rl_endpoint.serve_endpoint( - handler.rl_dispatch, - metrics_labels=model_metrics_labels, - ), ] + if config.enable_rl: + serve_tasks.append( + rl_endpoint.serve_endpoint( + handler.rl_dispatch, + metrics_labels=model_metrics_labels, + ) + ) + if lora_enabled: serve_tasks.extend( [