diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 3b3ce6610c..805acf0ec5 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -297,6 +297,124 @@ async def _generate_single( "response_logprobs": response_logprobs if len(response_logprobs) > 0 else None, } + async def sample(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """Sample completions via /inference/v1/generate. + + Single request with n in sampling_params. No retry-on-abort. + + Args: + request_payload: {"json": {...}, "headers": {...}} + json body matches Tinker SamplingClient.sample() args: + prompt (ModelInput), num_samples, sampling_params (SamplingParams), + include_prompt_logprobs, topk_prompt_logprobs. + session_id is optional for routing. + + Returns: + Dict matching Tinker SampleResponse schema. + """ + body = request_payload.get("json", {}) + session_id = body.pop("session_id", None) + + # Tinker input fields + prompt = body.get("prompt", {}) # ModelInput dict: {"chunks": [{"tokens": [...]}]} + num_samples = body.get("num_samples", 1) + sampling_params = body.get("sampling_params", {}) + prompt_logprobs = body.get("include_prompt_logprobs", False) + topk_prompt_logprobs = body.get("topk_prompt_logprobs", 0) + + prompt_token_ids = [tok for chunk in prompt.get("chunks", []) for tok in chunk.get("tokens", [])] + + # Tinker types.py SamplingParams -> vLLM /inference/v1/generate sampling_params + vllm_sampling_params = { + "n": num_samples, + "temperature": sampling_params.get("temperature"), + "max_tokens": sampling_params.get("max_tokens"), + "seed": sampling_params.get("seed"), + "top_k": sampling_params.get("top_k", -1), + "top_p": sampling_params.get("top_p", 1.0), + "prompt_logprobs": max(topk_prompt_logprobs, 1) if prompt_logprobs else None, + "logprobs": 0, # response logprobs + } + if sampling_params.get("stop_strings"): + vllm_sampling_params["stop"] = sampling_params["stop_strings"] + if sampling_params.get("stop_tokens"): + vllm_sampling_params["stop_token_ids"] = sampling_params["stop_tokens"] + + payload = { + "sampling_params": vllm_sampling_params, + "model": self.model_name, + "token_ids": prompt_token_ids, + } + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + session = await self._get_session() + url = f"{self.proxy_url}/inference/v1/generate" + + async with session.post(url, json=payload, headers=headers) as resp: + result = await resp.json() + raise_for_status(resp, result) + + # Transform response choices -> tinker type SampleResponse dict + sequences = [] + for choice in result["choices"]: + raw_stop = choice.get("finish_reason", "length") + stop_reason = "stop" if raw_stop in ("stop", "stop_token") else "length" + + token_ids = choice.get("token_ids", []) + + logprobs = None + lp = choice.get("logprobs") + if lp is not None: + logprobs_content = lp.get("content", []) + if logprobs_content: + logprobs = [info["logprob"] if info["logprob"] is not None else 0.0 for info in logprobs_content] + # Convert to tinker type SampledSequence dict + sequences.append( + { + "stop_reason": stop_reason, + "tokens": token_ids, + "logprobs": logprobs, + } + ) + + # Transform prompt_logprobs from vLLM format to Tinker SampleResponse format + transformed_prompt_logprobs = None + transformed_topk_prompt_logprobs = None + # raw_prompt_logprobs: VLLM type list[dict[int, Logprob]] | None + raw_prompt_logprobs = result.get("prompt_logprobs") if prompt_logprobs else None + + if raw_prompt_logprobs is not None: + # prompt_logprobs: single float per token (logprob of the actual prompt token) + transformed_prompt_logprobs = [] + for i, pos in enumerate(raw_prompt_logprobs): + if pos is None: + transformed_prompt_logprobs.append(None) + else: + token_key = str(prompt_token_ids[i]) + entry = pos.get(token_key) + transformed_prompt_logprobs.append(entry["logprob"] if entry is not None else None) + + # topk_prompt_logprobs: list of (token_id, logprob) tuples per position + if topk_prompt_logprobs > 0: + transformed_topk_prompt_logprobs = [] + for pos in raw_prompt_logprobs: + if pos is None: + transformed_topk_prompt_logprobs.append(None) + else: + transformed_topk_prompt_logprobs.append( + [(int(tid), info["logprob"]) for tid, info in pos.items()] + ) + + return { + "type": "sample", + "sequences": sequences, + "prompt_logprobs": transformed_prompt_logprobs, + "topk_prompt_logprobs": transformed_topk_prompt_logprobs, + } + async def chat_completion( self, request_payload: Dict[str, Any], diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index 33f376a89d..4072718920 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -22,10 +22,16 @@ from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import ( create_ray_wrapped_inference_engines, ) +from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import ( + RemoteInferenceClient, +) +from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter +from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup +from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch -from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S +from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S from skyrl.tinker import types from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str from skyrl.train.utils.utils import get_ray_pg_ready_with_timeout, initialize_ray @@ -116,6 +122,8 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides): self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model) self._inference_engine_client = None self._inference_engines_initialized = False + self._server_group = None + self._inference_router = None def has_model(self, model_id: str) -> bool: return self._model_id == model_id @@ -184,19 +192,84 @@ def init_weight_sync_state(self): self._dispatch.init_weight_sync_state(self._inference_engine_client) logger.info("Initialized weight sync state for policy model and inference engines.") + def _create_remote_inference_client(self): + """Create a RemoteInferenceClient using HTTP endpoints. + + Mirrors main_base.py._get_new_inference_client() with the same 4-way + branching on external_proxy_url / external_server_urls. + """ + ie_cfg = self._cfg.generator.inference_engine + is_colocated = self._cfg.trainer.placement.colocate_all + external_proxy_url = ie_cfg.external_proxy_url + external_server_urls = ie_cfg.external_server_urls + + has_external_proxy = external_proxy_url is not None + has_external_servers = external_server_urls is not None + + if has_external_proxy and has_external_servers: + proxy_url = external_proxy_url + server_urls = list(external_server_urls) + logger.info( + f"HTTP Inference: Using fully external setup - proxy_url={proxy_url}, server_urls={server_urls}" + ) + + elif has_external_proxy and not has_external_servers: + proxy_url = external_proxy_url + server_urls = [proxy_url] + logger.info(f"HTTP Inference: Using external proxy for both data and control plane - proxy_url={proxy_url}") + + elif has_external_servers and not has_external_proxy: + server_urls = list(external_server_urls) + self._inference_router = InferenceRouter(server_urls=server_urls) + proxy_url = self._inference_router.start() + logger.info( + f"HTTP Inference: Created internal router over external " + f"servers - server_urls={server_urls}, proxy_url={proxy_url}" + ) + + else: + cli_args = build_vllm_cli_args(self._cfg) + + self._server_group = ServerGroup( + cli_args=cli_args, + num_servers=ie_cfg.num_engines, + placement_group=self._colocate_pg if is_colocated else None, + enable_dp=ie_cfg.data_parallel_size > 1, + ) + server_infos = self._server_group.start() + server_urls = [info.url for info in server_infos] + + self._inference_router = InferenceRouter(server_urls=server_urls) + proxy_url = self._inference_router.start() + logger.info( + f"HTTP Inference: Built servers and router internally - " + f"proxy_url={proxy_url}, server_urls={server_urls}, colocated={is_colocated}" + ) + + return RemoteInferenceClient( + proxy_url=proxy_url, + server_urls=server_urls, + model_name=self._cfg.trainer.policy.model.path, + ) + def _ensure_inference_engines(self): """Lazily create inference engines and init weight sync on first sampling-related call.""" if self._inference_engines_initialized: return - logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines") - self._inference_engine_client = InferenceEngineClient( - create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer), - self._tokenizer, - self._cfg.trainer.policy.model.path, - self._cfg.trainer.policy.model.lora, - self._cfg.generator.inference_engine, - ) + if _SKYRL_USE_NEW_INFERENCE: + logger.info("Using new HTTP-based inference client (RemoteInferenceClient)") + self._inference_engine_client = self._create_remote_inference_client() + else: + logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines") + self._inference_engine_client = InferenceEngineClient( + create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer), + self._tokenizer, + self._cfg.trainer.policy.model.path, + self._cfg.trainer.policy.model.lora, + self._cfg.generator.inference_engine, + ) + self._dispatch.set_inference_engine_client(self._inference_engine_client) self.init_weight_sync_state() self._inference_engines_initialized = True @@ -533,28 +606,43 @@ async def sample_all(): prompt = prepared_batch.all_prompts[i] sampling_params = prepared_batch.all_sampling_params[i] - # Pass through common fields; only stop needs name translation - # (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids) - params_dict = { - "temperature": sampling_params.temperature, - "max_tokens": sampling_params.max_tokens, - "seed": sampling_params.seed, - "top_k": sampling_params.top_k, - "top_p": sampling_params.top_p, - "logprobs": 0, - } - if sampling_params.stop_strings: - params_dict["stop"] = sampling_params.stop_strings - if sampling_params.stop_tokens: - params_dict["stop_token_ids"] = sampling_params.stop_tokens - - tasks.append( - self._inference_engine_client.sample( - prompt_token_ids=prompt, - num_samples=1, # Tinker batches multiple samples separately - sampling_params=params_dict, + if _SKYRL_USE_NEW_INFERENCE: + # Right now, prompt is list[int] (token IDs), so we wrap in ModelInput format + json_body = { + "prompt": {"chunks": [{"tokens": prompt}]}, + "num_samples": 1, # Tinker batches multiple samples separately + "sampling_params": { + "temperature": sampling_params.temperature, + "max_tokens": sampling_params.max_tokens, + "seed": sampling_params.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "stop_tokens": sampling_params.stop_tokens, + "stop_strings": sampling_params.stop_strings, + }, + } + tasks.append(self._inference_engine_client.sample({"json": json_body, "headers": {}})) + else: + params_dict = { + "temperature": sampling_params.temperature, + "max_tokens": sampling_params.max_tokens, + "seed": sampling_params.seed, + "top_k": sampling_params.top_k, + "top_p": sampling_params.top_p, + "logprobs": 0, + } + if sampling_params.stop_strings: + params_dict["stop"] = sampling_params.stop_strings + if sampling_params.stop_tokens: + params_dict["stop_token_ids"] = sampling_params.stop_tokens + + tasks.append( + self._inference_engine_client.sample( + prompt_token_ids=prompt, + num_samples=1, # Tinker batches multiple samples separately + sampling_params=params_dict, + ) ) - ) return await asyncio.gather(*tasks, return_exceptions=True) @@ -565,14 +653,25 @@ async def sample_all(): # We preserve these to include error messages in responses # 4. Aggregate results by request - return self._aggregate_sample_results(prepared_batch, sample_outputs) + return self._aggregate_sample_results( + prepared_batch, sample_outputs, use_new_inference=_SKYRL_USE_NEW_INFERENCE + ) def _aggregate_sample_results( self, prepared_batch: types.PreparedSampleBatch, sample_outputs: list, + use_new_inference: bool = False, ) -> dict[str, types.SampleOutput | types.ErrorResponse]: - """Convert InferenceEngineClient outputs to Tinker format.""" + """Convert inference outputs to Tinker format. + + Args: + prepared_batch: The prepared sample batch. + sample_outputs: List of outputs from inference client. + use_new_inference: If True, outputs are SampleResponse dicts from + RemoteInferenceClient. If False, outputs are InferenceEngineOutput + dicts from InferenceEngineClient. + """ results = {} for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices: @@ -595,13 +694,20 @@ def _aggregate_sample_results( logger.error(error_msg) break - # Extract tokens and logprobs - response_tokens = output["response_ids"][0] - response_logprobs = (output.get("response_logprobs") or [[]])[0] - stop_reason_raw = output["stop_reasons"][0] - - # Map vLLM stop reason to Tinker format - stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" + if use_new_inference: + # New inference server: SampleResponse dict + seq = output["sequences"][0] + response_tokens = seq["tokens"] + response_logprobs = seq.get("logprobs") or [] + stop_reason = seq["stop_reason"] + else: + # Old inference engine: InferenceEngineOutput + # Extract tokens and logprobs + response_tokens = output["response_ids"][0] + response_logprobs = (output.get("response_logprobs") or [[]])[0] + stop_reason_raw = output["stop_reasons"][0] + # Map vLLM stop reason to Tinker format + stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" # Ensure logprobs exist (critical for RL) if response_logprobs is None or len(response_logprobs) == 0: diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py index 884d62efc0..68470bae0a 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py @@ -295,8 +295,7 @@ def test_token_based_generation_consistency(ray_init_fixture, tp_size: int, pp_s ) -# TODO: Remove this once sample API is also supported in the new inference pathway -@pytest.mark.skipif(_SKYRL_USE_NEW_INFERENCE, reason="New inference pathway doesn't support sample API yet") +@pytest.mark.skipif(_SKYRL_USE_NEW_INFERENCE, reason="Old sample API not used with new inference path") @pytest.mark.parametrize( "tp_size,dp_size", [ @@ -348,3 +347,112 @@ async def run_sample(): print(f"Generated {len(unique_responses)} unique responses from {num_samples} samples") for i, resp in enumerate(output["responses"]): print(f"Sample {i}: {resp[:100]}..." if len(resp) > 100 else f"Sample {i}: {resp}") + + +@pytest.mark.parametrize( + "tp_size,dp_size", + [ + pytest.param(2, 1), + ], + ids=["tp2"], +) +def test_sample_api_remote(ray_init_fixture, tp_size: int, dp_size: int): + """Test the sample() API via RemoteInferenceClient (new inference path). + + Makes two calls to validate both output correctness and sampling behavior: + - Call A (temp=1.0): schema checks, token decode, diversity across samples + - Call B (temp=0.0): determinism (all samples should be identical) + """ + cfg = get_test_actor_config() + + prompts = get_test_prompts(MODEL, 1) + tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompt_token_ids = tokenizer.apply_chat_template( + prompts, add_generation_prompt=True, tokenize=True, return_dict=True + )["input_ids"][0] + + cfg.generator.inference_engine.tensor_parallel_size = tp_size + cfg.generator.inference_engine.data_parallel_size = dp_size + + num_samples = 3 + + with InferenceEngineState.create(cfg, sleep_level=1, use_new_inference_servers=True) as engines: + llm_client = engines.client + + def build_payload(temperature): + return { + "json": { + "prompt": {"chunks": [{"tokens": prompt_token_ids}]}, + "num_samples": 1, + "sampling_params": { + "temperature": temperature, + "max_tokens": cfg.generator.sampling_params.max_generate_length, + "top_k": cfg.generator.sampling_params.top_k, + "top_p": cfg.generator.sampling_params.top_p, + }, + }, + "headers": {}, + } + + async def run_samples(temperature, n): + results = [] + for _ in range(n): + output = await llm_client.sample(build_payload(temperature)) + assert output["type"] == "sample", f"Expected type 'sample', got {output['type']!r}" + assert len(output["sequences"]) == 1, f"Expected 1 sequence per call, got {len(output['sequences'])}" + results.append(output["sequences"][0]) + return results + + # --- Call A: temp=1.0, expect diverse outputs --- + sequences = asyncio.run(run_samples(1.0, num_samples)) + + decoded_texts = [] + for i, seq in enumerate(sequences): + assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" + assert isinstance(seq["tokens"], list), f"tokens should be a list, got {type(seq['tokens'])}" + assert len(seq["tokens"]) > 0, f"Sequence {i} has no tokens" + assert all(isinstance(t, int) for t in seq["tokens"]), f"Sequence {i} contains non-int tokens" + if seq.get("logprobs") is not None: + assert isinstance( + seq["logprobs"], list + ), f"Sequence {i} logprobs should be a list, got {type(seq['logprobs'])}" + + text = tokenizer.decode(seq["tokens"], skip_special_tokens=True) + assert len(text.strip()) > 0, f"Sequence {i} decoded to empty text from {len(seq['tokens'])} tokens" + decoded_texts.append(text) + + unique_texts = set(decoded_texts) + assert len(unique_texts) > 1, ( + f"All {num_samples} samples at temp=1.0 are identical — sampling params may be ignored. " + f"Text: {decoded_texts[0][:120]!r}" + ) + + print(f"Call A (temp=1.0): {len(unique_texts)}/{num_samples} unique samples") + for i, text in enumerate(decoded_texts): + print(f" Sample {i}: {text[:100]}..." if len(text) > 100 else f" Sample {i}: {text}") + + # --- Call B: temp=0.0, expect deterministic outputs --- + det_sequences = asyncio.run(run_samples(0.0, num_samples)) + + det_token_seqs = [] + det_texts = [] + for i, seq in enumerate(det_sequences): + assert seq["stop_reason"] in ("stop", "length"), f"Unexpected stop_reason: {seq['stop_reason']}" + assert isinstance( + seq["tokens"], list + ), f"Deterministic sequence {i} tokens should be a list, got {type(seq['tokens'])}" + assert len(seq["tokens"]) > 0, f"Deterministic sequence {i} has no tokens" + det_token_seqs.append(tuple(seq["tokens"])) + + text = tokenizer.decode(seq["tokens"], skip_special_tokens=True) + assert len(text.strip()) > 0, f"Deterministic sequence {i} decoded to empty text" + det_texts.append(text) + + unique_det = set(det_token_seqs) + assert len(unique_det) == 1, ( + f"temp=0.0 produced {len(unique_det)} distinct token sequences — expected deterministic output. " + f"Lengths: {[len(s) for s in det_token_seqs]}" + ) + + print(f"Call B (temp=0.0): all {num_samples} samples identical (deterministic)") + print(f" Text: {det_texts[0][:120]}") diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py index 2e519d5a43..8c17f02660 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_save_weights_for_sampler.py @@ -17,6 +17,7 @@ from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.train.config import SkyRLTrainConfig from skyrl.train.utils.utils import validate_cfg +from skyrl.utils.tok import get_tokenizer from tests.backends.skyrl_train.gpu.utils import ( InferenceEngineState, get_test_prompts, @@ -78,6 +79,7 @@ def test_save_weights_for_sampler_then_inference(ray_init_fixture, colocate_all, tp_size=cfg.generator.inference_engine.tensor_parallel_size, colocate_all=cfg.trainer.placement.colocate_all, sleep_level=2, # Full sleep since we explicitly sync weights + gpu_memory_utilization=0.5 if colocate_all else None, ) as engines: client, pg = engines.client, engines.pg # Initialize policy worker @@ -121,7 +123,10 @@ def test_save_weights_for_sampler_then_inference(ray_init_fixture, colocate_all, sampling_params = get_sampling_params_for_backend( cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) - outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=5), sampling_params)) + tokenizer = get_tokenizer(MODEL) + outputs = asyncio.run( + run_inference(client, get_test_prompts(MODEL, num_samples=5), sampling_params, tokenizer=tokenizer) + ) # Verify we got responses assert "responses" in outputs, "Inference should return responses" @@ -186,5 +191,8 @@ def test_save_weights_for_sampler_multiple_training_steps(ray_init_fixture): sampling_params = get_sampling_params_for_backend( cfg.generator.inference_engine.backend, cfg.generator.sampling_params ) - outputs = asyncio.run(run_inference(client, get_test_prompts(MODEL, num_samples=2), sampling_params)) + tokenizer = get_tokenizer(MODEL) + outputs = asyncio.run( + run_inference(client, get_test_prompts(MODEL, num_samples=2), sampling_params, tokenizer=tokenizer) + ) assert len(outputs["responses"]) == 2, "Should get 2 responses" diff --git a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py index ba547b0a10..893828254b 100644 --- a/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py +++ b/tests/backends/skyrl_train/inference_servers/test_remote_inference_client.py @@ -45,13 +45,18 @@ async def completions(request: Request): @app.post("/inference/v1/generate") async def generate(request: Request): - body = await request.json() # Consume body - num_prompts = len(body.get("token_ids", [])) + body = await request.json() + n = body.get("sampling_params", {}).get("n", 1) return { "choices": [ - {"request_id": "dummy", "token_ids": [i, i + 1, i + 2], "finish_reason": "stop"} - for i in range(num_prompts) + { + "request_id": "dummy", + "token_ids": [i, i + 1, i + 2], + "finish_reason": "stop", + "logprobs": {"content": [{"logprob": -0.1}, {"logprob": -0.2}, {"logprob": -0.3}]}, + } + for i in range(n) ] } @@ -472,3 +477,87 @@ async def test_retry_on_abort(self, abort_mock_server): assert len(result["response_ids"][0]) > 0 finally: await client.teardown() + + +class TestSample: + """Test sample() with request_payload / SampleResponse dict interface.""" + + @pytest.mark.asyncio + async def test_sample(self, client): + """Test sample with n=1: verify SampleResponse dict contents.""" + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": [10, 20, 30]}]}, + "num_samples": 1, + "sampling_params": { + "temperature": 0.8, + "max_tokens": 512, + "seed": 42, + }, + }, + "headers": {}, + } + result = await client.sample(request_payload) + + assert result["type"] == "sample" + assert result["prompt_logprobs"] is None + assert result["topk_prompt_logprobs"] is None + assert len(result["sequences"]) == 1 + + # Mock returns token_ids=[0, 1, 2] for choice i=0, logprobs=[-0.1, -0.2, -0.3] + seq = result["sequences"][0] + assert seq["tokens"] == [0, 1, 2] + assert seq["logprobs"] == [-0.1, -0.2, -0.3] + assert seq["stop_reason"] == "stop" + + @pytest.mark.asyncio + async def test_sample_n2(self, client): + """Test sample with n=2: verify both sequences match mock output.""" + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": [10, 20, 30]}]}, + "num_samples": 2, + "sampling_params": { + "temperature": 0.8, + "max_tokens": 512, + "seed": 42, + }, + }, + "headers": {}, + } + result = await client.sample(request_payload) + + assert result["type"] == "sample" + assert len(result["sequences"]) == 2 + + # Mock returns token_ids=[i, i+1, i+2] for choice i + assert result["sequences"][0]["tokens"] == [0, 1, 2] + assert result["sequences"][0]["logprobs"] == [-0.1, -0.2, -0.3] + assert result["sequences"][0]["stop_reason"] == "stop" + + assert result["sequences"][1]["tokens"] == [1, 2, 3] + assert result["sequences"][1]["logprobs"] == [-0.1, -0.2, -0.3] + assert result["sequences"][1]["stop_reason"] == "stop" + + @pytest.mark.asyncio + async def test_sample_with_session_id(self, client): + """Test that session_id is popped from body and used for routing.""" + request_payload = { + "json": { + "prompt": {"chunks": [{"tokens": [1, 2]}]}, + "num_samples": 1, + "sampling_params": {"temperature": 1.0, "max_tokens": 10}, + "session_id": "test-routing-session", + }, + "headers": {}, + } + result = await client.sample(request_payload) + + # session_id should be consumed (popped) from body + assert "session_id" not in request_payload["json"] + + assert result["type"] == "sample" + assert len(result["sequences"]) == 1 + assert result["sequences"][0]["tokens"] == [0, 1, 2] + assert result["sequences"][0]["logprobs"] == [-0.1, -0.2, -0.3] + assert result["sequences"][0]["stop_reason"] == "stop"