[skyrl] Add /sample endpoint to RemoteInferenceClient following Tinker API#1396
[skyrl] Add /sample endpoint to RemoteInferenceClient following Tinker API#1396nithinvc wants to merge 8 commits intoNovaSky-AI:mainfrom
Conversation
- Add RemoteInferenceClient.sample() mapping Tinker-style sample requests to the vLLM /inference/v1/generate endpoint - Support n completions, logprobs, and configurable sampling params - Add unit tests (n=1, n=2, session_id routing) - Add GPU integration tests (sample, sample_multiple, sample_deterministic) - Simplify _force_close_connector to use transport.close() directly
55bc8e7 to
929e25b
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces a new sample method to RemoteInferenceClient to support the Tinker API, along with corresponding unit tests and updates to the mock inference server. I have provided feedback regarding the optimization of the _PARAM_MAP constant, the need for a test case covering session_id routing, and a correction for the num_choices logic in the mock server.
skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py
Outdated
Show resolved
Hide resolved
| def test_client_sample_deterministic(vllm_server: InferenceEngineState): | ||
| """Test that sample with seed + temperature=0 is deterministic across calls.""" | ||
| client = vllm_server.client | ||
| token_ids = _get_test_token_ids(MODEL_QWEN2_5) | ||
| params = {"temperature": 0.0, "max_tokens": 32, "seed": 42} | ||
|
|
||
| result1 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
| result2 = asyncio.run(client.sample(_build_sample_payload(token_ids, num_samples=1, sampling_params=params))) | ||
|
|
||
| assert result1["sequences"][0]["tokens"] == result2["sequences"][0]["tokens"] |
There was a problem hiding this comment.
The pull request description mentions adding a unit test for session_id routing for the sample method, but it seems to be missing from the submitted tests. Please consider adding a test case that utilizes the session_id parameter in _build_sample_payload to verify that session-based routing works as expected for the new endpoint.
There was a problem hiding this comment.
Leaving the arg in for _build_sample_payload since we may want to test it in the future. I'm not sure how to test session based routing in our current setup, so leaving for now.
tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py
Outdated
Show resolved
Hide resolved
…client.py revert change Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
|
|
||
| # Transform response choices → sequences | ||
| sequences = [] | ||
| logger.info("num choices: %d", len(response.get("choices", []))) |
There was a problem hiding this comment.
Always logging with info here is probably a little too verbose, right?
There was a problem hiding this comment.
Yes, I put it in for debugging originally. It shouldn't be in & I removed it
| return { | ||
| "type": "sample", | ||
| "sequences": sequences, | ||
| "prompt_logprobs": None, |
There was a problem hiding this comment.
Going forward, we might want / need to support this :)
There was a problem hiding this comment.
Yes! Next PR will include prompt_logprobs but I need to check how they handle prompt logprobs for vision to make sure we handle that
| tinker_params = body.get("sampling_params", {}) | ||
|
|
||
| # Flatten prompt chunks → token IDs | ||
| token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] |
There was a problem hiding this comment.
This will need adaptation for multi-modal inputs going forward, right?
There was a problem hiding this comment.
Yes, this will have to be the token concatenation we talked about, so it will get replaced.
Add
/sampleAPI toRemoteInferenceClientThis PR adds the tinker compatible
/sampleAPI toRemoteInferenceClienton the new inference server codepath, addressing #1286 .Changes
RemoteInferenceClient.sample()method that maps Tinker-style sample requests to the vLLM/inference/v1/generateendpoint, supportingncompletions, logprobs, and configurable sampling params (temperature, top_k, top_p, seed, stop tokens, etc.)Tests
TestSample) covering n=1, n=2, and multi-chunk promptstest_client_sample,test_client_sample_multiple,test_client_sample_deterministic) validating end-to-end generation against a live vLLM server