From 395f187b50fcac181de366698c0898fb18d778a8 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Thu, 5 Mar 2026 18:06:58 -0800 Subject: [PATCH 01/38] remove test guard, add sample method --- .../remote_inference_client.py | 41 +++++++++++++++++++ .../gpu/gpu_ci/test_engine_generation.py | 2 - 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 0e0cac0352..e5825dc0f8 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -293,6 +293,47 @@ async def _generate_single( "response_logprobs": response_logprobs if len(response_logprobs) > 0 else None, } + async def sample( + self, + prompt_token_ids: List[int], + num_samples: int, + sampling_params: Dict[str, Any], + ) -> InferenceEngineOutput: + """ + Generate multiple independent samples for the same prompt. + + Fires num_samples parallel calls to _generate_single with the same + prompt_token_ids and sampling_params, then aggregates into a single + InferenceEngineOutput. + + Args: + prompt_token_ids: Token IDs for the prompt. + num_samples: Number of independent samples to generate. + sampling_params: Sampling parameters for generation. + + Returns: + InferenceEngineOutput with num_samples responses. + """ + get_logprobs = sampling_params.get("logprobs") is not None + + tasks = [ + self._generate_single( + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + session_id=None, + ) + for _ in range(num_samples) + ] + + results = await asyncio.gather(*tasks) + + return InferenceEngineOutput( + responses=[r["response"] for r in results], + stop_reasons=[r["stop_reason"] for r in results], + response_ids=[r["response_ids"] for r in results], + response_logprobs=[r["response_logprobs"] for r in results] if get_logprobs else None, + ) + async def chat_completion( self, request_payload: Dict[str, Any], diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py index a7682bb76b..ef74ec1e4c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py @@ -286,8 +286,6 @@ def test_token_based_generation_consistency(ray_init_fixture, tp_size: int, pp_s ) -# TODO: Remove this once sample API is also supported in the new inference pathway -@pytest.mark.skipif(_SKYRL_USE_NEW_INFERENCE, reason="New inference pathway doesn't support sample API yet") @pytest.mark.parametrize( "tp_size,dp_size", [ From f0ec828141da70fc41e260c7eeabc366a61ee288 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Thu, 5 Mar 2026 18:33:17 -0800 Subject: [PATCH 02/38] add session id --- .../skyrl_train/inference_servers/remote_inference_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index e5825dc0f8..bb21289c0a 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -298,6 +298,7 @@ async def sample( prompt_token_ids: List[int], num_samples: int, sampling_params: Dict[str, Any], + session_id: Optional[Union[str, int]] = None, ) -> InferenceEngineOutput: """ Generate multiple independent samples for the same prompt. @@ -310,6 +311,7 @@ async def sample( prompt_token_ids: Token IDs for the prompt. num_samples: Number of independent samples to generate. sampling_params: Sampling parameters for generation. + session_id: Optional session ID for consistent routing via X-Session-ID header. Returns: InferenceEngineOutput with num_samples responses. @@ -320,7 +322,7 @@ async def sample( self._generate_single( prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, - session_id=None, + session_id=session_id, ) for _ in range(num_samples) ] From 07ae4debdd9df0c478ba20bc331902b85455d401 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Fri, 6 Mar 2026 10:42:26 -0800 Subject: [PATCH 03/38] add render api --- .../remote_inference_client.py | 42 +++++++++++++++++++ .../test_new_inference_generation.py | 18 ++++++++ .../test_remote_inference_client.py | 23 ++++++++++ 3 files changed, 83 insertions(+) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index bb21289c0a..66881f2d54 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -368,6 +368,48 @@ async def chat_completion( raise_for_status(resp, response) return response + async def render_chat_completion( + self, + messages: List[Dict[str, Any]], + add_generation_prompt: bool = True, + continue_final_message: bool = False, + session_id: Optional[Union[str, int]] = None, + ) -> Dict[str, Any]: + """ + Render chat messages into a tokenized prompt via /v1/chat/completions/render. + + Applies the model's chat template and tokenizes without generating text. + + Args: + messages: List of chat messages (e.g., [{"role": "user", "content": "Hello"}]). + add_generation_prompt: Whether to add generation prompt after messages. + continue_final_message: Whether to continue the final message. + session_id: Optional session ID for consistent routing via X-Session-ID header. + Needed for multimodal inputs where vLLM caches processed data on a specific backend. + + Returns: + List of [conversation, engine_prompts] where engine_prompts contains + dicts with "prompt" and "prompt_token_ids". + """ + session = await self._get_session() + url = f"{self.proxy_url}/v1/chat/completions/render" + + payload = { + "model": self.model_name, + "messages": messages, + "add_generation_prompt": add_generation_prompt, + "continue_final_message": continue_final_message, + } + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + async with session.post(url, json=payload, headers=headers) as resp: + response = await resp.json() + raise_for_status(resp, response) + return response + async def completion( self, request_payload: Dict[str, Any], diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py index b1486356db..e65a8ac216 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py @@ -32,6 +32,7 @@ from litellm import acompletion as litellm_async_completion from litellm import atext_completion as litellm_async_text_completion +from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE MODEL_QWEN2_5 = "Qwen/Qwen2.5-0.5B-Instruct" SERVED_MODEL_NAME = "my_qwen" @@ -585,3 +586,20 @@ def test_client_tokenize_detokenize_roundtrip(vllm_server: InferenceEngineState) decoded = asyncio.run(client.detokenize([token_ids]))[0] assert decoded == text + + +@pytest.mark.vllm +@pytest.mark.skipif(not _SKYRL_USE_NEW_INFERENCE, reason="Render API only supported with new inference client") +def test_client_render_chat_completion(vllm_server: InferenceEngineState): + """Test render_chat_completion via RemoteInferenceClient against real vLLM.""" + client = vllm_server.client + messages = [{"role": "user", "content": "Hello"}] + result = asyncio.run(client.render_chat_completion(messages=messages)) + # vLLM returns [conversation, engine_prompts] + assert isinstance(result, list) + assert len(result) == 2 + conversation, engine_prompts = result + # engine_prompts should have prompt_token_ids + assert len(engine_prompts) > 0 + assert "prompt_token_ids" in engine_prompts[0] + assert len(engine_prompts[0]["prompt_token_ids"]) > 0 diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index 87486b56b8..635cc4c6da 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -14,6 +14,7 @@ from skyrl.backends.skyrl_train.inference_servers.common import get_open_port from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient, PauseMode +from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE def create_mock_vllm_server(server_id: int) -> FastAPI: @@ -62,6 +63,15 @@ async def generate(request: Request): async def chat_completions(request: Request): return {"choices": [{"message": {"content": f"Chat from server {server_id}"}}]} + @app.post("/v1/chat/completions/render") + async def render_chat_completion(request: Request): + body = await request.json() + messages = body.get("messages", []) + return [ + messages, # conversation (echo back) + [{"prompt": "rendered prompt", "prompt_token_ids": [1, 2, 3]}], # engine_prompts + ] + @app.post("/tokenize") async def tokenize(request: Request): return {"tokens": [1, 2, 3]} @@ -243,6 +253,19 @@ async def test_completion(self, client): result = await client.completion(request_payload) assert "choices" in result + @pytest.mark.asyncio + @pytest.mark.skipif(not _SKYRL_USE_NEW_INFERENCE, reason="Render API only supported with new inference client") + async def test_render_chat_completion(self, client): + """Test render_chat_completion method.""" + messages = [{"role": "user", "content": "Hello"}] + result = await client.render_chat_completion(messages=messages) + assert isinstance(result, list) + assert len(result) == 2 + conversation, engine_prompts = result + assert conversation == messages + assert len(engine_prompts) > 0 + assert "prompt_token_ids" in engine_prompts[0] + @pytest.mark.asyncio async def test_tokenize(self, client): """Test tokenize method.""" From 713ada8ec5dbe2749c32cb009561061628c43917 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Fri, 6 Mar 2026 20:44:25 +0000 Subject: [PATCH 04/38] fix gpu save weights test --- .../gpu/gpu_ci/test_save_weights_for_sampler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py index c02a62cdb4..2f0af688a1 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py @@ -14,6 +14,7 @@ from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend from skyrl.train.utils.utils import validate_cfg from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch +from skyrl.utils.tok import get_tokenizer from tests.backends.skyrl_train.gpu.utils import ( get_test_prompts, @@ -76,6 +77,7 @@ def test_save_weights_for_sampler_then_inference(ray_init_fixture, colocate_all, tp_size=cfg.generator.inference_engine.tensor_parallel_size, colocate_all=cfg.trainer.placement.colocate_all, sleep_level=2, # Full sleep since we explicitly sync weights + gpu_memory_utilization=0.5 if colocate_all else None, ) as engines: client, pg = engines.client, engines.pg # Initialize policy worker @@ -119,7 +121,8 @@ def test_save_weights_for_sampler_then_inference(ray_init_fixture, colocate_all, sampling_params = get_sampling_params_for_backend( cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) - outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=5), sampling_params)) + tokenizer = get_tokenizer(MODEL) + outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=5), sampling_params, tokenizer=tokenizer)) # Verify we got responses assert "responses" in outputs, "Inference should return responses" @@ -184,5 +187,6 @@ def test_save_weights_for_sampler_multiple_training_steps(ray_init_fixture): sampling_params = get_sampling_params_for_backend( cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) - outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=2), sampling_params)) + tokenizer = get_tokenizer(MODEL) + outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=2), sampling_params, tokenizer=tokenizer)) assert len(outputs["responses"]) == 2, "Should get 2 responses" From 8ab564e30eabdd6ccfd4017267f717c2a36977e2 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Fri, 6 Mar 2026 13:57:26 -0800 Subject: [PATCH 05/38] add init of new inf backend --- skyrl/backends/skyrl_train_backend.py | 90 ++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 9 deletions(-) diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index a618233cb6..9a93fbc262 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -25,7 +25,7 @@ from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.train.utils.utils import initialize_ray, get_ray_pg_ready_with_timeout from skyrl.train.config import get_config_as_yaml_str -from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S +from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S, _SKYRL_USE_NEW_INFERENCE from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import ( create_ray_wrapped_inference_engines, ) @@ -116,6 +116,8 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides): self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model) self._inference_engine_client = None self._inference_engines_initialized = False + self._server_group = None + self._inference_router = None def has_model(self, model_id: str) -> bool: return self._model_id == model_id @@ -184,19 +186,89 @@ def init_weight_sync_state(self): self._dispatch.init_weight_sync_state(self._inference_engine_client) logger.info("Initialized weight sync state for policy model and inference engines.") + def _create_remote_inference_client(self): + """Create a RemoteInferenceClient using HTTP endpoints. + + Mirrors main_base.py._get_new_inference_client() with the same 4-way + branching on external_proxy_url / external_server_urls. + """ + from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient + from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter + from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup + from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args + + ie_cfg = self._cfg.generator.inference_engine + is_colocated = self._cfg.trainer.placement.colocate_all + external_proxy_url = ie_cfg.external_proxy_url + external_server_urls = ie_cfg.external_server_urls + + has_external_proxy = external_proxy_url is not None + has_external_servers = external_server_urls is not None + + if has_external_proxy and has_external_servers: + proxy_url = external_proxy_url + server_urls = list(external_server_urls) + logger.info( + f"HTTP Inference: Using fully external setup - proxy_url={proxy_url}, server_urls={server_urls}" + ) + + elif has_external_proxy and not has_external_servers: + proxy_url = external_proxy_url + server_urls = [proxy_url] + logger.info(f"HTTP Inference: Using external proxy for both data and control plane - proxy_url={proxy_url}") + + elif has_external_servers and not has_external_proxy: + server_urls = list(external_server_urls) + self._inference_router = InferenceRouter(server_urls=server_urls) + proxy_url = self._inference_router.start() + logger.info( + f"HTTP Inference: Created internal router over external " + f"servers - server_urls={server_urls}, proxy_url={proxy_url}" + ) + + else: + cli_args = build_vllm_cli_args(self._cfg) + + self._server_group = ServerGroup( + cli_args=cli_args, + num_servers=ie_cfg.num_engines, + placement_group=self._colocate_pg if is_colocated else None, + enable_dp=ie_cfg.data_parallel_size > 1, + ) + server_infos = self._server_group.start() + server_urls = [info.url for info in server_infos] + + self._inference_router = InferenceRouter(server_urls=server_urls) + proxy_url = self._inference_router.start() + logger.info( + f"HTTP Inference: Built servers and router internally - " + f"proxy_url={proxy_url}, server_urls={server_urls}, colocated={is_colocated}" + ) + + return RemoteInferenceClient( + proxy_url=proxy_url, + server_urls=server_urls, + model_name=self._cfg.trainer.policy.model.path, + ) + def _ensure_inference_engines(self): """Lazily create inference engines and init weight sync on first sampling-related call.""" if self._inference_engines_initialized: return - logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines") - self._inference_engine_client = InferenceEngineClient( - create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer), - self._tokenizer, - self._cfg.trainer.policy.model.path, - self._cfg.trainer.policy.model.lora, - self._cfg.generator.inference_engine, - ) + if _SKYRL_USE_NEW_INFERENCE: + logger.info("Using new HTTP-based inference client (RemoteInferenceClient)") + self._inference_engine_client = self._create_remote_inference_client() + else: + logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines") + self._inference_engine_client = InferenceEngineClient( + create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer), + self._tokenizer, + self._cfg.trainer.policy.model.path, + self._cfg.trainer.policy.model.lora, + self._cfg.generator.inference_engine, + ) + self._dispatch.set_inference_engine_client(self._inference_engine_client) self.init_weight_sync_state() self._inference_engines_initialized = True From 7d8e5f4a63a19dbb12c9c7a740f1fd5839ccafce Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Fri, 6 Mar 2026 18:53:44 -0800 Subject: [PATCH 06/38] add test for weight sync --- .../test_backend_weight_sync.py | 185 ++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py new file mode 100644 index 0000000000..cea7bba524 --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -0,0 +1,185 @@ +""" +GPU CI test for weight sync through the SkyRLTrainBackend API (new inference path). + +Uses colocated mode with 2 GPUs (TP=1, 2 engines, 2 FSDP2 workers): + - Backend creates FSDP2 workers with real weights from HF + - Inference servers start with dummy (random) weights via engine_init_kwargs + - save_sampler_checkpoint() broadcasts real training weights via NCCL + - Verified by querying the server before and after sync + +Run: + uv run --isolated --extra dev --extra fsdp pytest \ + tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py -v -s +""" + +import os +import time +from unittest import mock + +import httpx +import pytest +import ray +from functools import lru_cache +from loguru import logger + +from skyrl.train.utils.utils import peer_access_supported +from skyrl.env_vars import SKYRL_PYTHONPATH_EXPORT + + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +@lru_cache(5) +def log_once(msg): + logger.info(msg) + return None + + +def wait_for_url(url: str, timeout: float = 180.0) -> bool: + """Wait for a URL to become available.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(f"{url}/health", timeout=5.0) + if resp.status_code == 200: + return True + except httpx.RequestError: + time.sleep(2.0) + return False + + +@pytest.fixture(scope="class") +def ray_env_with_new_inference(): + """Ray init fixture with _SKYRL_USE_NEW_INFERENCE=1 in runtime env.""" + if ray.is_initialized(): + ray.shutdown() + + env_vars = { + "VLLM_USE_V1": "1", + "VLLM_ENABLE_V1_MULTIPROCESSING": "0", + "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", + "_SKYRL_USE_NEW_INFERENCE": "1", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NVTE_FUSED_ATTN": "0", + } + + if not peer_access_supported(max_num_gpus_per_node=2): + log_once("Disabling NCCL P2P for CI environment") + env_vars.update( + { + "NCCL_P2P_DISABLE": "1", + "NCCL_SHM_DISABLE": "1", + } + ) + + if SKYRL_PYTHONPATH_EXPORT: + pythonpath = os.environ.get("PYTHONPATH") + if pythonpath is None: + raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") + env_vars["PYTHONPATH"] = pythonpath + + logger.info(f"Initializing Ray with environment variables: {env_vars}") + ray.init(runtime_env={"env_vars": env_vars}) + + yield + + ray.shutdown() + + +@pytest.mark.asyncio(loop_scope="class") +class TestBackendWeightSync: + """Test weight sync through SkyRLTrainBackend with new inference path (colocated).""" + + async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): + """ + End-to-end colocated weight sync test via SkyRLTrainBackend: + + 1. Create backend with 2 FSDP2 workers (real weights from HF) + 2. Start 2 inference servers with dummy (random) weights + 3. Verify dummy weights produce gibberish + 4. Run save_sampler_checkpoint() to broadcast real weights via NCCL + 5. Verify real weights produce correct output + """ + from skyrl.backends.skyrl_train_backend import ( + SkyRLTrainBackend, + FSDPBackendOverrides, + ) + from skyrl.tinker.types import LoraConfig + + print("\n[TEST] Setting up SkyRLTrainBackend (colocated, 2 GPU)") + + # ===== Step 1: Create backend ===== + overrides = { + "trainer.placement.colocate_all": True, + "trainer.placement.policy_num_gpus_per_node": 2, + "trainer.placement.policy_num_nodes": 1, + "trainer.logger": "console", + "generator.inference_engine.tensor_parallel_size": 1, + "generator.inference_engine.num_engines": 2, + "generator.inference_engine.gpu_memory_utilization": 0.5, + "generator.inference_engine.async_engine": True, + } + backend = SkyRLTrainBackend(MODEL, FSDPBackendOverrides(**overrides)) + + # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== + model_id = "test-model" + backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) + print("[Step 2] Model created with real weights (FSDP2, 2 workers)") + + # ===== Step 3: Inject dummy weight config before inference engine creation ===== + backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} + print("[Step 3] Injected load_format=dummy into engine_init_kwargs") + + # ===== Step 4: Create inference engines with dummy weights ===== + with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): + backend._ensure_inference_engines() + print("[Step 4] Inference engines created (dummy weights)") + + # Wait for servers to be healthy + server_urls = backend._server_group.get_server_urls() + assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + print(f"[Step 4] Servers healthy: {server_urls}") + + try: + # ===== Step 5: Verify dummy weights produce gibberish ===== + payload = { + "model": MODEL, + "prompt": "What is the capital of France?", + "max_tokens": 32, + "temperature": 0.0, + } + + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: + resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) + assert resp.status_code == 200, f"Completions request failed: {resp.text}" + + text_before = resp.json()["choices"][0]["text"] + print(f"[Step 5] Dummy weights output: {text_before!r}") + assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" + + # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== + await backend._inference_engine_client.sleep() + print("[Step 6] Inference engines sleeping") + + # ===== Step 7: Sync weights via save_sampler_checkpoint ===== + with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): + backend.save_sampler_checkpoint("/tmp/test_backend_weight_sync.tar", model_id, persist=False) + print("[Step 7] save_sampler_checkpoint() completed (NCCL broadcast done)") + + # ===== Step 8: Verify real weights produce correct output ===== + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: + resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) + assert resp.status_code == 200, f"Completions request failed: {resp.text}" + + text_after = resp.json()["choices"][0]["text"] + print(f"[Step 8] Real weights output: {text_after!r}") + assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + + print("[SUCCESS] Backend weight sync test passed!") + + finally: + # Cleanup: teardown inference client session + if backend._inference_engine_client is not None: + await backend._inference_engine_client.teardown() From c7e4d525363835c64ce180d68eed0f094bc256f7 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 17:31:13 +0000 Subject: [PATCH 07/38] remove test --- .../test_backend_weight_sync.py | 185 ------------------ 1 file changed, 185 deletions(-) delete mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py deleted file mode 100644 index cea7bba524..0000000000 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ /dev/null @@ -1,185 +0,0 @@ -""" -GPU CI test for weight sync through the SkyRLTrainBackend API (new inference path). - -Uses colocated mode with 2 GPUs (TP=1, 2 engines, 2 FSDP2 workers): - - Backend creates FSDP2 workers with real weights from HF - - Inference servers start with dummy (random) weights via engine_init_kwargs - - save_sampler_checkpoint() broadcasts real training weights via NCCL - - Verified by querying the server before and after sync - -Run: - uv run --isolated --extra dev --extra fsdp pytest \ - tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py -v -s -""" - -import os -import time -from unittest import mock - -import httpx -import pytest -import ray -from functools import lru_cache -from loguru import logger - -from skyrl.train.utils.utils import peer_access_supported -from skyrl.env_vars import SKYRL_PYTHONPATH_EXPORT - - -MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - -@lru_cache(5) -def log_once(msg): - logger.info(msg) - return None - - -def wait_for_url(url: str, timeout: float = 180.0) -> bool: - """Wait for a URL to become available.""" - start = time.time() - while time.time() - start < timeout: - try: - resp = httpx.get(f"{url}/health", timeout=5.0) - if resp.status_code == 200: - return True - except httpx.RequestError: - time.sleep(2.0) - return False - - -@pytest.fixture(scope="class") -def ray_env_with_new_inference(): - """Ray init fixture with _SKYRL_USE_NEW_INFERENCE=1 in runtime env.""" - if ray.is_initialized(): - ray.shutdown() - - env_vars = { - "VLLM_USE_V1": "1", - "VLLM_ENABLE_V1_MULTIPROCESSING": "0", - "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", - "_SKYRL_USE_NEW_INFERENCE": "1", - "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "NVTE_FUSED_ATTN": "0", - } - - if not peer_access_supported(max_num_gpus_per_node=2): - log_once("Disabling NCCL P2P for CI environment") - env_vars.update( - { - "NCCL_P2P_DISABLE": "1", - "NCCL_SHM_DISABLE": "1", - } - ) - - if SKYRL_PYTHONPATH_EXPORT: - pythonpath = os.environ.get("PYTHONPATH") - if pythonpath is None: - raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") - env_vars["PYTHONPATH"] = pythonpath - - logger.info(f"Initializing Ray with environment variables: {env_vars}") - ray.init(runtime_env={"env_vars": env_vars}) - - yield - - ray.shutdown() - - -@pytest.mark.asyncio(loop_scope="class") -class TestBackendWeightSync: - """Test weight sync through SkyRLTrainBackend with new inference path (colocated).""" - - async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): - """ - End-to-end colocated weight sync test via SkyRLTrainBackend: - - 1. Create backend with 2 FSDP2 workers (real weights from HF) - 2. Start 2 inference servers with dummy (random) weights - 3. Verify dummy weights produce gibberish - 4. Run save_sampler_checkpoint() to broadcast real weights via NCCL - 5. Verify real weights produce correct output - """ - from skyrl.backends.skyrl_train_backend import ( - SkyRLTrainBackend, - FSDPBackendOverrides, - ) - from skyrl.tinker.types import LoraConfig - - print("\n[TEST] Setting up SkyRLTrainBackend (colocated, 2 GPU)") - - # ===== Step 1: Create backend ===== - overrides = { - "trainer.placement.colocate_all": True, - "trainer.placement.policy_num_gpus_per_node": 2, - "trainer.placement.policy_num_nodes": 1, - "trainer.logger": "console", - "generator.inference_engine.tensor_parallel_size": 1, - "generator.inference_engine.num_engines": 2, - "generator.inference_engine.gpu_memory_utilization": 0.5, - "generator.inference_engine.async_engine": True, - } - backend = SkyRLTrainBackend(MODEL, FSDPBackendOverrides(**overrides)) - - # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== - model_id = "test-model" - backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) - print("[Step 2] Model created with real weights (FSDP2, 2 workers)") - - # ===== Step 3: Inject dummy weight config before inference engine creation ===== - backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} - print("[Step 3] Injected load_format=dummy into engine_init_kwargs") - - # ===== Step 4: Create inference engines with dummy weights ===== - with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): - backend._ensure_inference_engines() - print("[Step 4] Inference engines created (dummy weights)") - - # Wait for servers to be healthy - server_urls = backend._server_group.get_server_urls() - assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" - for url in server_urls: - assert wait_for_url(url), f"Server {url} failed to start" - print(f"[Step 4] Servers healthy: {server_urls}") - - try: - # ===== Step 5: Verify dummy weights produce gibberish ===== - payload = { - "model": MODEL, - "prompt": "What is the capital of France?", - "max_tokens": 32, - "temperature": 0.0, - } - - async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: - resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) - assert resp.status_code == 200, f"Completions request failed: {resp.text}" - - text_before = resp.json()["choices"][0]["text"] - print(f"[Step 5] Dummy weights output: {text_before!r}") - assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" - - # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== - await backend._inference_engine_client.sleep() - print("[Step 6] Inference engines sleeping") - - # ===== Step 7: Sync weights via save_sampler_checkpoint ===== - with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): - backend.save_sampler_checkpoint("/tmp/test_backend_weight_sync.tar", model_id, persist=False) - print("[Step 7] save_sampler_checkpoint() completed (NCCL broadcast done)") - - # ===== Step 8: Verify real weights produce correct output ===== - async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: - resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) - assert resp.status_code == 200, f"Completions request failed: {resp.text}" - - text_after = resp.json()["choices"][0]["text"] - print(f"[Step 8] Real weights output: {text_after!r}") - assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - - print("[SUCCESS] Backend weight sync test passed!") - - finally: - # Cleanup: teardown inference client session - if backend._inference_engine_client is not None: - await backend._inference_engine_client.teardown() From 5ec91b7d9b9ad8fef350f828037d3018e7789290 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 18:06:08 +0000 Subject: [PATCH 08/38] add inference backend weight sync test --- .../test_backend_weight_sync.py | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py new file mode 100644 index 0000000000..169f3d317d --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -0,0 +1,188 @@ +""" +GPU CI test for weight sync through the SkyRLTrainBackend API (new inference path). + +Uses colocated mode with 2 GPUs (TP=1, 2 engines, 2 FSDP2 workers): + - Backend creates FSDP2 workers with real weights from HF + - Inference servers start with dummy (random) weights via engine_init_kwargs + - save_sampler_checkpoint() broadcasts real training weights via NCCL + - Verified by querying the server before and after sync + +Run: + uv run --isolated --extra dev --extra fsdp pytest \ + tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py -v -s +""" +ßßß +import os +import time +from unittest import mock + +import httpx +import pytest +import ray +from functools import lru_cache +from loguru import logger + +from skyrl.train.utils.utils import peer_access_supported +from skyrl.env_vars import SKYRL_PYTHONPATH_EXPORT + + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +@lru_cache(5) +def log_once(msg): + logger.info(msg) + return None + + +def wait_for_url(url: str, timeout: float = 180.0) -> bool: + """Wait for a URL to become available.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(f"{url}/health", timeout=5.0) + if resp.status_code == 200: + return True + except httpx.RequestError: + time.sleep(2.0) + return False + + +@pytest.fixture(scope="class") +def ray_env_with_new_inference(): + """Ray init fixture with _SKYRL_USE_NEW_INFERENCE=1 in runtime env.""" + if ray.is_initialized(): + ray.shutdown() + + env_vars = { + "VLLM_USE_V1": "1", + "VLLM_ENABLE_V1_MULTIPROCESSING": "0", + "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", + "_SKYRL_USE_NEW_INFERENCE": "1", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NVTE_FUSED_ATTN": "0", + } + + if not peer_access_supported(max_num_gpus_per_node=2): + log_once("Disabling NCCL P2P for CI environment") + env_vars.update( + { + "NCCL_P2P_DISABLE": "1", + "NCCL_SHM_DISABLE": "1", + } + ) + + if SKYRL_PYTHONPATH_EXPORT: + pythonpath = os.environ.get("PYTHONPATH") + if pythonpath is None: + raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") + env_vars["PYTHONPATH"] = pythonpath + + logger.info(f"Initializing Ray with environment variables: {env_vars}") + ray.init(runtime_env={"env_vars": env_vars}) + + yield + + ray.shutdown() + + +@pytest.mark.asyncio(loop_scope="class") +class TestBackendWeightSync: + """Test weight sync through SkyRLTrainBackend with new inference path (colocated).""" + + async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): + """ + End-to-end colocated weight sync test via SkyRLTrainBackend: + + 1. Create backend with 2 FSDP2 workers (real weights from HF) + 2. Start 2 inference servers with dummy (random) weights + 3. Verify dummy weights produce gibberish + 4. Run save_sampler_checkpoint() to broadcast real weights via NCCL + 5. Verify real weights produce correct output + """ + from skyrl.backends.skyrl_train_backend import ( + SkyRLTrainBackend, + FSDPBackendOverrides, + ) + from skyrl.tinker.types import LoraConfig + + print("\n[TEST] Setting up SkyRLTrainBackend (colocated, 2 GPU)") + + # ===== Step 1: Create backend ===== + overrides = { + "trainer.placement.colocate_all": True, + "trainer.placement.policy_num_gpus_per_node": 2, + "trainer.placement.policy_num_nodes": 1, + "trainer.logger": "console", + "generator.inference_engine.tensor_parallel_size": 1, + "generator.inference_engine.num_engines": 2, + "generator.inference_engine.gpu_memory_utilization": 0.5, + "generator.inference_engine.async_engine": True, + } + backend = SkyRLTrainBackend(MODEL, FSDPBackendOverrides(**overrides)) + + # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== + model_id = "test-model" + backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) + print("[Step 2] Model created with real weights (FSDP2, 2 workers)") + + # ===== Step 3: Inject dummy weight config before inference engine creation ===== + backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} + print("[Step 3] Injected load_format=dummy into engine_init_kwargs") + + # ===== Step 4: Create inference engines with dummy weights ===== + with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): + backend._ensure_inference_engines() + print("[Step 4] Inference engines created (dummy weights)") + + # Wait for servers to be healthy + server_urls = backend._server_group.get_server_urls() + assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + print(f"[Step 4] Servers healthy: {server_urls}") + + try: + # ===== Step 5: Verify dummy weights produce gibberish ===== + payload = { + "model": MODEL, + "prompt": "What is the capital of France?", + "max_tokens": 32, + "temperature": 0.0, + } + + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: + resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) + assert resp.status_code == 200, f"Completions request failed: {resp.text}" + + text_before = resp.json()["choices"][0]["text"] + print(f"[Step 5] Dummy weights output: {text_before!r}") + assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" + + # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== + await backend._inference_engine_client.sleep() + print("[Step 6] Inference engines sleeping") + + # ===== Step 7: Sync weights via save_sampler_checkpoint ===== + with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): + backend.save_sampler_checkpoint("/tmp/test_backend_weight_sync.tar", model_id, persist=False) + backend._validate_model_state(model_id) + backend._ensure_inference_engines() + await backend._dispatch.save_weights_for_sampler() + print("[Step 7] save_sampler_checkpoint() completed (NCCL broadcast done)") + + # ===== Step 8: Verify real weights produce correct output ===== + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: + resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) + assert resp.status_code == 200, f"Completions request failed: {resp.text}" + + text_after = resp.json()["choices"][0]["text"] + print(f"[Step 8] Real weights output: {text_after!r}") + assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + + print("[SUCCESS] Backend weight sync test passed!") + + finally: + # Cleanup: teardown inference client session + if backend._inference_engine_client is not None: + await backend._inference_engine_client.teardown() \ No newline at end of file From f353605e4b63dbe18eee578b5bfa0ca5e07b6c70 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 18:12:09 +0000 Subject: [PATCH 09/38] add test back in --- .../gpu/gpu_ci/inference_servers/test_backend_weight_sync.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index 169f3d317d..cc5cd43681 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -11,7 +11,7 @@ uv run --isolated --extra dev --extra fsdp pytest \ tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py -v -s """ -ßßß + import os import time from unittest import mock @@ -165,7 +165,6 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): # ===== Step 7: Sync weights via save_sampler_checkpoint ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): - backend.save_sampler_checkpoint("/tmp/test_backend_weight_sync.tar", model_id, persist=False) backend._validate_model_state(model_id) backend._ensure_inference_engines() await backend._dispatch.save_weights_for_sampler() From 652688713dc653a9fc8cee8549fc7d13f0b48f59 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 21:19:31 +0000 Subject: [PATCH 10/38] colocate all off --- .../gpu/gpu_ci/inference_servers/test_backend_weight_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index cc5cd43681..a4c0ceeb12 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -110,7 +110,7 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): # ===== Step 1: Create backend ===== overrides = { - "trainer.placement.colocate_all": True, + "trainer.placement.colocate_all": False, "trainer.placement.policy_num_gpus_per_node": 2, "trainer.placement.policy_num_nodes": 1, "trainer.logger": "console", From 2c0b47fdab5830de2066f84c3d9cdac1ccb45a54 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 22:05:11 +0000 Subject: [PATCH 11/38] port fix --- .../inference_servers/vllm_server_actor.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py index c14099118f..25f78dfa51 100644 --- a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py @@ -6,6 +6,7 @@ import logging import os import pickle +import socket import time from argparse import Namespace from typing import Any, Dict, Optional, Tuple @@ -28,7 +29,7 @@ SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) -from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port +from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, get_node_ip from skyrl.backends.skyrl_train.inference_servers.protocols import ServerActorProtocol from skyrl.backends.skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS @@ -96,7 +97,7 @@ def __init__( """ self._cli_args = vllm_cli_args self._ip = get_node_ip() - self._port = get_open_port(start_port) + self._port, self._port_reservation = self._find_and_reserve_port(start_port) self._server_idx = server_idx self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) @@ -148,6 +149,30 @@ def __init__( self._engine: Optional[AsyncLLMEngine] = None self._server_task: Optional[asyncio.Task] = None + @staticmethod + def _find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: + """Find an available port and hold the socket to prevent TOCTOU races. + + Unlike get_open_port() which tests-then-releases, this keeps the socket + bound so no other process can claim the same port between discovery and + actual server startup. + + Returns: + (port, socket) — caller must close the socket before rebinding. + """ + port = start_port + while True: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("", port)) + sock.listen(1) + return port, sock + except OSError: + port += 1 + if port > 65535: + raise RuntimeError(f"No available port found starting from {start_port}") + def _ensure_worker_extension(self) -> None: """ Ensure the SkyRL worker extension is configured. @@ -267,6 +292,12 @@ async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_ async def _run_server(self) -> None: """Internal method to run the HTTP server.""" + # Release the port reservation right before vLLM rebinds. + # SO_REUSEADDR on both sockets makes the hand-off atomic. + if self._port_reservation is not None: + self._port_reservation.close() + self._port_reservation = None + sock_addr = (self._cli_args.host, self._cli_args.port) sock = create_server_socket(sock_addr) app = build_app(self._cli_args) From 5b2fa340c88fb44e92820ad48cd5ed960427e0f6 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 22:16:29 +0000 Subject: [PATCH 12/38] move port allocation to shared commmoon --- .../skyrl_train/inference_servers/common.py | 26 +++++++++++++++++ .../inference_servers/vllm_server_actor.py | 29 ++----------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/common.py b/skyrl/backends/skyrl_train/inference_servers/common.py index 17ae4bb36e..0a5d0bbc4d 100644 --- a/skyrl/backends/skyrl_train/inference_servers/common.py +++ b/skyrl/backends/skyrl_train/inference_servers/common.py @@ -7,6 +7,7 @@ import logging import socket from dataclasses import dataclass +from typing import Tuple import ray @@ -72,3 +73,28 @@ def get_open_port(start_port: int | None = None) -> int: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("", 0)) return s.getsockname()[1] + + +def find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: + """Find an available port and hold the socket to prevent TOCTOU races. + + Unlike get_open_port() which tests-then-releases, this keeps the socket + bound so no other process can claim the same port between discovery and + actual server startup. Use SO_REUSEADDR on both this socket and the + eventual server socket for a seamless hand-off. + + Returns: + (port, socket) -- caller must close the socket before rebinding. + """ + port = start_port + while True: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("", port)) + sock.listen(1) + return port, sock + except OSError: + port += 1 + if port > 65535: + raise RuntimeError(f"No available port found starting from {start_port}") diff --git a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py index 25f78dfa51..ab1d30ac69 100644 --- a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py @@ -6,7 +6,6 @@ import logging import os import pickle -import socket import time from argparse import Namespace from typing import Any, Dict, Optional, Tuple @@ -29,7 +28,7 @@ SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) -from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, get_node_ip +from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, find_and_reserve_port, get_node_ip from skyrl.backends.skyrl_train.inference_servers.protocols import ServerActorProtocol from skyrl.backends.skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS @@ -97,7 +96,7 @@ def __init__( """ self._cli_args = vllm_cli_args self._ip = get_node_ip() - self._port, self._port_reservation = self._find_and_reserve_port(start_port) + self._port, self._port_reservation = find_and_reserve_port(start_port) self._server_idx = server_idx self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) @@ -149,30 +148,6 @@ def __init__( self._engine: Optional[AsyncLLMEngine] = None self._server_task: Optional[asyncio.Task] = None - @staticmethod - def _find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: - """Find an available port and hold the socket to prevent TOCTOU races. - - Unlike get_open_port() which tests-then-releases, this keeps the socket - bound so no other process can claim the same port between discovery and - actual server startup. - - Returns: - (port, socket) — caller must close the socket before rebinding. - """ - port = start_port - while True: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("", port)) - sock.listen(1) - return port, sock - except OSError: - port += 1 - if port > 65535: - raise RuntimeError(f"No available port found starting from {start_port}") - def _ensure_worker_extension(self) -> None: """ Ensure the SkyRL worker extension is configured. From 29125ab950e90048303872afd30f7880690c6494 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 22:25:30 +0000 Subject: [PATCH 13/38] add tests for port reservation --- .../inference_servers/test_common.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/backends/skyrl_train/inference_servers/test_common.py b/tests/backends/skyrl_train/inference_servers/test_common.py index dd604147b0..6621bc3e20 100644 --- a/tests/backends/skyrl_train/inference_servers/test_common.py +++ b/tests/backends/skyrl_train/inference_servers/test_common.py @@ -3,6 +3,7 @@ import socket from skyrl.backends.skyrl_train.inference_servers.common import ( + find_and_reserve_port, get_node_ip, get_open_port, ) @@ -35,3 +36,79 @@ def _verify_port_is_free(self, port: int): s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", port)) s.listen(1) + + +def _occupy_port(port: int) -> socket.socket: + """Bind+listen on *port* to simulate another service (e.g. Tinker API).""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", port)) + sock.listen(1) + return sock + + +class TestFindAndReservePort: + """ + get_open_port() probes-then-releases, so concurrent callers + could both claim the same port. find_and_reserve_port() holds the socket + open, forcing subsequent callers to skip to the next port. + """ + + def test_sequential_reservations_are_unique(self): + port_a, sock_a = find_and_reserve_port(15000) + try: + port_b, sock_b = find_and_reserve_port(15000) + try: + assert port_a != port_b, f"Duplicate port: {port_a}" + finally: + sock_b.close() + finally: + sock_a.close() + + def test_occupied_base_port_is_skipped(self): + """If the base port is already taken, the reservation must pick a higher port.""" + base = get_open_port() + blocker = _occupy_port(base) + try: + port, sock = find_and_reserve_port(base) + try: + assert port != base, f"Reserved the occupied port {base}" + assert port > base + finally: + sock.close() + finally: + blocker.close() + + def test_overlapping_ranges_no_collision(self): + """When base port N is occupied, reserving from N and N+1 must + yield different ports even though both scan through N+1.""" + base = get_open_port() + blocker = _occupy_port(base) + try: + port_0, sock_0 = find_and_reserve_port(base) + try: + port_1, sock_1 = find_and_reserve_port(base + 1) + try: + assert port_0 != port_1, f"Port collision: both got {port_0}" + finally: + sock_1.close() + finally: + sock_0.close() + finally: + blocker.close() + + def test_many_reservations_all_unique(self): + base = get_open_port() + blocker = _occupy_port(base) + sockets = [] + try: + for _ in range(4): + port, sock = find_and_reserve_port(base) + sockets.append((port, sock)) + + ports = [p for p, _ in sockets] + assert len(set(ports)) == len(ports), f"Duplicate among {ports}" + assert base not in ports + finally: + for _, sock in sockets: + sock.close() + blocker.close() From 3f071647fe085c385c2420b275c95072d855206b Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 22:29:28 +0000 Subject: [PATCH 14/38] remove print statements --- .../inference_servers/test_backend_weight_sync.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index a4c0ceeb12..f749ff9fbb 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -106,7 +106,6 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): ) from skyrl.tinker.types import LoraConfig - print("\n[TEST] Setting up SkyRLTrainBackend (colocated, 2 GPU)") # ===== Step 1: Create backend ===== overrides = { @@ -124,23 +123,19 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== model_id = "test-model" backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) - print("[Step 2] Model created with real weights (FSDP2, 2 workers)") # ===== Step 3: Inject dummy weight config before inference engine creation ===== backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} - print("[Step 3] Injected load_format=dummy into engine_init_kwargs") # ===== Step 4: Create inference engines with dummy weights ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): backend._ensure_inference_engines() - print("[Step 4] Inference engines created (dummy weights)") # Wait for servers to be healthy server_urls = backend._server_group.get_server_urls() assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" for url in server_urls: assert wait_for_url(url), f"Server {url} failed to start" - print(f"[Step 4] Servers healthy: {server_urls}") try: # ===== Step 5: Verify dummy weights produce gibberish ===== @@ -156,19 +151,16 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): assert resp.status_code == 200, f"Completions request failed: {resp.text}" text_before = resp.json()["choices"][0]["text"] - print(f"[Step 5] Dummy weights output: {text_before!r}") assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== await backend._inference_engine_client.sleep() - print("[Step 6] Inference engines sleeping") # ===== Step 7: Sync weights via save_sampler_checkpoint ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): backend._validate_model_state(model_id) backend._ensure_inference_engines() await backend._dispatch.save_weights_for_sampler() - print("[Step 7] save_sampler_checkpoint() completed (NCCL broadcast done)") # ===== Step 8: Verify real weights produce correct output ===== async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: @@ -176,10 +168,8 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): assert resp.status_code == 200, f"Completions request failed: {resp.text}" text_after = resp.json()["choices"][0]["text"] - print(f"[Step 8] Real weights output: {text_after!r}") assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - print("[SUCCESS] Backend weight sync test passed!") finally: # Cleanup: teardown inference client session From c6fbfded095325a10d2ad81c1417c9dc8067dc19 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:36:05 -0700 Subject: [PATCH 15/38] Revert "remove print statements" This reverts commit 3f071647fe085c385c2420b275c95072d855206b. --- .../inference_servers/test_backend_weight_sync.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index f749ff9fbb..a4c0ceeb12 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -106,6 +106,7 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): ) from skyrl.tinker.types import LoraConfig + print("\n[TEST] Setting up SkyRLTrainBackend (colocated, 2 GPU)") # ===== Step 1: Create backend ===== overrides = { @@ -123,19 +124,23 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== model_id = "test-model" backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) + print("[Step 2] Model created with real weights (FSDP2, 2 workers)") # ===== Step 3: Inject dummy weight config before inference engine creation ===== backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} + print("[Step 3] Injected load_format=dummy into engine_init_kwargs") # ===== Step 4: Create inference engines with dummy weights ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): backend._ensure_inference_engines() + print("[Step 4] Inference engines created (dummy weights)") # Wait for servers to be healthy server_urls = backend._server_group.get_server_urls() assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" for url in server_urls: assert wait_for_url(url), f"Server {url} failed to start" + print(f"[Step 4] Servers healthy: {server_urls}") try: # ===== Step 5: Verify dummy weights produce gibberish ===== @@ -151,16 +156,19 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): assert resp.status_code == 200, f"Completions request failed: {resp.text}" text_before = resp.json()["choices"][0]["text"] + print(f"[Step 5] Dummy weights output: {text_before!r}") assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== await backend._inference_engine_client.sleep() + print("[Step 6] Inference engines sleeping") # ===== Step 7: Sync weights via save_sampler_checkpoint ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): backend._validate_model_state(model_id) backend._ensure_inference_engines() await backend._dispatch.save_weights_for_sampler() + print("[Step 7] save_sampler_checkpoint() completed (NCCL broadcast done)") # ===== Step 8: Verify real weights produce correct output ===== async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: @@ -168,8 +176,10 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): assert resp.status_code == 200, f"Completions request failed: {resp.text}" text_after = resp.json()["choices"][0]["text"] + print(f"[Step 8] Real weights output: {text_after!r}") assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + print("[SUCCESS] Backend weight sync test passed!") finally: # Cleanup: teardown inference client session From c950ced5deffbf07b003ca9fc2cb18067fc10458 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:36:46 -0700 Subject: [PATCH 16/38] Revert "add tests for port reservation" This reverts commit 29125ab950e90048303872afd30f7880690c6494. --- .../inference_servers/test_common.py | 77 ------------------- 1 file changed, 77 deletions(-) diff --git a/tests/backends/skyrl_train/inference_servers/test_common.py b/tests/backends/skyrl_train/inference_servers/test_common.py index 6621bc3e20..dd604147b0 100644 --- a/tests/backends/skyrl_train/inference_servers/test_common.py +++ b/tests/backends/skyrl_train/inference_servers/test_common.py @@ -3,7 +3,6 @@ import socket from skyrl.backends.skyrl_train.inference_servers.common import ( - find_and_reserve_port, get_node_ip, get_open_port, ) @@ -36,79 +35,3 @@ def _verify_port_is_free(self, port: int): s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", port)) s.listen(1) - - -def _occupy_port(port: int) -> socket.socket: - """Bind+listen on *port* to simulate another service (e.g. Tinker API).""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(("", port)) - sock.listen(1) - return sock - - -class TestFindAndReservePort: - """ - get_open_port() probes-then-releases, so concurrent callers - could both claim the same port. find_and_reserve_port() holds the socket - open, forcing subsequent callers to skip to the next port. - """ - - def test_sequential_reservations_are_unique(self): - port_a, sock_a = find_and_reserve_port(15000) - try: - port_b, sock_b = find_and_reserve_port(15000) - try: - assert port_a != port_b, f"Duplicate port: {port_a}" - finally: - sock_b.close() - finally: - sock_a.close() - - def test_occupied_base_port_is_skipped(self): - """If the base port is already taken, the reservation must pick a higher port.""" - base = get_open_port() - blocker = _occupy_port(base) - try: - port, sock = find_and_reserve_port(base) - try: - assert port != base, f"Reserved the occupied port {base}" - assert port > base - finally: - sock.close() - finally: - blocker.close() - - def test_overlapping_ranges_no_collision(self): - """When base port N is occupied, reserving from N and N+1 must - yield different ports even though both scan through N+1.""" - base = get_open_port() - blocker = _occupy_port(base) - try: - port_0, sock_0 = find_and_reserve_port(base) - try: - port_1, sock_1 = find_and_reserve_port(base + 1) - try: - assert port_0 != port_1, f"Port collision: both got {port_0}" - finally: - sock_1.close() - finally: - sock_0.close() - finally: - blocker.close() - - def test_many_reservations_all_unique(self): - base = get_open_port() - blocker = _occupy_port(base) - sockets = [] - try: - for _ in range(4): - port, sock = find_and_reserve_port(base) - sockets.append((port, sock)) - - ports = [p for p, _ in sockets] - assert len(set(ports)) == len(ports), f"Duplicate among {ports}" - assert base not in ports - finally: - for _, sock in sockets: - sock.close() - blocker.close() From 26b7abfb6f03288c42fa39270f7d3387d1fe56da Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:37:16 -0700 Subject: [PATCH 17/38] Revert "move port allocation to shared commmoon" This reverts commit 5b2fa340c88fb44e92820ad48cd5ed960427e0f6. --- .../skyrl_train/inference_servers/common.py | 26 ----------------- .../inference_servers/vllm_server_actor.py | 29 +++++++++++++++++-- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/common.py b/skyrl/backends/skyrl_train/inference_servers/common.py index 0a5d0bbc4d..17ae4bb36e 100644 --- a/skyrl/backends/skyrl_train/inference_servers/common.py +++ b/skyrl/backends/skyrl_train/inference_servers/common.py @@ -7,7 +7,6 @@ import logging import socket from dataclasses import dataclass -from typing import Tuple import ray @@ -73,28 +72,3 @@ def get_open_port(start_port: int | None = None) -> int: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("", 0)) return s.getsockname()[1] - - -def find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: - """Find an available port and hold the socket to prevent TOCTOU races. - - Unlike get_open_port() which tests-then-releases, this keeps the socket - bound so no other process can claim the same port between discovery and - actual server startup. Use SO_REUSEADDR on both this socket and the - eventual server socket for a seamless hand-off. - - Returns: - (port, socket) -- caller must close the socket before rebinding. - """ - port = start_port - while True: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("", port)) - sock.listen(1) - return port, sock - except OSError: - port += 1 - if port > 65535: - raise RuntimeError(f"No available port found starting from {start_port}") diff --git a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py index ab1d30ac69..25f78dfa51 100644 --- a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py @@ -6,6 +6,7 @@ import logging import os import pickle +import socket import time from argparse import Namespace from typing import Any, Dict, Optional, Tuple @@ -28,7 +29,7 @@ SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) -from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, find_and_reserve_port, get_node_ip +from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, get_node_ip from skyrl.backends.skyrl_train.inference_servers.protocols import ServerActorProtocol from skyrl.backends.skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS @@ -96,7 +97,7 @@ def __init__( """ self._cli_args = vllm_cli_args self._ip = get_node_ip() - self._port, self._port_reservation = find_and_reserve_port(start_port) + self._port, self._port_reservation = self._find_and_reserve_port(start_port) self._server_idx = server_idx self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) @@ -148,6 +149,30 @@ def __init__( self._engine: Optional[AsyncLLMEngine] = None self._server_task: Optional[asyncio.Task] = None + @staticmethod + def _find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: + """Find an available port and hold the socket to prevent TOCTOU races. + + Unlike get_open_port() which tests-then-releases, this keeps the socket + bound so no other process can claim the same port between discovery and + actual server startup. + + Returns: + (port, socket) — caller must close the socket before rebinding. + """ + port = start_port + while True: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("", port)) + sock.listen(1) + return port, sock + except OSError: + port += 1 + if port > 65535: + raise RuntimeError(f"No available port found starting from {start_port}") + def _ensure_worker_extension(self) -> None: """ Ensure the SkyRL worker extension is configured. From 516ae5ea7e45d619a0bf56e96900019c8d07c08a Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:37:39 -0700 Subject: [PATCH 18/38] Revert "port fix" This reverts commit 2c0b47fdab5830de2066f84c3d9cdac1ccb45a54. --- .../inference_servers/vllm_server_actor.py | 35 ++----------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py index 25f78dfa51..c14099118f 100644 --- a/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl/backends/skyrl_train/inference_servers/vllm_server_actor.py @@ -6,7 +6,6 @@ import logging import os import pickle -import socket import time from argparse import Namespace from typing import Any, Dict, Optional, Tuple @@ -29,7 +28,7 @@ SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) -from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, get_node_ip +from skyrl.backends.skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port from skyrl.backends.skyrl_train.inference_servers.protocols import ServerActorProtocol from skyrl.backends.skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS @@ -97,7 +96,7 @@ def __init__( """ self._cli_args = vllm_cli_args self._ip = get_node_ip() - self._port, self._port_reservation = self._find_and_reserve_port(start_port) + self._port = get_open_port(start_port) self._server_idx = server_idx self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) @@ -149,30 +148,6 @@ def __init__( self._engine: Optional[AsyncLLMEngine] = None self._server_task: Optional[asyncio.Task] = None - @staticmethod - def _find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: - """Find an available port and hold the socket to prevent TOCTOU races. - - Unlike get_open_port() which tests-then-releases, this keeps the socket - bound so no other process can claim the same port between discovery and - actual server startup. - - Returns: - (port, socket) — caller must close the socket before rebinding. - """ - port = start_port - while True: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("", port)) - sock.listen(1) - return port, sock - except OSError: - port += 1 - if port > 65535: - raise RuntimeError(f"No available port found starting from {start_port}") - def _ensure_worker_extension(self) -> None: """ Ensure the SkyRL worker extension is configured. @@ -292,12 +267,6 @@ async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_ async def _run_server(self) -> None: """Internal method to run the HTTP server.""" - # Release the port reservation right before vLLM rebinds. - # SO_REUSEADDR on both sockets makes the hand-off atomic. - if self._port_reservation is not None: - self._port_reservation.close() - self._port_reservation = None - sock_addr = (self._cli_args.host, self._cli_args.port) sock = create_server_socket(sock_addr) app = build_app(self._cli_args) From 281ba77f21a3e1b9d4c0111269cc2a8de4dd934a Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:44:28 -0700 Subject: [PATCH 19/38] remove print statements --- .../inference_servers/test_backend_weight_sync.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index a4c0ceeb12..77c1171ff5 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -106,8 +106,6 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): ) from skyrl.tinker.types import LoraConfig - print("\n[TEST] Setting up SkyRLTrainBackend (colocated, 2 GPU)") - # ===== Step 1: Create backend ===== overrides = { "trainer.placement.colocate_all": False, @@ -124,23 +122,19 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== model_id = "test-model" backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) - print("[Step 2] Model created with real weights (FSDP2, 2 workers)") # ===== Step 3: Inject dummy weight config before inference engine creation ===== backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} - print("[Step 3] Injected load_format=dummy into engine_init_kwargs") # ===== Step 4: Create inference engines with dummy weights ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): backend._ensure_inference_engines() - print("[Step 4] Inference engines created (dummy weights)") # Wait for servers to be healthy server_urls = backend._server_group.get_server_urls() assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" for url in server_urls: assert wait_for_url(url), f"Server {url} failed to start" - print(f"[Step 4] Servers healthy: {server_urls}") try: # ===== Step 5: Verify dummy weights produce gibberish ===== @@ -156,19 +150,16 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): assert resp.status_code == 200, f"Completions request failed: {resp.text}" text_before = resp.json()["choices"][0]["text"] - print(f"[Step 5] Dummy weights output: {text_before!r}") assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== await backend._inference_engine_client.sleep() - print("[Step 6] Inference engines sleeping") # ===== Step 7: Sync weights via save_sampler_checkpoint ===== with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): backend._validate_model_state(model_id) backend._ensure_inference_engines() await backend._dispatch.save_weights_for_sampler() - print("[Step 7] save_sampler_checkpoint() completed (NCCL broadcast done)") # ===== Step 8: Verify real weights produce correct output ===== async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: @@ -176,12 +167,9 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): assert resp.status_code == 200, f"Completions request failed: {resp.text}" text_after = resp.json()["choices"][0]["text"] - print(f"[Step 8] Real weights output: {text_after!r}") assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - print("[SUCCESS] Backend weight sync test passed!") - finally: # Cleanup: teardown inference client session if backend._inference_engine_client is not None: - await backend._inference_engine_client.teardown() \ No newline at end of file + await backend._inference_engine_client.teardown() From 4d8298c7f31819c1232fa379a2728d5a109ca424 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:53:48 -0700 Subject: [PATCH 20/38] stricter tests --- .../inference_servers/test_new_inference_generation.py | 8 +++++--- .../inference_servers/test_remote_inference_client.py | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py index e65a8ac216..fc74cd7afa 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py @@ -593,13 +593,15 @@ def test_client_tokenize_detokenize_roundtrip(vllm_server: InferenceEngineState) def test_client_render_chat_completion(vllm_server: InferenceEngineState): """Test render_chat_completion via RemoteInferenceClient against real vLLM.""" client = vllm_server.client - messages = [{"role": "user", "content": "Hello"}] + messages = [{"role": "user", "content": "Hello world!"}] result = asyncio.run(client.render_chat_completion(messages=messages)) # vLLM returns [conversation, engine_prompts] assert isinstance(result, list) assert len(result) == 2 conversation, engine_prompts = result - # engine_prompts should have prompt_token_ids + # engine_prompts should have prompt_token_ids matching local tokenizer output assert len(engine_prompts) > 0 assert "prompt_token_ids" in engine_prompts[0] - assert len(engine_prompts[0]["prompt_token_ids"]) > 0 + tokenizer = AutoTokenizer.from_pretrained(MODEL_QWEN2_5) + expected_token_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + assert engine_prompts[0]["prompt_token_ids"] == expected_token_ids diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index 635cc4c6da..86556417b1 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -263,8 +263,9 @@ async def test_render_chat_completion(self, client): assert len(result) == 2 conversation, engine_prompts = result assert conversation == messages - assert len(engine_prompts) > 0 - assert "prompt_token_ids" in engine_prompts[0] + assert len(engine_prompts) == 1 + assert engine_prompts[0]["prompt_token_ids"] == [1, 2, 3] + assert engine_prompts[0]["prompt"] == "rendered prompt" @pytest.mark.asyncio async def test_tokenize(self, client): From 4651da7ebc7027fb588ddec0a02d2887b931d76a Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 16:56:26 -0700 Subject: [PATCH 21/38] move imports up --- skyrl/backends/skyrl_train_backend.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 9a93fbc262..a4361f9f6c 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -31,6 +31,10 @@ ) from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient from skyrl.utils.tok import get_tokenizer +from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient +from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter +from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup +from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args class SkyRLTrainBackendOverrides(BaseModel, extra="allow"): @@ -192,11 +196,6 @@ def _create_remote_inference_client(self): Mirrors main_base.py._get_new_inference_client() with the same 4-way branching on external_proxy_url / external_server_urls. """ - from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient - from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter - from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup - from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args - ie_cfg = self._cfg.generator.inference_engine is_colocated = self._cfg.trainer.placement.colocate_all external_proxy_url = ie_cfg.external_proxy_url From 39514695f4e888d6b9d6c9636e6d666bd147381d Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 17:08:31 -0700 Subject: [PATCH 22/38] fmt --- .../gpu/gpu_ci/test_save_weights_for_sampler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py index 2f0af688a1..af243fc851 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py @@ -122,7 +122,9 @@ def test_save_weights_for_sampler_then_inference(ray_init_fixture, colocate_all, cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) tokenizer = get_tokenizer(MODEL) - outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=5), sampling_params, tokenizer=tokenizer)) + outputs = asyncio.run( + run_inference(client, get_test_prompts(MODEL, num_samples=5), sampling_params, tokenizer=tokenizer) + ) # Verify we got responses assert "responses" in outputs, "Inference should return responses" @@ -188,5 +190,7 @@ def test_save_weights_for_sampler_multiple_training_steps(ray_init_fixture): cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) tokenizer = get_tokenizer(MODEL) - outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=2), sampling_params, tokenizer=tokenizer)) + outputs = asyncio.run( + run_inference(client, get_test_prompts(MODEL, num_samples=2), sampling_params, tokenizer=tokenizer) + ) assert len(outputs["responses"]) == 2, "Should get 2 responses" From 2c399e02167ad45ba2957e7fb7536e3b98b4d158 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Tue, 10 Mar 2026 00:40:53 +0000 Subject: [PATCH 23/38] add typing --- .../skyrl_train/inference_servers/remote_inference_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 66881f2d54..6f420cb58d 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -297,7 +297,7 @@ async def sample( self, prompt_token_ids: List[int], num_samples: int, - sampling_params: Dict[str, Any], + sampling_params: Dict[str, Union[str, int, float, bool, List, Dict]], session_id: Optional[Union[str, int]] = None, ) -> InferenceEngineOutput: """ From fde53424f48fe310b4af8b2cc8b6946321912549 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Tue, 10 Mar 2026 01:06:10 +0000 Subject: [PATCH 24/38] add lora_id --- .../skyrl_train/inference_servers/remote_inference_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 6f420cb58d..ee6a4aac49 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -299,6 +299,7 @@ async def sample( num_samples: int, sampling_params: Dict[str, Union[str, int, float, bool, List, Dict]], session_id: Optional[Union[str, int]] = None, + lora_id: Optional[int] = None, ) -> InferenceEngineOutput: """ Generate multiple independent samples for the same prompt. From d894f8208bb394b98eb9b1a35e8784f4535e0812 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 19:15:53 -0700 Subject: [PATCH 25/38] fix typing Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- .../skyrl_train/inference_servers/remote_inference_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index ee6a4aac49..1068734399 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -375,7 +375,7 @@ async def render_chat_completion( add_generation_prompt: bool = True, continue_final_message: bool = False, session_id: Optional[Union[str, int]] = None, - ) -> Dict[str, Any]: + ) -> List[Any]: """ Render chat messages into a tokenized prompt via /v1/chat/completions/render. From 269a81c7ca4eb3e1d7af4ce637e2a57030c65bb8 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 9 Mar 2026 19:18:19 -0700 Subject: [PATCH 26/38] fix docstrings --- .../inference_servers/test_backend_weight_sync.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index 77c1171ff5..93a9414099 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -1,7 +1,7 @@ """ GPU CI test for weight sync through the SkyRLTrainBackend API (new inference path). -Uses colocated mode with 2 GPUs (TP=1, 2 engines, 2 FSDP2 workers): +Uses the non-colocated setting (colocate_all=False) with 2 GPUs (TP=1, 2 engines, 2 FSDP2 workers): - Backend creates FSDP2 workers with real weights from HF - Inference servers start with dummy (random) weights via engine_init_kwargs - save_sampler_checkpoint() broadcasts real training weights via NCCL @@ -88,11 +88,11 @@ def ray_env_with_new_inference(): @pytest.mark.asyncio(loop_scope="class") class TestBackendWeightSync: - """Test weight sync through SkyRLTrainBackend with new inference path (colocated).""" + """Test weight sync through SkyRLTrainBackend with new inference path (non-colocated).""" - async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): + async def test_backend_weight_sync_non_colocated(self, ray_env_with_new_inference): """ - End-to-end colocated weight sync test via SkyRLTrainBackend: + End-to-end non-colocated weight sync test via SkyRLTrainBackend: 1. Create backend with 2 FSDP2 workers (real weights from HF) 2. Start 2 inference servers with dummy (random) weights @@ -152,7 +152,7 @@ async def test_backend_weight_sync_colocated(self, ray_env_with_new_inference): text_before = resp.json()["choices"][0]["text"] assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" - # ===== Step 6: Sleep inference engines (required before colocated weight sync) ===== + # ===== Step 6: Sleep inference engines (required before weight sync) ===== await backend._inference_engine_client.sleep() # ===== Step 7: Sync weights via save_sampler_checkpoint ===== From bd2601513c28a723cb47ec34a0169827726fc275 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 00:42:07 +0000 Subject: [PATCH 27/38] remove duplicate test fixture --- .../test_backend_weight_sync.py | 54 +------------------ 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py index 93a9414099..03004c8b6c 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py @@ -12,29 +12,15 @@ tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py -v -s """ -import os import time from unittest import mock import httpx import pytest -import ray -from functools import lru_cache -from loguru import logger - -from skyrl.train.utils.utils import peer_access_supported -from skyrl.env_vars import SKYRL_PYTHONPATH_EXPORT - MODEL = "Qwen/Qwen2.5-0.5B-Instruct" -@lru_cache(5) -def log_once(msg): - logger.info(msg) - return None - - def wait_for_url(url: str, timeout: float = 180.0) -> bool: """Wait for a URL to become available.""" start = time.time() @@ -48,49 +34,11 @@ def wait_for_url(url: str, timeout: float = 180.0) -> bool: return False -@pytest.fixture(scope="class") -def ray_env_with_new_inference(): - """Ray init fixture with _SKYRL_USE_NEW_INFERENCE=1 in runtime env.""" - if ray.is_initialized(): - ray.shutdown() - - env_vars = { - "VLLM_USE_V1": "1", - "VLLM_ENABLE_V1_MULTIPROCESSING": "0", - "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", - "_SKYRL_USE_NEW_INFERENCE": "1", - "CUDA_DEVICE_MAX_CONNECTIONS": "1", - "NVTE_FUSED_ATTN": "0", - } - - if not peer_access_supported(max_num_gpus_per_node=2): - log_once("Disabling NCCL P2P for CI environment") - env_vars.update( - { - "NCCL_P2P_DISABLE": "1", - "NCCL_SHM_DISABLE": "1", - } - ) - - if SKYRL_PYTHONPATH_EXPORT: - pythonpath = os.environ.get("PYTHONPATH") - if pythonpath is None: - raise RuntimeError("SKYRL_PYTHONPATH_EXPORT is set but PYTHONPATH is not defined in environment") - env_vars["PYTHONPATH"] = pythonpath - - logger.info(f"Initializing Ray with environment variables: {env_vars}") - ray.init(runtime_env={"env_vars": env_vars}) - - yield - - ray.shutdown() - - @pytest.mark.asyncio(loop_scope="class") class TestBackendWeightSync: """Test weight sync through SkyRLTrainBackend with new inference path (non-colocated).""" - async def test_backend_weight_sync_non_colocated(self, ray_env_with_new_inference): + async def test_backend_weight_sync_non_colocated(self, ray_init_fixture): """ End-to-end non-colocated weight sync test via SkyRLTrainBackend: From b5bda898d81ca445bd2d0f96fe608accbb2b752e Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Tue, 10 Mar 2026 19:33:50 -0700 Subject: [PATCH 28/38] update to use request_payload format --- .../remote_inference_client.py | 114 ++++++++++++------ skyrl/backends/skyrl_train_backend.py | 91 +++++++++----- 2 files changed, 141 insertions(+), 64 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 1068734399..f29de69c15 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -293,49 +293,95 @@ async def _generate_single( "response_logprobs": response_logprobs if len(response_logprobs) > 0 else None, } - async def sample( - self, - prompt_token_ids: List[int], - num_samples: int, - sampling_params: Dict[str, Union[str, int, float, bool, List, Dict]], - session_id: Optional[Union[str, int]] = None, - lora_id: Optional[int] = None, - ) -> InferenceEngineOutput: - """ - Generate multiple independent samples for the same prompt. + async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Sample completions via /inference/v1/generate. - Fires num_samples parallel calls to _generate_single with the same - prompt_token_ids and sampling_params, then aggregates into a single - InferenceEngineOutput. + Single request with n in sampling_params. No retry-on-abort. Args: - prompt_token_ids: Token IDs for the prompt. - num_samples: Number of independent samples to generate. - sampling_params: Sampling parameters for generation. - session_id: Optional session ID for consistent routing via X-Session-ID header. + request_payload: {"json": {...}, "headers": {...}} + json body matches Tinker SamplingClient.sample() args: + prompt (ModelInput), num_samples, sampling_params (SamplingParams), + include_prompt_logprobs, topk_prompt_logprobs. + session_id is optional for routing. Returns: - InferenceEngineOutput with num_samples responses. + Dict matching Tinker SampleResponse schema. """ - get_logprobs = sampling_params.get("logprobs") is not None + body = request_payload.get("json", {}) + session_id = body.pop("session_id", None) - tasks = [ - self._generate_single( - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - session_id=session_id, - ) - for _ in range(num_samples) - ] + # Tinker input fields + prompt = body.get("prompt", {}) # ModelInput dict: {"chunks": [{"tokens": [...]}]} + num_samples = body.get("num_samples", 1) + sampling_params = body.get("sampling_params", {}) + prompt_logprobs = body.get("prompt_logprobs", False) + topk_prompt_logprobs = body.get("topk_prompt_logprobs", 0) + + prompt_token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] + + # Tinker types.py SamplingParams -> vLLM /inference/v1/generate sampling_params + vllm_sampling_params = { + "n": num_samples, + "temperature": sampling_params.get("temperature"), + "max_tokens": sampling_params.get("max_tokens"), + "seed": sampling_params.get("seed"), + "top_k": sampling_params.get("top_k", -1), + "top_p": sampling_params.get("top_p", 1.0), + "prompt_logprobs": topk_prompt_logprobs if prompt_logprobs else 0, + "logprobs": 0, # response logprobs + } + if sampling_params.get("stop_strings"): + vllm_sampling_params["stop"] = sampling_params["stop_strings"] + if sampling_params.get("stop_tokens"): + vllm_sampling_params["stop_token_ids"] = sampling_params["stop_tokens"] - results = await asyncio.gather(*tasks) + payload = { + "sampling_params": vllm_sampling_params, + "model": self.model_name, + "token_ids": prompt_token_ids, + } - return InferenceEngineOutput( - responses=[r["response"] for r in results], - stop_reasons=[r["stop_reason"] for r in results], - response_ids=[r["response_ids"] for r in results], - response_logprobs=[r["response_logprobs"] for r in results] if get_logprobs else None, - ) + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + session = await self._get_session() + url = f"{self.proxy_url}/inference/v1/generate" + + async with session.post(url, json=payload, headers=headers) as resp: + result = await resp.json() + raise_for_status(resp, result) + + # Transform response choices -> SampleResponse dict + sequences = [] + for choice in result["choices"]: + raw_stop = choice.get("finish_reason", "length") + stop_reason = "stop" if raw_stop in ("stop", "stop_token") else "length" + + token_ids = choice.get("token_ids", []) + + logprobs = None + lp = choice.get("logprobs") + if lp is not None: + logprobs_content = lp.get("content", []) + if logprobs_content: + logprobs = [info["logprob"] if info["logprob"] is not None else 0.0 for info in logprobs_content] + + sequences.append( + { + "stop_reason": stop_reason, + "tokens": token_ids, + "logprobs": logprobs, + } + ) + + return { + "type": "sample", + "sequences": sequences, + "prompt_logprobs": None, + "topk_prompt_logprobs": None, + } async def chat_completion( self, diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index a4361f9f6c..7797971f6e 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -604,28 +604,43 @@ async def sample_all(): prompt = prepared_batch.all_prompts[i] sampling_params = prepared_batch.all_sampling_params[i] - # Pass through common fields; only stop needs name translation - # (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids) - params_dict = { - "temperature": sampling_params.temperature, - "max_tokens": sampling_params.max_tokens, - "seed": sampling_params.seed, - "top_k": sampling_params.top_k, - "top_p": sampling_params.top_p, - "logprobs": 0, - } - if sampling_params.stop_strings: - params_dict["stop"] = sampling_params.stop_strings - if sampling_params.stop_tokens: - params_dict["stop_token_ids"] = sampling_params.stop_tokens - - tasks.append( - self._inference_engine_client.sample( - prompt_token_ids=prompt, - num_samples=1, # Tinker batches multiple samples separately - sampling_params=params_dict, + if _SKYRL_USE_NEW_INFERENCE: + # Right now, prompt is list[int] (token IDs), so we wrap in ModelInput format + json_body = { + "prompt": {"chunks": [{"tokens": prompt}]}, + "num_samples": 1, # Tinker batches multiple samples separately + "sampling_params": { + "temperature": sampling_params.temperature, + "max_tokens": sampling_params.max_tokens, + "seed": sampling_params.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "stop_tokens": sampling_params.stop_tokens, + "stop_strings": sampling_params.stop_strings, + }, + } + tasks.append(self._inference_engine_client.sample({"json": json_body, "headers": {}})) + else: + params_dict = { + "temperature": sampling_params.temperature, + "max_tokens": sampling_params.max_tokens, + "seed": sampling_params.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "logprobs": 0, + } + if sampling_params.stop_strings: + params_dict["stop"] = sampling_params.stop_strings + if sampling_params.stop_tokens: + params_dict["stop_token_ids"] = sampling_params.stop_tokens + + tasks.append( + self._inference_engine_client.sample( + prompt_token_ids=prompt, + num_samples=1, # Tinker batches multiple samples separately + sampling_params=params_dict, + ) ) - ) return await asyncio.gather(*tasks, return_exceptions=True) @@ -636,14 +651,25 @@ async def sample_all(): # We preserve these to include error messages in responses # 4. Aggregate results by request - return self._aggregate_sample_results(prepared_batch, sample_outputs) + return self._aggregate_sample_results( + prepared_batch, sample_outputs, use_new_inference=_SKYRL_USE_NEW_INFERENCE + ) def _aggregate_sample_results( self, prepared_batch: types.PreparedSampleBatch, sample_outputs: list, + use_new_inference: bool = False, ) -> dict[str, types.SampleOutput | types.ErrorResponse]: - """Convert InferenceEngineClient outputs to Tinker format.""" + """Convert inference outputs to Tinker format. + + Args: + prepared_batch: The prepared sample batch. + sample_outputs: List of outputs from inference client. + use_new_inference: If True, outputs are SampleResponse dicts from + RemoteInferenceClient. If False, outputs are InferenceEngineOutput + dicts from InferenceEngineClient. + """ results = {} for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices: @@ -666,13 +692,18 @@ def _aggregate_sample_results( logger.error(error_msg) break - # Extract tokens and logprobs - response_tokens = output["response_ids"][0] - response_logprobs = (output.get("response_logprobs") or [[]])[0] - stop_reason_raw = output["stop_reasons"][0] - - # Map vLLM stop reason to Tinker format - stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" + if use_new_inference: + # New path: SampleResponse dict from RemoteInferenceClient + seq = output["sequences"][0] + response_tokens = seq["tokens"] + response_logprobs = seq.get("logprobs") or [] + stop_reason = seq["stop_reason"] + else: + # Old path: InferenceEngineOutput from InferenceEngineClient + response_tokens = output["response_ids"][0] + response_logprobs = (output.get("response_logprobs") or [[]])[0] + stop_reason_raw = output["stop_reasons"][0] + stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" # Ensure logprobs exist (critical for RL) if response_logprobs is None or len(response_logprobs) == 0: From b7484b3c8f3583c59f2f279cfd5524893bcf9258 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 00:17:47 -0700 Subject: [PATCH 29/38] add mock test --- .../test_remote_inference_client.py | 97 ++++++++++++++++++- 1 file changed, 93 insertions(+), 4 deletions(-) diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index 86556417b1..4681813548 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -49,13 +49,18 @@ async def completions(request: Request): @app.post("/inference/v1/generate") async def generate(request: Request): - body = await request.json() # Consume body - num_prompts = len(body.get("token_ids", [])) + body = await request.json() + n = body.get("sampling_params", {}).get("n", 1) return { "choices": [ - {"request_id": "dummy", "token_ids": [i, i + 1, i + 2], "finish_reason": "stop"} - for i in range(num_prompts) + { + "request_id": "dummy", + "token_ids": [i, i + 1, i + 2], + "finish_reason": "stop", + "logprobs": {"content": [{"logprob": -0.1}, {"logprob": -0.2}, {"logprob": -0.3}]}, + } + for i in range(n) ] } @@ -511,3 +516,87 @@ async def test_retry_on_abort(self, abort_mock_server): assert len(result["response_ids"][0]) > 0 finally: await client.teardown() + + +class TestSample: + """Test sample() with request_payload / SampleResponse dict interface.""" + + @pytest.mark.asyncio + async def test_sample(self, client): + """Test sample with n=1: verify SampleResponse dict contents.""" + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": [10, 20, 30]}]}, + "num_samples": 1, + "sampling_params": { + "temperature": 0.8, + "max_tokens": 512, + "seed": 42, + }, + }, + "headers": {}, + } + result = await client.sample(request_payload) + + assert result["type"] == "sample" + assert result["prompt_logprobs"] is None + assert result["topk_prompt_logprobs"] is None + assert len(result["sequences"]) == 1 + + # Mock returns token_ids=[0, 1, 2] for choice i=0, logprobs=[-0.1, -0.2, -0.3] + seq = result["sequences"][0] + assert seq["tokens"] == [0, 1, 2] + assert seq["logprobs"] == [-0.1, -0.2, -0.3] + assert seq["stop_reason"] == "stop" + + @pytest.mark.asyncio + async def test_sample_n2(self, client): + """Test sample with n=2: verify both sequences match mock output.""" + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": [10, 20, 30]}]}, + "num_samples": 2, + "sampling_params": { + "temperature": 0.8, + "max_tokens": 512, + "seed": 42, + }, + }, + "headers": {}, + } + result = await client.sample(request_payload) + + assert result["type"] == "sample" + assert len(result["sequences"]) == 2 + + # Mock returns token_ids=[i, i+1, i+2] for choice i + assert result["sequences"][0]["tokens"] == [0, 1, 2] + assert result["sequences"][0]["logprobs"] == [-0.1, -0.2, -0.3] + assert result["sequences"][0]["stop_reason"] == "stop" + + assert result["sequences"][1]["tokens"] == [1, 2, 3] + assert result["sequences"][1]["logprobs"] == [-0.1, -0.2, -0.3] + assert result["sequences"][1]["stop_reason"] == "stop" + + @pytest.mark.asyncio + async def test_sample_with_session_id(self, client): + """Test that session_id is popped from body and used for routing.""" + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": [1, 2]}]}, + "num_samples": 1, + "sampling_params": {"temperature": 1.0, "max_tokens": 10}, + "session_id": "test-routing-session", + }, + "headers": {}, + } + result = await client.sample(request_payload) + + # session_id should be consumed (popped) from body + assert "session_id" not in request_payload["json"] + + assert result["type"] == "sample" + assert len(result["sequences"]) == 1 + assert result["sequences"][0]["tokens"] == [0, 1, 2] + assert result["sequences"][0]["logprobs"] == [-0.1, -0.2, -0.3] + assert result["sequences"][0]["stop_reason"] == "stop" From aff3835c91c91ac88f931c95e5bc7880fbfed696 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 00:42:57 -0700 Subject: [PATCH 30/38] add comments + update gpu test --- skyrl/backends/skyrl_train_backend.py | 6 +- .../gpu/gpu_ci/test_engine_generation.py | 62 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 7797971f6e..076d467536 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -693,16 +693,18 @@ def _aggregate_sample_results( break if use_new_inference: - # New path: SampleResponse dict from RemoteInferenceClient + # New inference server: SampleResponse dict seq = output["sequences"][0] response_tokens = seq["tokens"] response_logprobs = seq.get("logprobs") or [] stop_reason = seq["stop_reason"] else: - # Old path: InferenceEngineOutput from InferenceEngineClient + # Old inference engine: InferenceEngineOutput + # Extract tokens and logprobs response_tokens = output["response_ids"][0] response_logprobs = (output.get("response_logprobs") or [[]])[0] stop_reason_raw = output["stop_reasons"][0] + # Map vLLM stop reason to Tinker format stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" # Ensure logprobs exist (critical for RL) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py index ef74ec1e4c..28731709c5 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py @@ -286,6 +286,7 @@ def test_token_based_generation_consistency(ray_init_fixture, tp_size: int, pp_s ) +@pytest.mark.skipif(_SKYRL_USE_NEW_INFERENCE, reason="Old sample API not used with new inference path") @pytest.mark.parametrize( "tp_size,dp_size", [ @@ -337,3 +338,64 @@ async def run_sample(): print(f"Generated {len(unique_responses)} unique responses from {num_samples} samples") for i, resp in enumerate(output["responses"]): print(f"Sample {i}: {resp[:100]}..." if len(resp) > 100 else f"Sample {i}: {resp}") + + +@pytest.mark.parametrize( + "tp_size,dp_size", + [ + pytest.param(2, 1), + ], + ids=["tp2"], +) +def test_sample_api_remote(ray_init_fixture, tp_size: int, dp_size: int): + """Test the sample() API via RemoteInferenceClient (new inference path).""" + cfg = get_test_actor_config() + cfg.generator.sampling_params.temperature = 0.7 + + prompts = get_test_prompts(MODEL, 1) + tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompt_token_ids = tokenizer.apply_chat_template( + prompts, add_generation_prompt=True, tokenize=True, return_dict=True + )["input_ids"][0] + + cfg.generator.inference_engine.tensor_parallel_size = tp_size + cfg.generator.inference_engine.data_parallel_size = dp_size + + num_samples = 3 + + with InferenceEngineState.create(cfg, sleep_level=1, use_new_inference_servers=True) as engines: + llm_client = engines.client + + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": prompt_token_ids}]}, + "num_samples": num_samples, + "sampling_params": { + "temperature": cfg.generator.sampling_params.temperature, + "max_tokens": cfg.generator.sampling_params.max_generate_length, + "top_k": cfg.generator.sampling_params.top_k, + "top_p": cfg.generator.sampling_params.top_p, + }, + }, + "headers": {}, + } + + async def run_sample(): + return await llm_client.sample(request_payload) + + output = asyncio.run(run_sample()) + + assert output["type"] == "sample" + assert len(output["sequences"]) == num_samples + + for i, seq in enumerate(output["sequences"]): + assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" + assert isinstance(seq["tokens"], list), f"tokens should be a list, got {type(seq['tokens'])}" + assert len(seq["tokens"]) > 0, f"Sequence {i} has no tokens" + assert all(isinstance(t, int) for t in seq["tokens"]) + if seq.get("logprobs") is not None: + assert isinstance(seq["logprobs"], list) + + print(f"Generated {num_samples} samples via RemoteInferenceClient") + for i, seq in enumerate(output["sequences"]): + print(f"Sample {i}: {len(seq['tokens'])} tokens, stop_reason={seq['stop_reason']}") From 15b2ddb71c83ab94f9f38b0429f1b73eb802af58 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 00:51:41 -0700 Subject: [PATCH 31/38] stronger gpu checks --- .../gpu/gpu_ci/test_engine_generation.py | 83 ++++++++++++++----- 1 file changed, 62 insertions(+), 21 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py index 28731709c5..0d9d5182b7 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py @@ -348,9 +348,13 @@ async def run_sample(): ids=["tp2"], ) def test_sample_api_remote(ray_init_fixture, tp_size: int, dp_size: int): - """Test the sample() API via RemoteInferenceClient (new inference path).""" + """Test the sample() API via RemoteInferenceClient (new inference path). + + Makes two calls to validate both output correctness and sampling behavior: + - Call A (temp=1.0): schema checks, token decode, diversity across samples + - Call B (temp=0.0): determinism (all samples should be identical) + """ cfg = get_test_actor_config() - cfg.generator.sampling_params.temperature = 0.7 prompts = get_test_prompts(MODEL, 1) tokenizer = AutoTokenizer.from_pretrained(MODEL) @@ -366,28 +370,28 @@ def test_sample_api_remote(ray_init_fixture, tp_size: int, dp_size: int): with InferenceEngineState.create(cfg, sleep_level=1, use_new_inference_servers=True) as engines: llm_client = engines.client - request_payload = { - "json": { - "prompt": {"chunks": [{"tokens": prompt_token_ids}]}, - "num_samples": num_samples, - "sampling_params": { - "temperature": cfg.generator.sampling_params.temperature, - "max_tokens": cfg.generator.sampling_params.max_generate_length, - "top_k": cfg.generator.sampling_params.top_k, - "top_p": cfg.generator.sampling_params.top_p, + def build_payload(temperature, n_samples): + return { + "json": { + "prompt": {"chunks": [{"tokens": prompt_token_ids}]}, + "num_samples": n_samples, + "sampling_params": { + "temperature": temperature, + "max_tokens": cfg.generator.sampling_params.max_generate_length, + "top_k": cfg.generator.sampling_params.top_k, + "top_p": cfg.generator.sampling_params.top_p, + }, }, - }, - "headers": {}, - } - - async def run_sample(): - return await llm_client.sample(request_payload) + "headers": {}, + } - output = asyncio.run(run_sample()) + # --- Call A: temp=1.0, expect diverse outputs --- + output = asyncio.run(llm_client.sample(build_payload(1.0, num_samples))) assert output["type"] == "sample" assert len(output["sequences"]) == num_samples + decoded_texts = [] for i, seq in enumerate(output["sequences"]): assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" assert isinstance(seq["tokens"], list), f"tokens should be a list, got {type(seq['tokens'])}" @@ -396,6 +400,43 @@ async def run_sample(): if seq.get("logprobs") is not None: assert isinstance(seq["logprobs"], list) - print(f"Generated {num_samples} samples via RemoteInferenceClient") - for i, seq in enumerate(output["sequences"]): - print(f"Sample {i}: {len(seq['tokens'])} tokens, stop_reason={seq['stop_reason']}") + text = tokenizer.decode(seq["tokens"], skip_special_tokens=True) + assert len(text.strip()) > 0, f"Sequence {i} decoded to empty text from {len(seq['tokens'])} tokens" + decoded_texts.append(text) + + unique_texts = set(decoded_texts) + assert len(unique_texts) > 1, ( + f"All {num_samples} samples at temp=1.0 are identical — sampling params may be ignored. " + f"Text: {decoded_texts[0][:120]!r}" + ) + + print(f"Call A (temp=1.0): {len(unique_texts)}/{num_samples} unique samples") + for i, text in enumerate(decoded_texts): + print(f" Sample {i}: {text[:100]}..." if len(text) > 100 else f" Sample {i}: {text}") + + # --- Call B: temp=0.0, expect deterministic outputs --- + output_det = asyncio.run(llm_client.sample(build_payload(0.0, num_samples))) + + assert output_det["type"] == "sample" + assert len(output_det["sequences"]) == num_samples + + det_token_seqs = [] + det_texts = [] + for i, seq in enumerate(output_det["sequences"]): + assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" + assert isinstance(seq["tokens"], list) + assert len(seq["tokens"]) > 0, f"Deterministic sequence {i} has no tokens" + det_token_seqs.append(tuple(seq["tokens"])) + + text = tokenizer.decode(seq["tokens"], skip_special_tokens=True) + assert len(text.strip()) > 0, f"Deterministic sequence {i} decoded to empty text" + det_texts.append(text) + + unique_det = set(det_token_seqs) + assert len(unique_det) == 1, ( + f"temp=0.0 produced {len(unique_det)} distinct token sequences — expected deterministic output. " + f"Lengths: {[len(s) for s in det_token_seqs]}" + ) + + print(f"Call B (temp=0.0): all {num_samples} samples identical (deterministic)") + print(f" Text: {det_texts[0][:120]}") From 863c203098e04d7ef53def348ca150a6c5dc1ff0 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 08:05:57 +0000 Subject: [PATCH 32/38] update gpu sample api test --- .../gpu/gpu_ci/test_engine_generation.py | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py index 0d9d5182b7..acc74df8d6 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py @@ -370,11 +370,11 @@ def test_sample_api_remote(ray_init_fixture, tp_size: int, dp_size: int): with InferenceEngineState.create(cfg, sleep_level=1, use_new_inference_servers=True) as engines: llm_client = engines.client - def build_payload(temperature, n_samples): + def build_payload(temperature): return { "json": { "prompt": {"chunks": [{"tokens": prompt_token_ids}]}, - "num_samples": n_samples, + "num_samples": 1, "sampling_params": { "temperature": temperature, "max_tokens": cfg.generator.sampling_params.max_generate_length, @@ -385,20 +385,28 @@ def build_payload(temperature, n_samples): "headers": {}, } - # --- Call A: temp=1.0, expect diverse outputs --- - output = asyncio.run(llm_client.sample(build_payload(1.0, num_samples))) + async def run_samples(temperature, n): + results = [] + for _ in range(n): + output = await llm_client.sample(build_payload(temperature)) + assert output["type"] == "sample", f"Expected type 'sample', got {output['type']!r}" + assert len(output["sequences"]) == 1, f"Expected 1 sequence per call, got {len(output['sequences'])}" + results.append(output["sequences"][0]) + return results - assert output["type"] == "sample" - assert len(output["sequences"]) == num_samples + # --- Call A: temp=1.0, expect diverse outputs --- + sequences = asyncio.run(run_samples(1.0, num_samples)) decoded_texts = [] - for i, seq in enumerate(output["sequences"]): + for i, seq in enumerate(sequences): assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" assert isinstance(seq["tokens"], list), f"tokens should be a list, got {type(seq['tokens'])}" assert len(seq["tokens"]) > 0, f"Sequence {i} has no tokens" - assert all(isinstance(t, int) for t in seq["tokens"]) + assert all(isinstance(t, int) for t in seq["tokens"]), f"Sequence {i} contains non-int tokens" if seq.get("logprobs") is not None: - assert isinstance(seq["logprobs"], list) + assert isinstance( + seq["logprobs"], list + ), f"Sequence {i} logprobs should be a list, got {type(seq['logprobs'])}" text = tokenizer.decode(seq["tokens"], skip_special_tokens=True) assert len(text.strip()) > 0, f"Sequence {i} decoded to empty text from {len(seq['tokens'])} tokens" @@ -415,16 +423,15 @@ def build_payload(temperature, n_samples): print(f" Sample {i}: {text[:100]}..." if len(text) > 100 else f" Sample {i}: {text}") # --- Call B: temp=0.0, expect deterministic outputs --- - output_det = asyncio.run(llm_client.sample(build_payload(0.0, num_samples))) - - assert output_det["type"] == "sample" - assert len(output_det["sequences"]) == num_samples + det_sequences = asyncio.run(run_samples(0.0, num_samples)) det_token_seqs = [] det_texts = [] - for i, seq in enumerate(output_det["sequences"]): + for i, seq in enumerate(det_sequences): assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" - assert isinstance(seq["tokens"], list) + assert isinstance( + seq["tokens"], list + ), f"Deterministic sequence {i} tokens should be a list, got {type(seq['tokens'])}" assert len(seq["tokens"]) > 0, f"Deterministic sequence {i} has no tokens" det_token_seqs.append(tuple(seq["tokens"])) From 7eaa5d8f3f6719550c60e2149f7b8013ddff794a Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 01:22:14 -0700 Subject: [PATCH 33/38] update render to use payload_request --- .../remote_inference_client.py | 28 +++++++------------ .../test_new_inference_generation.py | 8 +++++- .../test_remote_inference_client.py | 9 +++++- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index f29de69c15..b59bdf8ee1 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -417,10 +417,7 @@ async def chat_completion( async def render_chat_completion( self, - messages: List[Dict[str, Any]], - add_generation_prompt: bool = True, - continue_final_message: bool = False, - session_id: Optional[Union[str, int]] = None, + request_payload: Dict[str, Any], ) -> List[Any]: """ Render chat messages into a tokenized prompt via /v1/chat/completions/render. @@ -428,31 +425,26 @@ async def render_chat_completion( Applies the model's chat template and tokenizes without generating text. Args: - messages: List of chat messages (e.g., [{"role": "user", "content": "Hello"}]). - add_generation_prompt: Whether to add generation prompt after messages. - continue_final_message: Whether to continue the final message. - session_id: Optional session ID for consistent routing via X-Session-ID header. - Needed for multimodal inputs where vLLM caches processed data on a specific backend. + request_payload: Dict with {"json": , "headers": }. + The request body should contain messages and optional chat template params. + session_id can be included in json for consistent routing. Returns: List of [conversation, engine_prompts] where engine_prompts contains dicts with "prompt" and "prompt_token_ids". """ - session = await self._get_session() - url = f"{self.proxy_url}/v1/chat/completions/render" + body = request_payload.get("json", {}) - payload = { - "model": self.model_name, - "messages": messages, - "add_generation_prompt": add_generation_prompt, - "continue_final_message": continue_final_message, - } + session_id = body.pop("session_id", None) headers = {"Content-Type": "application/json"} if session_id: headers["X-Session-ID"] = str(session_id) - async with session.post(url, json=payload, headers=headers) as resp: + session = await self._get_session() + url = f"{self.proxy_url}/v1/chat/completions/render" + + async with session.post(url, json=body, headers=headers) as resp: response = await resp.json() raise_for_status(resp, response) return response diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py index fc74cd7afa..26d940606e 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py @@ -594,7 +594,13 @@ def test_client_render_chat_completion(vllm_server: InferenceEngineState): """Test render_chat_completion via RemoteInferenceClient against real vLLM.""" client = vllm_server.client messages = [{"role": "user", "content": "Hello world!"}] - result = asyncio.run(client.render_chat_completion(messages=messages)) + request_payload = { + "json": { + "messages": messages, + }, + "headers": {}, + } + result = asyncio.run(client.render_chat_completion(request_payload)) # vLLM returns [conversation, engine_prompts] assert isinstance(result, list) assert len(result) == 2 diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index 4681813548..d315a91d51 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -263,7 +263,14 @@ async def test_completion(self, client): async def test_render_chat_completion(self, client): """Test render_chat_completion method.""" messages = [{"role": "user", "content": "Hello"}] - result = await client.render_chat_completion(messages=messages) + request_payload = { + "json": { + "model": "test", + "messages": messages, + }, + "headers": {}, + } + result = await client.render_chat_completion(request_payload) assert isinstance(result, list) assert len(result) == 2 conversation, engine_prompts = result From 03ff39dffc34263d408d99ab69868c12de69bd9d Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 01:59:43 -0700 Subject: [PATCH 34/38] add log probs extraction --- .../remote_inference_client.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index b59bdf8ee1..28c4e09611 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -328,7 +328,7 @@ async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: "seed": sampling_params.get("seed"), "top_k": sampling_params.get("top_k", -1), "top_p": sampling_params.get("top_p", 1.0), - "prompt_logprobs": topk_prompt_logprobs if prompt_logprobs else 0, + "prompt_logprobs": max(topk_prompt_logprobs, 1) if prompt_logprobs else 0, "logprobs": 0, # response logprobs } if sampling_params.get("stop_strings"): @@ -376,11 +376,39 @@ async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: } ) + # Transform prompt_logprobs from vLLM format to Tinker SampleResponse format + transformed_prompt_logprobs = None + transformed_topk_prompt_logprobs = None + # raw_prompt_logprobs: VLLM type list[dict[int, Logprob]] | None + raw_prompt_logprobs = result.get("prompt_logprobs") if prompt_logprobs else None + + if raw_prompt_logprobs is not None: + # prompt_logprobs: single float per token (logprob of the actual prompt token) + transformed_prompt_logprobs = [] + for i, pos in enumerate(raw_prompt_logprobs): + if pos is None: + transformed_prompt_logprobs.append(None) + else: + token_key = str(prompt_token_ids[i]) + entry = pos.get(token_key) + transformed_prompt_logprobs.append(entry["logprob"] if entry is not None else None) + + # topk_prompt_logprobs: list of (token_id, logprob) tuples per position + if topk_prompt_logprobs > 0: + transformed_topk_prompt_logprobs = [] + for pos in raw_prompt_logprobs: + if pos is None: + transformed_topk_prompt_logprobs.append(None) + else: + transformed_topk_prompt_logprobs.append( + [(int(tid), info["logprob"]) for tid, info in pos.items()] + ) + return { "type": "sample", "sequences": sequences, - "prompt_logprobs": None, - "topk_prompt_logprobs": None, + "prompt_logprobs": transformed_prompt_logprobs, + "topk_prompt_logprobs": transformed_topk_prompt_logprobs, } async def chat_completion( From 71ca1cf1fc7354012fc3130543de6de0646ba075 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Wed, 11 Mar 2026 09:19:21 +0000 Subject: [PATCH 35/38] comment + remove duplicate test --- .../remote_inference_client.py | 4 +- .../test_backend_weight_sync.py | 123 ------------------ 2 files changed, 2 insertions(+), 125 deletions(-) delete mode 100644 tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 28c4e09611..dde3cf24de 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -353,7 +353,7 @@ async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: result = await resp.json() raise_for_status(resp, result) - # Transform response choices -> SampleResponse dict + # Transform response choices -> tinker type SampleResponse dict sequences = [] for choice in result["choices"]: raw_stop = choice.get("finish_reason", "length") @@ -367,7 +367,7 @@ async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: logprobs_content = lp.get("content", []) if logprobs_content: logprobs = [info["logprob"] if info["logprob"] is not None else 0.0 for info in logprobs_content] - + # Convert to tinker type SampledSequence dict sequences.append( { "stop_reason": stop_reason, diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py deleted file mode 100644 index 03004c8b6c..0000000000 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -GPU CI test for weight sync through the SkyRLTrainBackend API (new inference path). - -Uses the non-colocated setting (colocate_all=False) with 2 GPUs (TP=1, 2 engines, 2 FSDP2 workers): - - Backend creates FSDP2 workers with real weights from HF - - Inference servers start with dummy (random) weights via engine_init_kwargs - - save_sampler_checkpoint() broadcasts real training weights via NCCL - - Verified by querying the server before and after sync - -Run: - uv run --isolated --extra dev --extra fsdp pytest \ - tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py -v -s -""" - -import time -from unittest import mock - -import httpx -import pytest - -MODEL = "Qwen/Qwen2.5-0.5B-Instruct" - - -def wait_for_url(url: str, timeout: float = 180.0) -> bool: - """Wait for a URL to become available.""" - start = time.time() - while time.time() - start < timeout: - try: - resp = httpx.get(f"{url}/health", timeout=5.0) - if resp.status_code == 200: - return True - except httpx.RequestError: - time.sleep(2.0) - return False - - -@pytest.mark.asyncio(loop_scope="class") -class TestBackendWeightSync: - """Test weight sync through SkyRLTrainBackend with new inference path (non-colocated).""" - - async def test_backend_weight_sync_non_colocated(self, ray_init_fixture): - """ - End-to-end non-colocated weight sync test via SkyRLTrainBackend: - - 1. Create backend with 2 FSDP2 workers (real weights from HF) - 2. Start 2 inference servers with dummy (random) weights - 3. Verify dummy weights produce gibberish - 4. Run save_sampler_checkpoint() to broadcast real weights via NCCL - 5. Verify real weights produce correct output - """ - from skyrl.backends.skyrl_train_backend import ( - SkyRLTrainBackend, - FSDPBackendOverrides, - ) - from skyrl.tinker.types import LoraConfig - - # ===== Step 1: Create backend ===== - overrides = { - "trainer.placement.colocate_all": False, - "trainer.placement.policy_num_gpus_per_node": 2, - "trainer.placement.policy_num_nodes": 1, - "trainer.logger": "console", - "generator.inference_engine.tensor_parallel_size": 1, - "generator.inference_engine.num_engines": 2, - "generator.inference_engine.gpu_memory_utilization": 0.5, - "generator.inference_engine.async_engine": True, - } - backend = SkyRLTrainBackend(MODEL, FSDPBackendOverrides(**overrides)) - - # ===== Step 2: Create model (real weights, FSDP2 sharded across 2 GPUs) ===== - model_id = "test-model" - backend.create_model(model_id, LoraConfig(rank=0, alpha=0, seed=42)) - - # ===== Step 3: Inject dummy weight config before inference engine creation ===== - backend._cfg.generator.inference_engine.engine_init_kwargs = {"load_format": "dummy"} - - # ===== Step 4: Create inference engines with dummy weights ===== - with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): - backend._ensure_inference_engines() - - # Wait for servers to be healthy - server_urls = backend._server_group.get_server_urls() - assert len(server_urls) == 2, f"Expected 2 server URLs, got {len(server_urls)}" - for url in server_urls: - assert wait_for_url(url), f"Server {url} failed to start" - - try: - # ===== Step 5: Verify dummy weights produce gibberish ===== - payload = { - "model": MODEL, - "prompt": "What is the capital of France?", - "max_tokens": 32, - "temperature": 0.0, - } - - async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: - resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) - assert resp.status_code == 200, f"Completions request failed: {resp.text}" - - text_before = resp.json()["choices"][0]["text"] - assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" - - # ===== Step 6: Sleep inference engines (required before weight sync) ===== - await backend._inference_engine_client.sleep() - - # ===== Step 7: Sync weights via save_sampler_checkpoint ===== - with mock.patch("skyrl.backends.skyrl_train_backend._SKYRL_USE_NEW_INFERENCE", True): - backend._validate_model_state(model_id) - backend._ensure_inference_engines() - await backend._dispatch.save_weights_for_sampler() - - # ===== Step 8: Verify real weights produce correct output ===== - async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as http_client: - resp = await http_client.post(f"{server_urls[0]}/v1/completions", json=payload) - assert resp.status_code == 200, f"Completions request failed: {resp.text}" - - text_after = resp.json()["choices"][0]["text"] - assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - - finally: - # Cleanup: teardown inference client session - if backend._inference_engine_client is not None: - await backend._inference_engine_client.teardown() From 47c94002ac5ee3cfdae224fc2404cec8bb4cb893 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Thu, 12 Mar 2026 14:21:15 -0700 Subject: [PATCH 36/38] Update skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- .../skyrl_train/inference_servers/remote_inference_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index c5100d3246..8e00361062 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -332,7 +332,7 @@ async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: "seed": sampling_params.get("seed"), "top_k": sampling_params.get("top_k", -1), "top_p": sampling_params.get("top_p", 1.0), - "prompt_logprobs": max(topk_prompt_logprobs, 1) if prompt_logprobs else 0, + "prompt_logprobs": max(topk_prompt_logprobs, 1) if prompt_logprobs else None, "logprobs": 0, # response logprobs } if sampling_params.get("stop_strings"): From 9a5eb863da446df755d9743b86367434de41cfc9 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Thu, 12 Mar 2026 14:27:40 -0700 Subject: [PATCH 37/38] lint --- .../remote_inference_client.py | 2 +- skyrl/backends/skyrl_train_backend.py | 20 +++++++++---------- .../test_new_inference_generation.py | 3 +-- .../gpu_ci/test_save_weights_for_sampler.py | 2 -- .../test_remote_inference_client.py | 2 +- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 8e00361062..019f3b27ba 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -319,7 +319,7 @@ async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: prompt = body.get("prompt", {}) # ModelInput dict: {"chunks": [{"tokens": [...]}]} num_samples = body.get("num_samples", 1) sampling_params = body.get("sampling_params", {}) - prompt_logprobs = body.get("prompt_logprobs", False) + prompt_logprobs = body.get("include_prompt_logprobs", False) topk_prompt_logprobs = body.get("topk_prompt_logprobs", 0) prompt_token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 803101943f..4072718920 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -22,23 +22,21 @@ from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import ( create_ray_wrapped_inference_engines, ) +from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import ( + RemoteInferenceClient, +) +from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter +from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup +from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch -from skyrl.train.utils.utils import initialize_ray, get_ray_pg_ready_with_timeout -from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str -from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S, _SKYRL_USE_NEW_INFERENCE -from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import ( - create_ray_wrapped_inference_engines, -) -from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S from skyrl.tinker import types +from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str +from skyrl.train.utils.utils import get_ray_pg_ready_with_timeout, initialize_ray from skyrl.utils.log import logger from skyrl.utils.tok import get_tokenizer -from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient -from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter -from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup -from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args class SkyRLTrainBackendOverrides(BaseModel, extra="allow"): diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py index d7e71a066b..246bcfa915 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py @@ -24,7 +24,6 @@ import requests from litellm import acompletion as litellm_async_completion from litellm import atext_completion as litellm_async_text_completion - from pydantic import BaseModel from transformers import AutoTokenizer @@ -35,9 +34,9 @@ from skyrl.backends.skyrl_train.inference_engines.utils import ( get_sampling_params_for_backend, ) +from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE from skyrl.train.config import SkyRLTrainConfig from tests.backends.skyrl_train.gpu.utils import InferenceEngineState, get_test_prompts -from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE MODEL_QWEN2_5 = "Qwen/Qwen2.5-0.5B-Instruct" SERVED_MODEL_NAME = "my_qwen" diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py index 7d7333c3ae..8c17f02660 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py @@ -17,9 +17,7 @@ from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.train.config import SkyRLTrainConfig from skyrl.train.utils.utils import validate_cfg -from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.utils.tok import get_tokenizer - from tests.backends.skyrl_train.gpu.utils import ( InferenceEngineState, get_test_prompts, diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index 63eee4f086..a7f24cf10d 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -13,11 +13,11 @@ from fastapi import FastAPI, Request from skyrl.backends.skyrl_train.inference_servers.common import get_open_port -from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import ( PauseMode, RemoteInferenceClient, ) +from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE def create_mock_vllm_server(server_id: int) -> FastAPI: From f9874a003f379e4f15455e34a40b90f5fa484305 Mon Sep 17 00:00:00 2001 From: Nithin Chalapathi Date: Mon, 23 Mar 2026 17:58:41 -0700 Subject: [PATCH 38/38] remove /render api --- .../remote_inference_client.py | 34 ------------------- .../test_new_inference_generation.py | 26 -------------- .../test_remote_inference_client.py | 31 ----------------- 3 files changed, 91 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 019f3b27ba..805acf0ec5 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -447,40 +447,6 @@ async def chat_completion( raise_for_status(resp, response) return response - async def render_chat_completion( - self, - request_payload: Dict[str, Any], - ) -> List[Any]: - """ - Render chat messages into a tokenized prompt via /v1/chat/completions/render. - - Applies the model's chat template and tokenizes without generating text. - - Args: - request_payload: Dict with {"json": , "headers": }. - The request body should contain messages and optional chat template params. - session_id can be included in json for consistent routing. - - Returns: - List of [conversation, engine_prompts] where engine_prompts contains - dicts with "prompt" and "prompt_token_ids". - """ - body = request_payload.get("json", {}) - - session_id = body.pop("session_id", None) - - headers = {"Content-Type": "application/json"} - if session_id: - headers["X-Session-ID"] = str(session_id) - - session = await self._get_session() - url = f"{self.proxy_url}/v1/chat/completions/render" - - async with session.post(url, json=body, headers=headers) as resp: - response = await resp.json() - raise_for_status(resp, response) - return response - async def completion( self, request_payload: Dict[str, Any], diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py index 246bcfa915..66ebc443e6 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_new_inference_generation.py @@ -34,7 +34,6 @@ from skyrl.backends.skyrl_train.inference_engines.utils import ( get_sampling_params_for_backend, ) -from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE from skyrl.train.config import SkyRLTrainConfig from tests.backends.skyrl_train.gpu.utils import InferenceEngineState, get_test_prompts @@ -592,28 +591,3 @@ def test_client_tokenize_detokenize_roundtrip(vllm_server: InferenceEngineState) decoded = asyncio.run(client.detokenize([token_ids]))[0] assert decoded == text - - -@pytest.mark.vllm -@pytest.mark.skipif(not _SKYRL_USE_NEW_INFERENCE, reason="Render API only supported with new inference client") -def test_client_render_chat_completion(vllm_server: InferenceEngineState): - """Test render_chat_completion via RemoteInferenceClient against real vLLM.""" - client = vllm_server.client - messages = [{"role": "user", "content": "Hello world!"}] - request_payload = { - "json": { - "messages": messages, - }, - "headers": {}, - } - result = asyncio.run(client.render_chat_completion(request_payload)) - # vLLM returns [conversation, engine_prompts] - assert isinstance(result, list) - assert len(result) == 2 - conversation, engine_prompts = result - # engine_prompts should have prompt_token_ids matching local tokenizer output - assert len(engine_prompts) > 0 - assert "prompt_token_ids" in engine_prompts[0] - tokenizer = AutoTokenizer.from_pretrained(MODEL_QWEN2_5) - expected_token_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - assert engine_prompts[0]["prompt_token_ids"] == expected_token_ids diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index a7f24cf10d..893828254b 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -17,7 +17,6 @@ PauseMode, RemoteInferenceClient, ) -from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE def create_mock_vllm_server(server_id: int) -> FastAPI: @@ -65,15 +64,6 @@ async def generate(request: Request): async def chat_completions(request: Request): return {"choices": [{"message": {"content": f"Chat from server {server_id}"}}]} - @app.post("/v1/chat/completions/render") - async def render_chat_completion(request: Request): - body = await request.json() - messages = body.get("messages", []) - return [ - messages, # conversation (echo back) - [{"prompt": "rendered prompt", "prompt_token_ids": [1, 2, 3]}], # engine_prompts - ] - @app.post("/tokenize") async def tokenize(request: Request): return {"tokens": [1, 2, 3]} @@ -251,27 +241,6 @@ async def test_completion(self, client): result = await client.completion(request_payload) assert "choices" in result - @pytest.mark.asyncio - @pytest.mark.skipif(not _SKYRL_USE_NEW_INFERENCE, reason="Render API only supported with new inference client") - async def test_render_chat_completion(self, client): - """Test render_chat_completion method.""" - messages = [{"role": "user", "content": "Hello"}] - request_payload = { - "json": { - "model": "test", - "messages": messages, - }, - "headers": {}, - } - result = await client.render_chat_completion(request_payload) - assert isinstance(result, list) - assert len(result) == 2 - conversation, engine_prompts = result - assert conversation == messages - assert len(engine_prompts) == 1 - assert engine_prompts[0]["prompt_token_ids"] == [1, 2, 3] - assert engine_prompts[0]["prompt"] == "rendered prompt" - @pytest.mark.asyncio async def test_tokenize(self, client): """Test tokenize method."""