Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/prime_rl/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ class AdminAPI(Protocol):
"""

async def health(self, client: AsyncClient) -> None: ...

async def list_models(self, client: AsyncClient) -> list[dict]: ...

async def pause(self, client: AsyncClient) -> None: ...

async def resume(self, client: AsyncClient) -> None: ...

async def update_weights(self, client: AsyncClient, weight_dir: str | None) -> None: ...

async def load_lora_adapter(
self,
client: AsyncClient,
Expand All @@ -36,6 +41,7 @@ async def load_lora_adapter(
*,
timeout: httpx.Timeout,
) -> None: ...

async def init_broadcaster(
self,
client: AsyncClient,
Expand Down Expand Up @@ -310,7 +316,16 @@ def setup_clients(
) -> list[vf.ClientConfig]:
clients = []
client_idx = 0
renderer_transport = "dynamo" if client_type == "renderer" and client_config.backend == "dynamo" else "vllm"
# `renderer_transport` selects the engine wire shape. Dynamo accepts the
# `dynamo` shape (placeholder messages + nvext.token_data on
# `/v1/chat/completions`) for both renderer-mode and TITO; vanilla vLLM
# accepts only the legacy shapes (`/generate` for renderer, `/chat/
# completions/tokens` for TITO).
is_token_aware_client = client_type in {"renderer", "openai_chat_completions_token"}
if is_token_aware_client and client_config.backend == "dynamo":
renderer_transport = "dynamo"
else:
renderer_transport = "vllm"
for base_url in client_config.base_url:
for dp_rank in range(client_config.dp_rank_count):
headers = client_config.headers.copy()
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/utils/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ def test_setup_clients_uses_dynamo_transport_for_dynamo_renderer():
assert clients[0].renderer_transport == "dynamo"


def test_setup_clients_uses_dynamo_transport_for_dynamo_token_client():
client_config = ClientConfig(
base_url=["http://worker-a:8000/v1"],
api_key_var="PRIME_API_KEY",
backend="dynamo",
)

clients = setup_clients(client_config, client_type="openai_chat_completions_token")

assert clients[0].renderer_transport == "dynamo"


def test_setup_clients_preserves_chat_client_defaults():
client_config = ClientConfig(
base_url=["http://worker-a:8000/v1"],
Expand Down
Loading