-
Notifications
You must be signed in to change notification settings - Fork 285
[skyrl] Add /sample endpoint to RemoteInferenceClient following Tinker API #1396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
929e25b
c401a8c
57f895b
3b6d555
5b8d1c7
9af1924
b6fdbc1
5fa8d70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,16 @@ | |
|
|
||
| _DATA_PLANE_RETRIES = 30 | ||
|
|
||
| _TINKER_SAMPLE_TO_VLLM_PARAM_MAP = { | ||
| "temperature": "temperature", | ||
| "max_tokens": "max_tokens", | ||
| "seed": "seed", | ||
| "top_k": "top_k", | ||
| "top_p": "top_p", | ||
| "stop_strings": "stop", | ||
| "stop_tokens": "stop_token_ids", | ||
| } | ||
|
|
||
| if TYPE_CHECKING: | ||
| from skyrl.backends.skyrl_train.weight_sync.transfer_strategy import ( | ||
| WeightSyncInitInfo, | ||
|
|
@@ -361,6 +371,81 @@ async def _generate_single( | |
| "response_logprobs": response_logprobs, | ||
| } | ||
|
|
||
| async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: | ||
| """ | ||
| Sample completions via /inference/v1/generate (Tinker API). | ||
|
|
||
| Maps Tinker-style sample requests to the vLLM generate endpoint. | ||
| Uses self._post() for automatic retry + backoff on transient errors. | ||
|
|
||
| Args: | ||
| request_payload: Dict with {"json": <request-body>}. | ||
| Expected keys in json: prompt, num_samples, sampling_params, session_id. | ||
|
|
||
| Returns: | ||
| Dict with type="sample", sequences list, and stub prompt_logprobs fields. | ||
| """ | ||
| session_id, body = _extract_session_id_and_body(request_payload) | ||
|
|
||
| prompt = body.get("prompt", {}) | ||
| num_samples = body.get("num_samples", 1) | ||
| tinker_params = body.get("sampling_params", {}) | ||
|
|
||
| # Flatten prompt chunks → token IDs | ||
| token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] | ||
|
|
||
| # Map Tinker SamplingParams → vLLM format | ||
| sampling_params: Dict[str, Any] = { | ||
| "n": num_samples, | ||
| "logprobs": 0, | ||
| "output_kind": 2, | ||
| } | ||
|
|
||
| for tinker_key, vllm_key in _TINKER_SAMPLE_TO_VLLM_PARAM_MAP.items(): | ||
| val = tinker_params.get(tinker_key) | ||
| if val is not None: | ||
| sampling_params[vllm_key] = val | ||
|
|
||
| effective_model = self.active_lora_name if self.active_lora_name else self.model_name | ||
|
|
||
| payload = { | ||
| "sampling_params": sampling_params, | ||
| "model": effective_model, | ||
| "token_ids": token_ids, | ||
| } | ||
|
|
||
| headers = {"Content-Type": "application/json"} | ||
| if session_id: | ||
| headers["X-Session-ID"] = str(session_id) | ||
|
|
||
| url = f"{self.proxy_url}/inference/v1/generate" | ||
| response = await self._post(url, json=payload, headers=headers) | ||
|
|
||
| # Transform response choices → sequences | ||
| sequences = [] | ||
| for choice in response.get("choices", []): | ||
| seq_logprobs: Optional[List[float]] = None | ||
| logprobs_data = choice.get("logprobs") | ||
| if logprobs_data is not None: | ||
| logprobs_content = logprobs_data.get("content", []) | ||
| if logprobs_content: | ||
| seq_logprobs = [lp["logprob"] for lp in logprobs_content] | ||
|
|
||
| sequences.append( | ||
| { | ||
| "tokens": choice["token_ids"], | ||
| "logprobs": seq_logprobs, | ||
| "stop_reason": choice.get("finish_reason"), | ||
| } | ||
| ) | ||
|
|
||
| return { | ||
| "type": "sample", | ||
| "sequences": sequences, | ||
| "prompt_logprobs": None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going forward, we might want / need to support this :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! Next PR will include |
||
| "topk_prompt_logprobs": None, | ||
| } | ||
|
|
||
| async def chat_completion( | ||
| self, | ||
| request_payload: Dict[str, Any], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -592,3 +592,90 @@ def test_client_tokenize_detokenize_roundtrip(vllm_server: InferenceEngineState) | |
|
|
||
| decoded = asyncio.run(client.detokenize([token_ids]))[0] | ||
| assert decoded == text | ||
|
|
||
|
|
||
| # --- Group C: RemoteInferenceClient.sample() (Tinker API) --- | ||
|
|
||
|
|
||
| def _build_sample_payload( | ||
| token_ids: List[int], | ||
| num_samples: int = 1, | ||
| sampling_params: Dict[str, Any] | None = None, | ||
| session_id: str | None = None, | ||
| ) -> Dict[str, Any]: | ||
| """Build a Tinker-format sample request payload.""" | ||
| body: Dict[str, Any] = { | ||
| "prompt": {"chunks": [{"tokens": token_ids}]}, | ||
| "num_samples": num_samples, | ||
| "sampling_params": sampling_params or {"temperature": 0.7, "max_tokens": 64}, | ||
| } | ||
| if session_id is not None: | ||
| body["session_id"] = session_id | ||
| return {"json": body} | ||
|
|
||
|
|
||
| def _get_test_token_ids(model: str) -> List[int]: | ||
| """Tokenize a single test prompt into token IDs.""" | ||
| tokenizer = AutoTokenizer.from_pretrained(model) | ||
| conv = get_test_prompts(model, num_samples=1)[0] | ||
| token_ids = tokenizer.apply_chat_template( | ||
| conv, | ||
| add_generation_prompt=True, | ||
| tokenize=True, | ||
| ) | ||
| return token_ids | ||
|
|
||
|
|
||
| @pytest.mark.vllm | ||
| def test_client_sample(vllm_server: InferenceEngineState): | ||
| """Test sample with n=1 returns correct Tinker response structure.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| payload = _build_sample_payload(token_ids, num_samples=1, sampling_params={"temperature": 0.7, "max_tokens": 64}) | ||
|
|
||
| result = asyncio.run(client.sample(payload)) | ||
|
|
||
| assert result["type"] == "sample" | ||
| assert len(result["sequences"]) == 1 | ||
|
|
||
| seq = result["sequences"][0] | ||
| assert isinstance(seq["tokens"], list) and len(seq["tokens"]) > 0 | ||
| assert all(isinstance(t, int) for t in seq["tokens"]) | ||
| assert isinstance(seq["logprobs"], list) and len(seq["logprobs"]) > 0 | ||
| assert all(isinstance(lp, float) for lp in seq["logprobs"]) | ||
| assert seq["stop_reason"] in ["stop", "length"] | ||
|
|
||
|
|
||
| @pytest.mark.vllm | ||
| def test_client_sample_multiple(vllm_server: InferenceEngineState): | ||
| """Test sample with n=3 returns three independent sequences.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| payload = _build_sample_payload(token_ids, num_samples=3, sampling_params={"temperature": 1.0, "max_tokens": 64}) | ||
|
|
||
| result = asyncio.run(client.sample(payload)) | ||
|
|
||
| assert result["type"] == "sample" | ||
| assert len(result["sequences"]) == 3 | ||
|
|
||
| for seq in result["sequences"]: | ||
| assert isinstance(seq["tokens"], list) and len(seq["tokens"]) > 0 | ||
| assert isinstance(seq["logprobs"], list) and len(seq["logprobs"]) > 0 | ||
| assert seq["stop_reason"] in ["stop", "length"] | ||
|
|
||
| # With temperature=1.0, at least two sequences should differ | ||
| all_tokens = [tuple(seq["tokens"]) for seq in result["sequences"]] | ||
| assert len(set(all_tokens)) > 1, "All 3 sequences are identical at temperature=1.0" | ||
|
|
||
|
|
||
| @pytest.mark.vllm | ||
| def test_client_sample_deterministic(vllm_server: InferenceEngineState): | ||
| """Test that sample with seed + temperature=0 is deterministic across calls.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| params = {"temperature": 0.0, "max_tokens": 32, "seed": 42} | ||
|
|
||
| result1 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
| result2 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
|
|
||
| assert result1["sequences"][0]["tokens"] == result2["sequences"][0]["tokens"] | ||
|
Comment on lines
+672
to
+681
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pull request description mentions adding a unit test for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leaving the arg in for _build_sample_payload since we may want to test it in the future. I'm not sure how to test session based routing in our current setup, so leaving for now. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will need adaptation for multi-modal inputs going forward, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this will have to be the token concatenation we talked about, so it will get replaced.