-
Notifications
You must be signed in to change notification settings - Fork 286
[tinker][SkyRL] Add sample and renderer API to new inference client #1287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
395f187
f0ec828
07ae4de
713ada8
8ab564e
7d8e5f4
c7e4d52
5ec91b7
f353605
6526887
2c0b47f
5b2fa34
29125ab
3f07164
c6fbfde
c950ced
26b7abf
516ae5e
281ba77
4d8298c
4651da7
3951469
2c399e0
fde5342
d894f82
269a81c
bd26015
b5bda89
b7484b3
aff3835
15b2ddb
863c203
7eaa5d8
03ff39d
71ca1cf
4f5ccd4
47c9400
9a5eb86
f9874a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,10 +22,16 @@ | |||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import ( | ||||||||||||||||||||||||
| create_ray_wrapped_inference_engines, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import ( | ||||||||||||||||||||||||
| RemoteInferenceClient, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_servers.router import InferenceRouter | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup | ||||||||||||||||||||||||
| from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch | ||||||||||||||||||||||||
| from skyrl.env_vars import SKYRL_RAY_PG_TIMEOUT_IN_S | ||||||||||||||||||||||||
| from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S | ||||||||||||||||||||||||
| from skyrl.tinker import types | ||||||||||||||||||||||||
| from skyrl.train.config import SkyRLTrainConfig, get_config_as_yaml_str | ||||||||||||||||||||||||
| from skyrl.train.utils.utils import get_ray_pg_ready_with_timeout, initialize_ray | ||||||||||||||||||||||||
|
|
@@ -116,6 +122,8 @@ def __init__(self, base_model: str, config: SkyRLTrainBackendOverrides): | |||||||||||||||||||||||
| self._tokenizer: AutoTokenizer = get_tokenizer(self.base_model) | ||||||||||||||||||||||||
| self._inference_engine_client = None | ||||||||||||||||||||||||
| self._inference_engines_initialized = False | ||||||||||||||||||||||||
| self._server_group = None | ||||||||||||||||||||||||
| self._inference_router = None | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def has_model(self, model_id: str) -> bool: | ||||||||||||||||||||||||
| return self._model_id == model_id | ||||||||||||||||||||||||
|
|
@@ -184,19 +192,84 @@ def init_weight_sync_state(self): | |||||||||||||||||||||||
| self._dispatch.init_weight_sync_state(self._inference_engine_client) | ||||||||||||||||||||||||
| logger.info("Initialized weight sync state for policy model and inference engines.") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _create_remote_inference_client(self): | ||||||||||||||||||||||||
| """Create a RemoteInferenceClient using HTTP endpoints. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Mirrors main_base.py._get_new_inference_client() with the same 4-way | ||||||||||||||||||||||||
| branching on external_proxy_url / external_server_urls. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| ie_cfg = self._cfg.generator.inference_engine | ||||||||||||||||||||||||
| is_colocated = self._cfg.trainer.placement.colocate_all | ||||||||||||||||||||||||
| external_proxy_url = ie_cfg.external_proxy_url | ||||||||||||||||||||||||
| external_server_urls = ie_cfg.external_server_urls | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| has_external_proxy = external_proxy_url is not None | ||||||||||||||||||||||||
| has_external_servers = external_server_urls is not None | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if has_external_proxy and has_external_servers: | ||||||||||||||||||||||||
| proxy_url = external_proxy_url | ||||||||||||||||||||||||
| server_urls = list(external_server_urls) | ||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||
| f"HTTP Inference: Using fully external setup - proxy_url={proxy_url}, server_urls={server_urls}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| elif has_external_proxy and not has_external_servers: | ||||||||||||||||||||||||
| proxy_url = external_proxy_url | ||||||||||||||||||||||||
| server_urls = [proxy_url] | ||||||||||||||||||||||||
| logger.info(f"HTTP Inference: Using external proxy for both data and control plane - proxy_url={proxy_url}") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| elif has_external_servers and not has_external_proxy: | ||||||||||||||||||||||||
| server_urls = list(external_server_urls) | ||||||||||||||||||||||||
| self._inference_router = InferenceRouter(server_urls=server_urls) | ||||||||||||||||||||||||
| proxy_url = self._inference_router.start() | ||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||
| f"HTTP Inference: Created internal router over external " | ||||||||||||||||||||||||
| f"servers - server_urls={server_urls}, proxy_url={proxy_url}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| cli_args = build_vllm_cli_args(self._cfg) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self._server_group = ServerGroup( | ||||||||||||||||||||||||
| cli_args=cli_args, | ||||||||||||||||||||||||
| num_servers=ie_cfg.num_engines, | ||||||||||||||||||||||||
| placement_group=self._colocate_pg if is_colocated else None, | ||||||||||||||||||||||||
| enable_dp=ie_cfg.data_parallel_size > 1, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| server_infos = self._server_group.start() | ||||||||||||||||||||||||
| server_urls = [info.url for info in server_infos] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self._inference_router = InferenceRouter(server_urls=server_urls) | ||||||||||||||||||||||||
| proxy_url = self._inference_router.start() | ||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||
| f"HTTP Inference: Built servers and router internally - " | ||||||||||||||||||||||||
| f"proxy_url={proxy_url}, server_urls={server_urls}, colocated={is_colocated}" | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return RemoteInferenceClient( | ||||||||||||||||||||||||
| proxy_url=proxy_url, | ||||||||||||||||||||||||
| server_urls=server_urls, | ||||||||||||||||||||||||
| model_name=self._cfg.trainer.policy.model.path, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
Comment on lines
+249
to
+253
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴
The old Note:
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _ensure_inference_engines(self): | ||||||||||||||||||||||||
| """Lazily create inference engines and init weight sync on first sampling-related call.""" | ||||||||||||||||||||||||
| if self._inference_engines_initialized: | ||||||||||||||||||||||||
| return | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines") | ||||||||||||||||||||||||
| self._inference_engine_client = InferenceEngineClient( | ||||||||||||||||||||||||
| create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer), | ||||||||||||||||||||||||
| self._tokenizer, | ||||||||||||||||||||||||
| self._cfg.trainer.policy.model.path, | ||||||||||||||||||||||||
| self._cfg.trainer.policy.model.lora, | ||||||||||||||||||||||||
| self._cfg.generator.inference_engine, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| if _SKYRL_USE_NEW_INFERENCE: | ||||||||||||||||||||||||
| logger.info("Using new HTTP-based inference client (RemoteInferenceClient)") | ||||||||||||||||||||||||
| self._inference_engine_client = self._create_remote_inference_client() | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| logger.info(f"Creating {self._cfg.generator.inference_engine.num_engines} inference engines") | ||||||||||||||||||||||||
| self._inference_engine_client = InferenceEngineClient( | ||||||||||||||||||||||||
| create_ray_wrapped_inference_engines_from_config(self._cfg, self._colocate_pg, self._tokenizer), | ||||||||||||||||||||||||
| self._tokenizer, | ||||||||||||||||||||||||
| self._cfg.trainer.policy.model.path, | ||||||||||||||||||||||||
| self._cfg.trainer.policy.model.lora, | ||||||||||||||||||||||||
| self._cfg.generator.inference_engine, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self._dispatch.set_inference_engine_client(self._inference_engine_client) | ||||||||||||||||||||||||
| self.init_weight_sync_state() | ||||||||||||||||||||||||
| self._inference_engines_initialized = True | ||||||||||||||||||||||||
|
|
@@ -533,28 +606,43 @@ async def sample_all(): | |||||||||||||||||||||||
| prompt = prepared_batch.all_prompts[i] | ||||||||||||||||||||||||
| sampling_params = prepared_batch.all_sampling_params[i] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Pass through common fields; only stop needs name translation | ||||||||||||||||||||||||
| # (Tinker uses stop_strings/stop_tokens, vLLM uses stop/stop_token_ids) | ||||||||||||||||||||||||
| params_dict = { | ||||||||||||||||||||||||
| "temperature": sampling_params.temperature, | ||||||||||||||||||||||||
| "max_tokens": sampling_params.max_tokens, | ||||||||||||||||||||||||
| "seed": sampling_params.seed, | ||||||||||||||||||||||||
| "top_k": sampling_params.top_k, | ||||||||||||||||||||||||
| "top_p": sampling_params.top_p, | ||||||||||||||||||||||||
| "logprobs": 0, | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| if sampling_params.stop_strings: | ||||||||||||||||||||||||
| params_dict["stop"] = sampling_params.stop_strings | ||||||||||||||||||||||||
| if sampling_params.stop_tokens: | ||||||||||||||||||||||||
| params_dict["stop_token_ids"] = sampling_params.stop_tokens | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| tasks.append( | ||||||||||||||||||||||||
| self._inference_engine_client.sample( | ||||||||||||||||||||||||
| prompt_token_ids=prompt, | ||||||||||||||||||||||||
| num_samples=1, # Tinker batches multiple samples separately | ||||||||||||||||||||||||
| sampling_params=params_dict, | ||||||||||||||||||||||||
| if _SKYRL_USE_NEW_INFERENCE: | ||||||||||||||||||||||||
| # Right now, prompt is list[int] (token IDs), so we wrap in ModelInput format | ||||||||||||||||||||||||
| json_body = { | ||||||||||||||||||||||||
| "prompt": {"chunks": [{"tokens": prompt}]}, | ||||||||||||||||||||||||
| "num_samples": 1, # Tinker batches multiple samples separately | ||||||||||||||||||||||||
| "sampling_params": { | ||||||||||||||||||||||||
| "temperature": sampling_params.temperature, | ||||||||||||||||||||||||
| "max_tokens": sampling_params.max_tokens, | ||||||||||||||||||||||||
| "seed": sampling_params.seed, | ||||||||||||||||||||||||
| "top_k": sampling_params.top_k, | ||||||||||||||||||||||||
| "top_p": sampling_params.top_p, | ||||||||||||||||||||||||
| "stop_tokens": sampling_params.stop_tokens, | ||||||||||||||||||||||||
| "stop_strings": sampling_params.stop_strings, | ||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| tasks.append(self._inference_engine_client.sample({"json": json_body, "headers": {}})) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| params_dict = { | ||||||||||||||||||||||||
| "temperature": sampling_params.temperature, | ||||||||||||||||||||||||
| "max_tokens": sampling_params.max_tokens, | ||||||||||||||||||||||||
| "seed": sampling_params.seed, | ||||||||||||||||||||||||
| "top_k": sampling_params.top_k, | ||||||||||||||||||||||||
| "top_p": sampling_params.top_p, | ||||||||||||||||||||||||
| "logprobs": 0, | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| if sampling_params.stop_strings: | ||||||||||||||||||||||||
| params_dict["stop"] = sampling_params.stop_strings | ||||||||||||||||||||||||
| if sampling_params.stop_tokens: | ||||||||||||||||||||||||
| params_dict["stop_token_ids"] = sampling_params.stop_tokens | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| tasks.append( | ||||||||||||||||||||||||
| self._inference_engine_client.sample( | ||||||||||||||||||||||||
| prompt_token_ids=prompt, | ||||||||||||||||||||||||
| num_samples=1, # Tinker batches multiple samples separately | ||||||||||||||||||||||||
| sampling_params=params_dict, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return await asyncio.gather(*tasks, return_exceptions=True) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -565,14 +653,25 @@ async def sample_all(): | |||||||||||||||||||||||
| # We preserve these to include error messages in responses | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # 4. Aggregate results by request | ||||||||||||||||||||||||
| return self._aggregate_sample_results(prepared_batch, sample_outputs) | ||||||||||||||||||||||||
| return self._aggregate_sample_results( | ||||||||||||||||||||||||
| prepared_batch, sample_outputs, use_new_inference=_SKYRL_USE_NEW_INFERENCE | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _aggregate_sample_results( | ||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||
| prepared_batch: types.PreparedSampleBatch, | ||||||||||||||||||||||||
| sample_outputs: list, | ||||||||||||||||||||||||
| use_new_inference: bool = False, | ||||||||||||||||||||||||
| ) -> dict[str, types.SampleOutput | types.ErrorResponse]: | ||||||||||||||||||||||||
| """Convert InferenceEngineClient outputs to Tinker format.""" | ||||||||||||||||||||||||
| """Convert inference outputs to Tinker format. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||
| prepared_batch: The prepared sample batch. | ||||||||||||||||||||||||
| sample_outputs: List of outputs from inference client. | ||||||||||||||||||||||||
| use_new_inference: If True, outputs are SampleResponse dicts from | ||||||||||||||||||||||||
| RemoteInferenceClient. If False, outputs are InferenceEngineOutput | ||||||||||||||||||||||||
| dicts from InferenceEngineClient. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| results = {} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for request_id, model_id, start_idx, end_idx, needs_prompt_logprobs in prepared_batch.request_batch_slices: | ||||||||||||||||||||||||
|
|
@@ -595,13 +694,20 @@ def _aggregate_sample_results( | |||||||||||||||||||||||
| logger.error(error_msg) | ||||||||||||||||||||||||
| break | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Extract tokens and logprobs | ||||||||||||||||||||||||
| response_tokens = output["response_ids"][0] | ||||||||||||||||||||||||
| response_logprobs = (output.get("response_logprobs") or [[]])[0] | ||||||||||||||||||||||||
| stop_reason_raw = output["stop_reasons"][0] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Map vLLM stop reason to Tinker format | ||||||||||||||||||||||||
| stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" | ||||||||||||||||||||||||
| if use_new_inference: | ||||||||||||||||||||||||
| # New inference server: SampleResponse dict | ||||||||||||||||||||||||
| seq = output["sequences"][0] | ||||||||||||||||||||||||
| response_tokens = seq["tokens"] | ||||||||||||||||||||||||
| response_logprobs = seq.get("logprobs") or [] | ||||||||||||||||||||||||
| stop_reason = seq["stop_reason"] | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| # Old inference engine: InferenceEngineOutput | ||||||||||||||||||||||||
| # Extract tokens and logprobs | ||||||||||||||||||||||||
| response_tokens = output["response_ids"][0] | ||||||||||||||||||||||||
| response_logprobs = (output.get("response_logprobs") or [[]])[0] | ||||||||||||||||||||||||
| stop_reason_raw = output["stop_reasons"][0] | ||||||||||||||||||||||||
| # Map vLLM stop reason to Tinker format | ||||||||||||||||||||||||
| stop_reason = "stop" if stop_reason_raw in ["stop", "stop_token"] else "length" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Ensure logprobs exist (critical for RL) | ||||||||||||||||||||||||
| if response_logprobs is None or len(response_logprobs) == 0: | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.