Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
395f187
remove test guard, add sample method
nithinvc Mar 6, 2026
f0ec828
add session id
nithinvc Mar 6, 2026
07ae4de
add render api
nithinvc Mar 6, 2026
713ada8
fix gpu save weights test
nithinvc Mar 6, 2026
8ab564e
add init of new inf backend
nithinvc Mar 6, 2026
7d8e5f4
add test for weight sync
nithinvc Mar 7, 2026
c7e4d52
remove test
nithinvc Mar 9, 2026
5ec91b7
add inference backend weight sync test
nithinvc Mar 9, 2026
f353605
add test back in
nithinvc Mar 9, 2026
6526887
colocate all off
nithinvc Mar 9, 2026
2c0b47f
port fix
nithinvc Mar 9, 2026
5b2fa34
move port allocation to shared commmoon
nithinvc Mar 9, 2026
29125ab
add tests for port reservation
nithinvc Mar 9, 2026
3f07164
remove print statements
nithinvc Mar 9, 2026
c6fbfde
Revert "remove print statements"
nithinvc Mar 9, 2026
c950ced
Revert "add tests for port reservation"
nithinvc Mar 9, 2026
26b7abf
Revert "move port allocation to shared commmoon"
nithinvc Mar 9, 2026
516ae5e
Revert "port fix"
nithinvc Mar 9, 2026
281ba77
remove print statements
nithinvc Mar 9, 2026
4d8298c
stricter tests
nithinvc Mar 9, 2026
4651da7
move imports up
nithinvc Mar 9, 2026
3951469
fmt
nithinvc Mar 10, 2026
2c399e0
add typing
nithinvc Mar 10, 2026
fde5342
add lora_id
nithinvc Mar 10, 2026
d894f82
fix typing
nithinvc Mar 10, 2026
269a81c
fix docstrings
nithinvc Mar 10, 2026
bd26015
remove duplicate test fixture
nithinvc Mar 11, 2026
b5bda89
update to use request_payload format
nithinvc Mar 11, 2026
b7484b3
add mock test
nithinvc Mar 11, 2026
aff3835
add comments + update gpu test
nithinvc Mar 11, 2026
15b2ddb
stronger gpu checks
nithinvc Mar 11, 2026
863c203
update gpu sample api test
nithinvc Mar 11, 2026
7eaa5d8
update render to use payload_request
nithinvc Mar 11, 2026
03ff39d
add log probs extraction
nithinvc Mar 11, 2026
71ca1cf
comment + remove duplicate test
nithinvc Mar 11, 2026
4f5ccd4
Merge branch 'main' into nithinc/inference-server-sample
nithinvc Mar 12, 2026
47c9400
Update skyrl/backends/skyrl_train/inference_servers/remote_inference_…
nithinvc Mar 12, 2026
9a5eb86
lint
nithinvc Mar 12, 2026
f9874a0
remove /render api
nithinvc Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,124 @@ async def _generate_single(
"response_logprobs": response_logprobs if len(response_logprobs) > 0 else None,
}

async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""Sample completions via /inference/v1/generate.

Single request with n in sampling_params. No retry-on-abort.

Args:
request_payload: {"json": {...}, "headers": {...}}
json body matches Tinker SamplingClient.sample() args:
prompt (ModelInput), num_samples, sampling_params (SamplingParams),
include_prompt_logprobs, topk_prompt_logprobs.
session_id is optional for routing.

Returns:
Dict matching Tinker SampleResponse schema.
"""
body = request_payload.get("json", {})
session_id = body.pop("session_id", None)

# Tinker input fields
prompt = body.get("prompt", {}) # ModelInput dict: {"chunks": [{"tokens": [...]}]}
num_samples = body.get("num_samples", 1)
sampling_params = body.get("sampling_params", {})
prompt_logprobs = body.get("include_prompt_logprobs", False)
topk_prompt_logprobs = body.get("topk_prompt_logprobs", 0)

prompt_token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])]

# Tinker types.py SamplingParams -> vLLM /inference/v1/generate sampling_params
vllm_sampling_params = {
"n": num_samples,
"temperature": sampling_params.get("temperature"),
"max_tokens": sampling_params.get("max_tokens"),
"seed": sampling_params.get("seed"),
"top_k": sampling_params.get("top_k", -1),
"top_p": sampling_params.get("top_p", 1.0),
"prompt_logprobs": max(topk_prompt_logprobs, 1) if prompt_logprobs else None,
"logprobs": 0, # response logprobs
}
if sampling_params.get("stop_strings"):
vllm_sampling_params["stop"] = sampling_params["stop_strings"]
if sampling_params.get("stop_tokens"):
vllm_sampling_params["stop_token_ids"] = sampling_params["stop_tokens"]

payload = {
"sampling_params": vllm_sampling_params,
"model": self.model_name,
"token_ids": prompt_token_ids,
}

headers = {"Content-Type": "application/json"}
if session_id:
headers["X-Session-ID"] = str(session_id)

session = await self._get_session()
url = f"{self.proxy_url}/inference/v1/generate"

async with session.post(url, json=payload, headers=headers) as resp:
result = await resp.json()
raise_for_status(resp, result)

# Transform response choices -> tinker type SampleResponse dict
sequences = []
for choice in result["choices"]:
raw_stop = choice.get("finish_reason", "length")
stop_reason = "stop" if raw_stop in ("stop", "stop_token") else "length"

token_ids = choice.get("token_ids", [])

logprobs = None
lp = choice.get("logprobs")
if lp is not None:
logprobs_content = lp.get("content", [])
if logprobs_content:
logprobs = [info["logprob"] if info["logprob"] is not None else 0.0 for info in logprobs_content]
# Convert to tinker type SampledSequence dict
sequences.append(
{
"stop_reason": stop_reason,
"tokens": token_ids,
"logprobs": logprobs,
}
)

# Transform prompt_logprobs from vLLM format to Tinker SampleResponse format
transformed_prompt_logprobs = None
transformed_topk_prompt_logprobs = None
# raw_prompt_logprobs: VLLM type list[dict[int, Logprob]] | None
raw_prompt_logprobs = result.get("prompt_logprobs") if prompt_logprobs else None

if raw_prompt_logprobs is not None:
# prompt_logprobs: single float per token (logprob of the actual prompt token)
transformed_prompt_logprobs = []
for i, pos in enumerate(raw_prompt_logprobs):
if pos is None:
transformed_prompt_logprobs.append(None)
else:
token_key = str(prompt_token_ids[i])
entry = pos.get(token_key)
transformed_prompt_logprobs.append(entry["logprob"] if entry is not None else None)

# topk_prompt_logprobs: list of (token_id, logprob) tuples per position
if topk_prompt_logprobs > 0:
transformed_topk_prompt_logprobs = []
for pos in raw_prompt_logprobs:
if pos is None:
transformed_topk_prompt_logprobs.append(None)
else:
transformed_topk_prompt_logprobs.append(
[(int(tid), info["logprob"]) for tid, info in pos.items()]
)

return {
"type": "sample",
"sequences": sequences,
"prompt_logprobs": transformed_prompt_logprobs,
"topk_prompt_logprobs": transformed_topk_prompt_logprobs,
}

async def chat_completion(
self,
request_payload: Dict[str, Any],
Expand Down
184 changes: 145 additions & 39 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@
from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import (
create_ray_wrapped_inference_engines,
)
from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import (
RemoteInferenceClient,
)
from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter
from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup
from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S
from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S
from skyrl.tinker import types
from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str
from skyrl.train.utils.utils import get_ray_pg_ready_with_timeout, initialize_ray
Expand Down Expand Up @@ -116,6 +122,8 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides):
self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model)
self._inference_engine_client = None
self._inference_engines_initialized = False
self._server_group = None
self._inference_router = None

def has_model(self, model_id: str) -> bool:
return self._model_id == model_id
Expand Down Expand Up @@ -184,19 +192,84 @@ def init_weight_sync_state(self):
self._dispatch.init_weight_sync_state(self._inference_engine_client)
logger.info("Initialized weight sync state for policy model and inference engines.")

def _create_remote_inference_client(self):
"""Create a RemoteInferenceClient using HTTP endpoints.

Mirrors main_base.py._get_new_inference_client() with the same 4-way
branching on external_proxy_url / external_server_urls.
"""
ie_cfg = self._cfg.generator.inference_engine
is_colocated = self._cfg.trainer.placement.colocate_all
external_proxy_url = ie_cfg.external_proxy_url
external_server_urls = ie_cfg.external_server_urls

has_external_proxy = external_proxy_url is not None
has_external_servers = external_server_urls is not None

if has_external_proxy and has_external_servers:
proxy_url = external_proxy_url
server_urls = list(external_server_urls)
logger.info(
f"HTTP Inference: Using fully external setup - proxy_url={proxy_url}, server_urls={server_urls}"
)

elif has_external_proxy and not has_external_servers:
proxy_url = external_proxy_url
server_urls = [proxy_url]
logger.info(f"HTTP Inference: Using external proxy for both data and control plane - proxy_url={proxy_url}")

elif has_external_servers and not has_external_proxy:
server_urls = list(external_server_urls)
self._inference_router = InferenceRouter(server_urls=server_urls)
proxy_url = self._inference_router.start()
logger.info(
f"HTTP Inference: Created internal router over external "
f"servers - server_urls={server_urls}, proxy_url={proxy_url}"
)

else:
cli_args = build_vllm_cli_args(self._cfg)

self._server_group = ServerGroup(
cli_args=cli_args,
num_servers=ie_cfg.num_engines,
placement_group=self._colocate_pg if is_colocated else None,
enable_dp=ie_cfg.data_parallel_size > 1,
)
server_infos = self._server_group.start()
server_urls = [info.url for info in server_infos]

self._inference_router = InferenceRouter(server_urls=server_urls)
proxy_url = self._inference_router.start()
logger.info(
f"HTTP Inference: Built servers and router internally - "
f"proxy_url={proxy_url}, server_urls={server_urls}, colocated={is_colocated}"
)

return RemoteInferenceClient(
proxy_url=proxy_url,
server_urls=server_urls,
model_name=self._cfg.trainer.policy.model.path,
)
Comment on lines +249 to +253
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 _create_remote_inference_client ignores served_model_name, causing request rejection when configured

_create_remote_inference_client always uses self._cfg.trainer.policy.model.path as model_name, but when served_model_name is configured in the inference engine config, the vLLM server only accepts that name (not the model path). This causes all data plane requests (sample, generate, chat_completion, etc.) to fail with a "model not found" error.

The old InferenceEngineClient correctly handles this at skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:68-70, and the test utility at tests/backends/skyrl_train/gpu/utils.py:512 also correctly uses served_model_name if served_model_name else cfg.trainer.policy.model.path. The production code here omits this logic.

Note: main_base.py:377 has the same pre-existing issue, which this code mirrors — but it should be fixed here nonetheless.

Suggested change
return RemoteInferenceClient(
proxy_url=proxy_url,
server_urls=server_urls,
model_name=self._cfg.trainer.policy.model.path,
)
ie_served_name = self._cfg.generator.inference_engine.served_model_name
return RemoteInferenceClient(
proxy_url=proxy_url,
server_urls=server_urls,
model_name=ie_served_name if ie_served_name else self._cfg.trainer.policy.model.path,
)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.


def _ensure_inference_engines(self):
"""Lazily create inference engines and init weight sync on first sampling-related call."""
if self._inference_engines_initialized:
return

logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines")
self._inference_engine_client = InferenceEngineClient(
create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer),
self._tokenizer,
self._cfg.trainer.policy.model.path,
self._cfg.trainer.policy.model.lora,
self._cfg.generator.inference_engine,
)
if _SKYRL_USE_NEW_INFERENCE:
logger.info("Using new HTTP-based inference client (RemoteInferenceClient)")
self._inference_engine_client = self._create_remote_inference_client()
else:
logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines")
self._inference_engine_client = InferenceEngineClient(
create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer),
self._tokenizer,
self._cfg.trainer.policy.model.path,
self._cfg.trainer.policy.model.lora,
self._cfg.generator.inference_engine,
)

self._dispatch.set_inference_engine_client(self._inference_engine_client)
self.init_weight_sync_state()
self._inference_engines_initialized = True
Expand Down Expand Up @@ -533,28 +606,43 @@ async def sample_all():
prompt = prepared_batch.all_prompts[i]
sampling_params = prepared_batch.all_sampling_params[i]

# Pass through common fields; only stop needs name translation
# (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids)
params_dict = {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"logprobs": 0,
}
if sampling_params.stop_strings:
params_dict["stop"] = sampling_params.stop_strings
if sampling_params.stop_tokens:
params_dict["stop_token_ids"] = sampling_params.stop_tokens

tasks.append(
self._inference_engine_client.sample(
prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
if _SKYRL_USE_NEW_INFERENCE:
# Right now, prompt is list[int] (token IDs), so we wrap in ModelInput format
json_body = {
"prompt": {"chunks": [{"tokens": prompt}]},
"num_samples": 1, # Tinker batches multiple samples separately
"sampling_params": {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"stop_tokens": sampling_params.stop_tokens,
"stop_strings": sampling_params.stop_strings,
},
}
tasks.append(self._inference_engine_client.sample({"json": json_body, "headers": {}}))
else:
params_dict = {
"temperature": sampling_params.temperature,
"max_tokens": sampling_params.max_tokens,
"seed": sampling_params.seed,
"top_k": sampling_params.top_k,
"top_p": sampling_params.top_p,
"logprobs": 0,
}
if sampling_params.stop_strings:
params_dict["stop"] = sampling_params.stop_strings
if sampling_params.stop_tokens:
params_dict["stop_token_ids"] = sampling_params.stop_tokens

tasks.append(
self._inference_engine_client.sample(
prompt_token_ids=prompt,
num_samples=1, # Tinker batches multiple samples separately
sampling_params=params_dict,
)
)
)

return await asyncio.gather(*tasks, return_exceptions=True)

Expand All @@ -565,14 +653,25 @@ async def sample_all():
# We preserve these to include error messages in responses

# 4. Aggregate results by request
return self._aggregate_sample_results(prepared_batch, sample_outputs)
return self._aggregate_sample_results(
prepared_batch, sample_outputs, use_new_inference=_SKYRL_USE_NEW_INFERENCE
)

def _aggregate_sample_results(
self,
prepared_batch: types.PreparedSampleBatch,
sample_outputs: list,
use_new_inference: bool = False,
) -> dict[str, types.SampleOutput | types.ErrorResponse]:
"""Convert InferenceEngineClient outputs to Tinker format."""
"""Convert inference outputs to Tinker format.

Args:
prepared_batch: The prepared sample batch.
sample_outputs: List of outputs from inference client.
use_new_inference: If True, outputs are SampleResponse dicts from
RemoteInferenceClient. If False, outputs are InferenceEngineOutput
dicts from InferenceEngineClient.
"""
results = {}

for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices:
Expand All @@ -595,13 +694,20 @@ def _aggregate_sample_results(
logger.error(error_msg)
break

# Extract tokens and logprobs
response_tokens = output["response_ids"][0]
response_logprobs = (output.get("response_logprobs") or [[]])[0]
stop_reason_raw = output["stop_reasons"][0]

# Map vLLM stop reason to Tinker format
stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length"
if use_new_inference:
# New inference server: SampleResponse dict
seq = output["sequences"][0]
response_tokens = seq["tokens"]
response_logprobs = seq.get("logprobs") or []
stop_reason = seq["stop_reason"]
else:
# Old inference engine: InferenceEngineOutput
# Extract tokens and logprobs
response_tokens = output["response_ids"][0]
response_logprobs = (output.get("response_logprobs") or [[]])[0]
stop_reason_raw = output["stop_reasons"][0]
# Map vLLM stop reason to Tinker format
stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length"

# Ensure logprobs exist (critical for RL)
if response_logprobs is None or len(response_logprobs) == 0:
Expand Down
Loading