diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 4eb3ad1532..eddf4b1b7d 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -24,10 +24,15 @@ class AdminAPI(Protocol): """ async def health(self, client: AsyncClient) -> None: ... + async def list_models(self, client: AsyncClient) -> list[dict]: ... + async def pause(self, client: AsyncClient) -> None: ... + async def resume(self, client: AsyncClient) -> None: ... + async def update_weights(self, client: AsyncClient, weight_dir: str | None) -> None: ... + async def load_lora_adapter( self, client: AsyncClient, @@ -36,6 +41,7 @@ async def load_lora_adapter( *, timeout: httpx.Timeout, ) -> None: ... + async def init_broadcaster( self, client: AsyncClient, @@ -310,7 +316,16 @@ def setup_clients( ) -> list[vf.ClientConfig]: clients = [] client_idx = 0 - renderer_transport = "dynamo" if client_type == "renderer" and client_config.backend == "dynamo" else "vllm" + # `renderer_transport` selects the engine wire shape. Dynamo accepts the + # `dynamo` shape (placeholder messages + nvext.token_data on + # `/v1/chat/completions`) for both renderer-mode and TITO; vanilla vLLM + # accepts only the legacy shapes (`/generate` for renderer, `/chat/ + # completions/tokens` for TITO). + is_token_aware_client = client_type in {"renderer", "openai_chat_completions_token"} + if is_token_aware_client and client_config.backend == "dynamo": + renderer_transport = "dynamo" + else: + renderer_transport = "vllm" for base_url in client_config.base_url: for dp_rank in range(client_config.dp_rank_count): headers = client_config.headers.copy() diff --git a/tests/unit/utils/test_client.py b/tests/unit/utils/test_client.py index 5a071a1d92..3461abf932 100644 --- a/tests/unit/utils/test_client.py +++ b/tests/unit/utils/test_client.py @@ -101,6 +101,18 @@ def test_setup_clients_uses_dynamo_transport_for_dynamo_renderer(): assert clients[0].renderer_transport == "dynamo" +def test_setup_clients_uses_dynamo_transport_for_dynamo_token_client(): + client_config = ClientConfig( + base_url=["http://worker-a:8000/v1"], + api_key_var="PRIME_API_KEY", + backend="dynamo", + ) + + clients = setup_clients(client_config, client_type="openai_chat_completions_token") + + assert clients[0].renderer_transport == "dynamo" + + def test_setup_clients_preserves_chat_client_defaults(): client_config = ClientConfig( base_url=["http://worker-a:8000/v1"],