diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 3276612a01..7434f71418 100644 --- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -115,6 +115,7 @@ def create_ray_wrapped_inference_engines( enable_return_routed_experts: bool = False, served_model_name: str | None = None, distributed_executor_backend: str = "ray", + max_logprobs: int = 1, ) -> List[InferenceEngineInterface]: """ Create a list of RayWrappedInferenceEngine instances wrapping Ray actor handles to InferenceEngineInterface @@ -301,7 +302,7 @@ def create_ray_wrapped_inference_engines( noset_visible_devices=noset_visible_devices, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs, - max_logprobs=1, # only need chosen-token logprobs + max_logprobs=max_logprobs, enable_ray_prometheus_stats=enable_ray_prometheus_stats, enable_return_routed_experts=enable_return_routed_experts, **dp_kwargs,