-
Notifications
You must be signed in to change notification settings - Fork 311
[skyrl] Add /sample endpoint to RemoteInferenceClient following Tinker API #1396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
929e25b
c401a8c
57f895b
3b6d555
5b8d1c7
9af1924
b6fdbc1
5fa8d70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,7 @@ | |
|
|
||
| import asyncio | ||
| import logging | ||
| import os | ||
| from dataclasses import dataclass, field | ||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | ||
|
|
@@ -361,6 +362,90 @@ async def _generate_single( | |
| "response_logprobs": response_logprobs, | ||
| } | ||
|
|
||
| async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: | ||
| """ | ||
| Sample completions via /inference/v1/generate (Tinker API). | ||
|
|
||
| Maps Tinker-style sample requests to the vLLM generate endpoint. | ||
| Uses self._post() for automatic retry + backoff on transient errors. | ||
|
|
||
| Args: | ||
| request_payload: Dict with {"json": <request-body>}. | ||
| Expected keys in json: prompt, num_samples, sampling_params, session_id. | ||
|
|
||
| Returns: | ||
| Dict with type="sample", sequences list, and stub prompt_logprobs fields. | ||
| """ | ||
| session_id, body = _extract_session_id_and_body(request_payload) | ||
|
|
||
| prompt = body.get("prompt", {}) | ||
| num_samples = body.get("num_samples", 1) | ||
| tinker_params = body.get("sampling_params", {}) | ||
|
|
||
| # Flatten prompt chunks → token IDs | ||
| token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] | ||
|
|
||
| # Map Tinker SamplingParams → vLLM format | ||
| sampling_params: Dict[str, Any] = { | ||
| "n": num_samples, | ||
| "logprobs": 0, | ||
| "output_kind": 2, | ||
| } | ||
| _PARAM_MAP = { | ||
| "temperature": "temperature", | ||
| "max_tokens": "max_tokens", | ||
| "seed": "seed", | ||
| "top_k": "top_k", | ||
| "top_p": "top_p", | ||
| "stop": "stop_strings", | ||
| "stop_token_ids": "stop_tokens", | ||
|
nithinvc marked this conversation as resolved.
Outdated
|
||
| } | ||
|
nithinvc marked this conversation as resolved.
Outdated
|
||
| for tinker_key, vllm_key in _PARAM_MAP.items(): | ||
| val = tinker_params.get(tinker_key) | ||
| if val is not None: | ||
| sampling_params[vllm_key] = val | ||
|
|
||
| effective_model = self.active_lora_name if self.active_lora_name else self.model_name | ||
|
|
||
| payload = { | ||
| "sampling_params": sampling_params, | ||
| "model": effective_model, | ||
| "token_ids": token_ids, | ||
| } | ||
|
|
||
| headers = {"Content-Type": "application/json"} | ||
| if session_id: | ||
| headers["X-Session-ID"] = str(session_id) | ||
|
|
||
| url = f"{self.proxy_url}/inference/v1/generate" | ||
| response = await self._post(url, json=payload, headers=headers) | ||
|
|
||
| # Transform response choices → sequences | ||
| sequences = [] | ||
| logger.info("num choices: %d", len(response.get("choices", []))) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Always logging with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I put it in for debugging originally. It shouldn't be in & I removed it |
||
| for choice in response.get("choices", []): | ||
| seq_logprobs: Optional[List[float]] = None | ||
| logprobs_data = choice.get("logprobs") | ||
| if logprobs_data is not None: | ||
| logprobs_content = logprobs_data.get("content", []) | ||
| if logprobs_content: | ||
| seq_logprobs = [lp["logprob"] for lp in logprobs_content] | ||
|
|
||
| sequences.append( | ||
| { | ||
| "tokens": choice["token_ids"], | ||
| "logprobs": seq_logprobs, | ||
| "stop_reason": choice.get("finish_reason"), | ||
| } | ||
| ) | ||
|
|
||
| return { | ||
| "type": "sample", | ||
| "sequences": sequences, | ||
| "prompt_logprobs": None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going forward, we might want / need to support this :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! Next PR will include |
||
| "topk_prompt_logprobs": None, | ||
| } | ||
|
|
||
| async def chat_completion( | ||
| self, | ||
| request_payload: Dict[str, Any], | ||
|
|
@@ -838,7 +923,7 @@ def _release(proto: aiohttp.client_proto.ResponseHandler) -> None: | |
| fd = tsock.fileno() | ||
| if fd != -1: | ||
| try: | ||
| tsock.close() | ||
| os.close(fd) | ||
|
nithinvc marked this conversation as resolved.
Outdated
|
||
| except OSError: | ||
| pass | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -592,3 +592,90 @@ def test_client_tokenize_detokenize_roundtrip(vllm_server: InferenceEngineState) | |
|
|
||
| decoded = asyncio.run(client.detokenize([token_ids]))[0] | ||
| assert decoded == text | ||
|
|
||
|
|
||
| # --- Group C: RemoteInferenceClient.sample() (Tinker API) --- | ||
|
|
||
|
|
||
| def _build_sample_payload( | ||
| token_ids: List[int], | ||
| num_samples: int = 1, | ||
| sampling_params: Dict[str, Any] | None = None, | ||
| session_id: str | None = None, | ||
| ) -> Dict[str, Any]: | ||
| """Build a Tinker-format sample request payload.""" | ||
| body: Dict[str, Any] = { | ||
| "prompt": {"chunks": [{"tokens": token_ids}]}, | ||
| "num_samples": num_samples, | ||
| "sampling_params": sampling_params or {"temperature": 0.7, "max_tokens": 64}, | ||
| } | ||
| if session_id is not None: | ||
| body["session_id"] = session_id | ||
| return {"json": body} | ||
|
|
||
|
|
||
| def _get_test_token_ids(model: str) -> List[int]: | ||
| """Tokenize a single test prompt into token IDs.""" | ||
| tokenizer = AutoTokenizer.from_pretrained(model) | ||
| conv = get_test_prompts(model, num_samples=1)[0] | ||
| token_ids = tokenizer.apply_chat_template( | ||
| conv, | ||
| add_generation_prompt=True, | ||
| tokenize=True, | ||
| ) | ||
| return token_ids | ||
|
|
||
|
|
||
| @pytest.mark.vllm | ||
| def test_client_sample(vllm_server: InferenceEngineState): | ||
| """Test sample with n=1 returns correct Tinker response structure.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| payload = _build_sample_payload(token_ids, num_samples=1, sampling_params={"temperature": 0.7, "max_tokens": 64}) | ||
|
|
||
| result = asyncio.run(client.sample(payload)) | ||
|
|
||
| assert result["type"] == "sample" | ||
| assert len(result["sequences"]) == 1 | ||
|
|
||
| seq = result["sequences"][0] | ||
| assert isinstance(seq["tokens"], list) and len(seq["tokens"]) > 0 | ||
| assert all(isinstance(t, int) for t in seq["tokens"]) | ||
| assert isinstance(seq["logprobs"], list) and len(seq["logprobs"]) > 0 | ||
| assert all(isinstance(lp, float) for lp in seq["logprobs"]) | ||
| assert seq["stop_reason"] in ["stop", "length"] | ||
|
|
||
|
|
||
| @pytest.mark.vllm | ||
| def test_client_sample_multiple(vllm_server: InferenceEngineState): | ||
| """Test sample with n=3 returns three independent sequences.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| payload = _build_sample_payload(token_ids, num_samples=3, sampling_params={"temperature": 1.0, "max_tokens": 64}) | ||
|
|
||
| result = asyncio.run(client.sample(payload)) | ||
|
|
||
| assert result["type"] == "sample" | ||
| assert len(result["sequences"]) == 3 | ||
|
|
||
| for seq in result["sequences"]: | ||
| assert isinstance(seq["tokens"], list) and len(seq["tokens"]) > 0 | ||
| assert isinstance(seq["logprobs"], list) and len(seq["logprobs"]) > 0 | ||
| assert seq["stop_reason"] in ["stop", "length"] | ||
|
|
||
| # With temperature=1.0, at least two sequences should differ | ||
| all_tokens = [tuple(seq["tokens"]) for seq in result["sequences"]] | ||
| assert len(set(all_tokens)) > 1, "All 3 sequences are identical at temperature=1.0" | ||
|
|
||
|
|
||
| @pytest.mark.vllm | ||
| def test_client_sample_deterministic(vllm_server: InferenceEngineState): | ||
| """Test that sample with seed + temperature=0 is deterministic across calls.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| params = {"temperature": 0.0, "max_tokens": 32, "seed": 42} | ||
|
|
||
| result1 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
| result2 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
|
|
||
| assert result1["sequences"][0]["tokens"] == result2["sequences"][0]["tokens"] | ||
|
Comment on lines
+672
to
+681
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pull request description mentions adding a unit test for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving the arg in for _build_sample_payload since we may want to test it in the future. I'm not sure how to test session based routing in our current setup, so leaving for now. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will need adaptation for multi-modal inputs going forward, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this will have to be the token concatenation we talked about, so it will get replaced.