Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@

_DATA_PLANE_RETRIES = 30

_TINKER_SAMPLE_TO_VLLM_PARAM_MAP = {
"temperature": "temperature",
"max_tokens": "max_tokens",
"seed": "seed",
"top_k": "top_k",
"top_p": "top_p",
"stop_strings": "stop",
"stop_tokens": "stop_token_ids",
}

if TYPE_CHECKING:
from skyrl.backends.skyrl_train.weight_sync.transfer_strategy import (
WeightSyncInitInfo,
Expand Down Expand Up @@ -361,6 +371,81 @@ async def _generate_single(
"response_logprobs": response_logprobs,
}

async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""
Sample completions via /inference/v1/generate (Tinker API).

Maps Tinker-style sample requests to the vLLM generate endpoint.
Uses self._post() for automatic retry + backoff on transient errors.

Args:
request_payload: Dict with {"json": <request-body>}.
Expected keys in json: prompt, num_samples, sampling_params, session_id.

Returns:
Dict with type="sample", sequences list, and stub prompt_logprobs fields.
"""
session_id, body = _extract_session_id_and_body(request_payload)

prompt = body.get("prompt", {})
num_samples = body.get("num_samples", 1)
tinker_params = body.get("sampling_params", {})

# Flatten prompt chunks → token IDs
token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need adaptation for multi-modal inputs going forward, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will have to be the token concatenation we talked about, so it will get replaced.


# Map Tinker SamplingParams → vLLM format
sampling_params: Dict[str, Any] = {
"n": num_samples,
"logprobs": 0,
"output_kind": 2,
}

for tinker_key, vllm_key in _TINKER_SAMPLE_TO_VLLM_PARAM_MAP.items():
val = tinker_params.get(tinker_key)
if val is not None:
sampling_params[vllm_key] = val

effective_model = self.active_lora_name if self.active_lora_name else self.model_name

payload = {
"sampling_params": sampling_params,
"model": effective_model,
"token_ids": token_ids,
}

headers = {"Content-Type": "application/json"}
if session_id:
headers["X-Session-ID"] = str(session_id)

url = f"{self.proxy_url}/inference/v1/generate"
response = await self._post(url, json=payload, headers=headers)

# Transform response choices → sequences
sequences = []
for choice in response.get("choices", []):
seq_logprobs: Optional[List[float]] = None
logprobs_data = choice.get("logprobs")
if logprobs_data is not None:
logprobs_content = logprobs_data.get("content", [])
if logprobs_content:
seq_logprobs = [lp["logprob"] for lp in logprobs_content]

sequences.append(
{
"tokens": choice["token_ids"],
"logprobs": seq_logprobs,
"stop_reason": choice.get("finish_reason"),
}
)

return {
"type": "sample",
"sequences": sequences,
"prompt_logprobs": None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going forward, we might want / need to support this :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Next PR will include prompt_logprobs but I need to check how they handle prompt logprobs for vision to make sure we handle that

"topk_prompt_logprobs": None,
}

async def chat_completion(
self,
request_payload: Dict[str, Any],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,90 @@ def test_client_tokenize_detokenize_roundtrip(vllm_server: InferenceEngineState)

decoded = asyncio.run(client.detokenize([token_ids]))[0]
assert decoded == text


# --- Group C: RemoteInferenceClient.sample() (Tinker API) ---


def _build_sample_payload(
token_ids: List[int],
num_samples: int = 1,
sampling_params: Dict[str, Any] | None = None,
session_id: str | None = None,
) -> Dict[str, Any]:
"""Build a Tinker-format sample request payload."""
body: Dict[str, Any] = {
"prompt": {"chunks": [{"tokens": token_ids}]},
"num_samples": num_samples,
"sampling_params": sampling_params or {"temperature": 0.7, "max_tokens": 64},
}
if session_id is not None:
body["session_id"] = session_id
return {"json": body}


def _get_test_token_ids(model: str) -> List[int]:
"""Tokenize a single test prompt into token IDs."""
tokenizer = AutoTokenizer.from_pretrained(model)
conv = get_test_prompts(model, num_samples=1)[0]
token_ids = tokenizer.apply_chat_template(
conv,
add_generation_prompt=True,
tokenize=True,
)
return token_ids


@pytest.mark.vllm
def test_client_sample(vllm_server: InferenceEngineState):
"""Test sample with n=1 returns correct Tinker response structure."""
client = vllm_server.client
token_ids = _get_test_token_ids(MODEL_QWEN2_5)
payload = _build_sample_payload(token_ids, num_samples=1, sampling_params={"temperature": 0.7, "max_tokens": 64})

result = asyncio.run(client.sample(payload))

assert result["type"] == "sample"
assert len(result["sequences"]) == 1

seq = result["sequences"][0]
assert isinstance(seq["tokens"], list) and len(seq["tokens"]) > 0
assert all(isinstance(t, int) for t in seq["tokens"])
assert isinstance(seq["logprobs"], list) and len(seq["logprobs"]) > 0
assert all(isinstance(lp, float) for lp in seq["logprobs"])
assert seq["stop_reason"] in ["stop", "length"]


@pytest.mark.vllm
def test_client_sample_multiple(vllm_server: InferenceEngineState):
"""Test sample with n=3 returns three independent sequences."""
client = vllm_server.client
token_ids = _get_test_token_ids(MODEL_QWEN2_5)
payload = _build_sample_payload(token_ids, num_samples=3, sampling_params={"temperature": 1.0, "max_tokens": 64})

result = asyncio.run(client.sample(payload))

assert result["type"] == "sample"
assert len(result["sequences"]) == 3

for seq in result["sequences"]:
assert isinstance(seq["tokens"], list) and len(seq["tokens"]) > 0
assert isinstance(seq["logprobs"], list) and len(seq["logprobs"]) > 0
assert seq["stop_reason"] in ["stop", "length"]

# With temperature=1.0, at least two sequences should differ
all_tokens = [tuple(seq["tokens"]) for seq in result["sequences"]]
assert len(set(all_tokens)) > 1, "All 3 sequences are identical at temperature=1.0"


@pytest.mark.vllm
def test_client_sample_deterministic(vllm_server: InferenceEngineState):
"""Test that sample with seed + temperature=0 is deterministic across calls."""
client = vllm_server.client
token_ids = _get_test_token_ids(MODEL_QWEN2_5)
params = {"temperature": 0.0, "max_tokens": 32, "seed": 42}

result1 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params)))
result2 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params)))

assert result1["sequences"][0]["tokens"] == result2["sequences"][0]["tokens"]
Comment on lines +672 to +681
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pull request description mentions adding a unit test for session_id routing for the sample method, but it seems to be missing from the submitted tests. Please consider adding a test case that utilizes the session_id parameter in _build_sample_payload to verify that session-based routing works as expected for the new endpoint.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving the arg in for _build_sample_payload since we may want to test it in the future. I'm not sure how to test session based routing in our current setup, so leaving for now.

Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,24 @@ async def completions(request: Request):
@app.post("/inference/v1/generate")
async def generate(request: Request):
body = await request.json() # Consume body
num_prompts = len(body.get("token_ids", []))
sp = body.get("sampling_params", {})
n = sp.get("n", 1)
# If logprobs is explicitly set (sample path), use n for num_choices.
# Otherwise (generate path), use len(token_ids) for per-prompt responses.
if "logprobs" in sp:
num_choices = n
else:
num_choices = 1

return {
"choices": [
{"request_id": "dummy", "token_ids": [i, i + 1, i + 2], "finish_reason": "stop"}
for i in range(num_prompts)
{
"request_id": "dummy",
"token_ids": [i, i + 1, i + 2],
"finish_reason": "stop",
"logprobs": {"content": [{"logprob": -0.1 * (i + 1)}]},
}
for i in range(num_choices)
]
}

Expand Down Expand Up @@ -421,6 +433,50 @@ async def test_get_world_size(self, client):
assert total_world_size2 == 4


class TestSample:
"""Test sample() method (Tinker API)."""

@pytest.mark.asyncio
async def test_sample(self, client):
"""Test sample with n=1 returns correct structure."""
request_payload = {
"json": {
"prompt": {"chunks": [{"tokens": [10, 20, 30]}]},
"num_samples": 1,
"sampling_params": {"temperature": 0.7, "max_tokens": 64},
}
}
result = await client.sample(request_payload)

assert result["type"] == "sample"
assert result["prompt_logprobs"] is None
assert result["topk_prompt_logprobs"] is None
assert len(result["sequences"]) == 1

seq = result["sequences"][0]
assert seq["tokens"] == [0, 1, 2]
assert seq["logprobs"] == [-0.1]
assert seq["stop_reason"] == "stop"

@pytest.mark.asyncio
async def test_sample_n2(self, client):
"""Test sample with n=2 returns two sequences."""
request_payload = {
"json": {
"prompt": {"chunks": [{"tokens": [1, 2]}, {"tokens": [3]}]},
"num_samples": 2,
"sampling_params": {"temperature": 1.0, "max_tokens": 32},
}
}
result = await client.sample(request_payload)

assert len(result["sequences"]) == 2
assert result["sequences"][0]["tokens"] == [0, 1, 2]
assert result["sequences"][1]["tokens"] == [1, 2, 3]
assert result["sequences"][0]["logprobs"] == [-0.1]
assert result["sequences"][1]["logprobs"] == [-0.2]


class TestRenderChatCompletion:
"""Test render_chat_completion method (multimodal and text-only)."""

Expand Down
Loading