From dcaacfec53ab96fb2a36a82b073266b352142689 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Sat, 31 Jan 2026 19:23:45 +0800 Subject: [PATCH 1/9] Fix partial load problem, Add vlm support for trtllm rollout --- .../trtllm_rollout/trtllm_async_server.py | 84 ++++++++++++++----- .../rollout/trtllm_rollout/trtllm_rollout.py | 35 ++++++-- .../trtllm_rollout/trtllm_worker_extension.py | 77 +++++++++++++++++ 3 files changed, 168 insertions(+), 28 deletions(-) create mode 100644 verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index f669a7bfe3b..6448075e476 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -125,7 +125,7 @@ async def launch_server(self): "model": self.model_config.local_path, "backend": "pytorch", "orchestrator_type": "ray", - "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + "ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension", "kv_cache_config": kv_cache_config, "max_seq_len": self.config.max_model_len, "max_batch_size": self.config.max_num_seqs, @@ -159,18 +159,45 @@ async def launch_server(self): } ) - self.llm = await AsyncLLM(**llm_kwargs) - - trtllm_server = OpenAIServer( - llm=self.llm, - model=self.model_config.local_path, - tool_parser=None, - server_role=None, - metadata_server_cfg=None, - ) + if self.is_vlm_model: + from tensorrt_llm.inputs.multimodal import MultimodalServerConfig + multimodal_config = MultimodalServerConfig( + media_io_kwargs={ + "image": { + "format": "pil", + "device": "cpu", + }, + "video": { + "num_frames": 8, + "fps": 30, + "format": "pil", + "device": "cpu", + }, + } + ) + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + multimodal_server_config=multimodal_config, + ) + else: + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + ) + app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) + @resume_on_abort async def generate( self, prompt_ids: list[int], @@ -179,11 +206,7 @@ async def generate( image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, ) -> TokenOutput: - """Generate sequence with token-in-token-out.""" - assert image_data is None and video_data is None, "Multimodality is not yet supported in TRTLLMHttpServer." - from tensorrt_llm.llmapi import SamplingParams - max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids)) sampling_params["max_tokens"] = max_tokens sampling_params["logprobs"] = 1 if sampling_params.pop("logprobs", False) else None @@ -192,15 +215,34 @@ async def generate( sampling_params.update(self.sampling_args) trt_llm_sampling_params = SamplingParams(**sampling_params) - outputs = await self.llm.generate_async( - inputs=prompt_ids, - sampling_params=trt_llm_sampling_params, - ) - + if self.is_vlm_model: + if image_data or video_data: + input_dict = { + "prompt_token_ids": prompt_ids, + "multi_modal_data": {}, + } + if image_data: + input_dict["multi_modal_data"]["image"] = image_data + if video_data: + input_dict["multi_modal_data"]["video"] = video_data + outputs = await self.llm.generate_async( + inputs=input_dict, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) token_ids = outputs.outputs[0].token_ids log_probs = None - if trt_llm_sampling_params.logprobs is not None: - log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] + if outputs.outputs[0].logprobs is not None: + log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)] return TokenOutput(token_ids=token_ids, log_probs=log_probs) async def wake_up(self): diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 3c42ee7bc73..ba6a991b57d 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -281,6 +281,7 @@ def __init__( self.is_leader_rank = None self.replica_rank = None self.is_dp_rank = None + self._supports_partial_loading = None # hybrid mode if self.device_mesh is not None: @@ -312,6 +313,21 @@ def __init__( self.node_ip = ray.util.get_node_ip_address().strip("[]") + async def get_supports_partial_loading(self) -> bool: + """Query and cache whether the model supports partial weight loading.""" + if self._supports_partial_loading is not None: + return self._supports_partial_loading + + await self._init_server_adapter() + try: + self._supports_partial_loading = await self.server_actor.supports_partial_loading.remote() + except Exception as e: + logger.warning(f"Failed to query partial loading support: {e}, defaulting to False") + self._supports_partial_loading = False + + logger.info(f"Model supports partial loading: {self._supports_partial_loading}") + return self._supports_partial_loading + async def _init_server_adapter(self): if self._adapter is not None: return @@ -405,16 +421,21 @@ async def flush(): await self.update_weights_from_ipc_handles(serialized_device_handles) cur_available_bytes = total_available_bytes cur_handles = [] + + # Query if model supports partial loading + supports_partial_loading = await self.get_supports_partial_loading() for name, param in weights: - size_in_bytes = param.element_size() * param.numel() - if size_in_bytes > cur_available_bytes: - await flush() + if supports_partial_loading: + size_in_bytes = param.element_size() * param.numel() + if size_in_bytes > cur_available_bytes: + await flush() + + assert cur_available_bytes >= size_in_bytes, ( + f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" + ) + cur_available_bytes -= size_in_bytes - assert cur_available_bytes >= size_in_bytes, ( - f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" - ) - cur_available_bytes -= size_in_bytes handle = reduce_tensor(param.detach()) cur_handles.append((name, handle)) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py new file mode 100644 index 00000000000..86b341dbf84 --- /dev/null +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -0,0 +1,77 @@ +import base64 +import inspect +import pickle +from typing import Optional + +from tensorrt_llm._ray_utils import control_action_decorator +from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer +from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm.logger import logger + + +class WorkerExtension: + + def __init__(self): + pass + + @control_action_decorator + def supports_partial_loading(self) -> bool: + """Check if the model supports partial weight loading.""" + try: + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + return "allow_partial_loading" in load_weights_args + except Exception as e: + logger.warning(f"Failed to check partial loading support: {e}") + return False + + @control_action_decorator + def update_weights(self, ipc_handles: Optional[dict] = None): + try: + if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "pre_reload_weights") and not getattr( + module, "_weights_removed", False + ): + module.pre_reload_weights() + setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + + if ipc_handles is not None: + device_uuid = get_device_uuid() + handles = ipc_handles.get(device_uuid, None) + if handles is not None: + weights = pickle.loads(base64.b64decode(handles)) + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + supports_partial_loading = "allow_partial_loading" in load_weights_args + + if supports_partial_loading: + self.engine.model_engine.model_loader.reload( + model, weights, allow_partial_loading=True + ) + else: + self.engine.model_engine.model_loader.reload( + model, weights, allow_partial_loading=False + ) + else: + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "process_weights_after_loading") and not getattr( + module, "_weights_removed", False + ): + module.process_weights_after_loading() + if hasattr(module, "post_load_weights") and not getattr( + module, "_weights_removed", False + ): + module.post_load_weights() + moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + self.engine.reset_prefix_cache() + delattr(self.engine.model_engine.model, "first_pre_reload_weights") + + except Exception as e: + logger.error("Encountered an error in update_weights") + raise e From 0394ab512fdefe6be8a6b8fa5e2393dfa5e0777e Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Sat, 31 Jan 2026 20:03:04 +0800 Subject: [PATCH 2/9] Precommit check --- .../trtllm_rollout/trtllm_async_server.py | 5 +-- .../rollout/trtllm_rollout/trtllm_rollout.py | 4 +-- .../trtllm_rollout/trtllm_worker_extension.py | 32 +++++++++++-------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 6448075e476..3317a641fc1 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -161,6 +161,7 @@ async def launch_server(self): if self.is_vlm_model: from tensorrt_llm.inputs.multimodal import MultimodalServerConfig + multimodal_config = MultimodalServerConfig( media_io_kwargs={ "image": { @@ -193,11 +194,10 @@ async def launch_server(self): server_role=None, metadata_server_cfg=None, ) - + app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) - @resume_on_abort async def generate( self, prompt_ids: list[int], @@ -207,6 +207,7 @@ async def generate( video_data: Optional[list[Any]] = None, ) -> TokenOutput: from tensorrt_llm.llmapi import SamplingParams + max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids)) sampling_params["max_tokens"] = max_tokens sampling_params["logprobs"] = 1 if sampling_params.pop("logprobs", False) else None diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index ba6a991b57d..ce2527c66e7 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -281,7 +281,7 @@ def __init__( self.is_leader_rank = None self.replica_rank = None self.is_dp_rank = None - self._supports_partial_loading = None + self._supports_partial_loading = None # hybrid mode if self.device_mesh is not None: @@ -421,7 +421,7 @@ async def flush(): await self.update_weights_from_ipc_handles(serialized_device_handles) cur_available_bytes = total_available_bytes cur_handles = [] - + # Query if model supports partial loading supports_partial_loading = await self.get_supports_partial_loading() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 86b341dbf84..a7a96f607fa 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -1,3 +1,16 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import base64 import inspect import pickle @@ -10,7 +23,6 @@ class WorkerExtension: - def __init__(self): pass @@ -30,11 +42,9 @@ def update_weights(self, ipc_handles: Optional[dict] = None): try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): - if hasattr(module, "pre_reload_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): module.pre_reload_weights() - setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + self.engine.model_engine.model.first_pre_reload_weights = True if ipc_handles is not None: device_uuid = get_device_uuid() @@ -46,22 +56,16 @@ def update_weights(self, ipc_handles: Optional[dict] = None): supports_partial_loading = "allow_partial_loading" in load_weights_args if supports_partial_loading: - self.engine.model_engine.model_loader.reload( - model, weights, allow_partial_loading=True - ) + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) else: - self.engine.model_engine.model_loader.reload( - model, weights, allow_partial_loading=False - ) + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) else: for module in self.engine.model_engine.model.modules(): if hasattr(module, "process_weights_after_loading") and not getattr( module, "_weights_removed", False ): module.process_weights_after_loading() - if hasattr(module, "post_load_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): module.post_load_weights() moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) if isinstance(moe_load_balancer, MoeLoadBalancer): From 0664ab102b059063d29b4d420d28abbd146eef1c Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Sat, 31 Jan 2026 22:55:36 +0800 Subject: [PATCH 3/9] Add check for if the model is vlm in trtllmhttpserver --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 3317a641fc1..24b380a6962 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -89,6 +89,10 @@ def __init__( logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto") self.config.load_format = "auto" + self.is_vlm_model = ( + self.model_config.hf_config is not None and hasattr(self.model_config.hf_config, "vision_config") + ) or hasattr(self.model_config, "vision_config") + # used for http server self._server_address = ray.util.get_node_ip_address().strip("[]") self._server_port = None From bf71c9b4c3c19b6496133d93568119aea6d8951d Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 17:12:58 +0800 Subject: [PATCH 4/9] Support latest trtllm --- .../config/test_optim_config_on_cpu.py | 48 ---------- .../trtllm_rollout/trtllm_worker_extension.py | 96 +++++++++++++++---- 2 files changed, 80 insertions(+), 64 deletions(-) delete mode 100644 tests/workers/config/test_optim_config_on_cpu.py diff --git a/tests/workers/config/test_optim_config_on_cpu.py b/tests/workers/config/test_optim_config_on_cpu.py deleted file mode 100644 index b44cb40c6b1..00000000000 --- a/tests/workers/config/test_optim_config_on_cpu.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from verl.workers.config.optimizer import FSDPOptimizerConfig - - -class TestFSDPOptimizerConfigCPU: - def test_default_configuration(self): - config = FSDPOptimizerConfig(lr=0.1) - assert config.min_lr_ratio is None - assert config.lr_scheduler_type == "constant" - assert config.num_cycles == 0.5 - - @pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"]) - def test_valid_lr_scheduler_types(self, lr_scheduler_type): - config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1) - assert config.lr_scheduler_type == lr_scheduler_type - - @pytest.mark.parametrize("warmup_style", ["constant", "cosine"]) - def test_valid_warmup_style_types(self, warmup_style): - config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1) - assert config.lr_scheduler_type == warmup_style - - def test_invalid_lr_scheduler_type(self): - with pytest.raises((ValueError, AssertionError)): - FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1) - - def test_invalid_warmup_style_type(self): - with pytest.raises((ValueError, AssertionError)): - FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1) - - @pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5]) - def test_num_cycles_configuration(self, num_cycles): - config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1) - assert config.num_cycles == num_cycles diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index a7a96f607fa..ce56c8b9b5c 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -13,9 +13,11 @@ # limitations under the License. import base64 import inspect -import pickle from typing import Optional +import torch + +from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer from tensorrt_llm._torch.utils import get_device_uuid @@ -42,30 +44,85 @@ def update_weights(self, ipc_handles: Optional[dict] = None): try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): - if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): + if hasattr(module, "pre_reload_weights") and not getattr( + module, "_weights_removed", False + ): module.pre_reload_weights() - self.engine.model_engine.model.first_pre_reload_weights = True + setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) if ipc_handles is not None: - device_uuid = get_device_uuid() - handles = ipc_handles.get(device_uuid, None) - if handles is not None: - weights = pickle.loads(base64.b64decode(handles)) - model = self.engine.model_engine.model - load_weights_args = inspect.getfullargspec(model.load_weights).args - supports_partial_loading = "allow_partial_loading" in load_weights_args - - if supports_partial_loading: - self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) - else: - self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) + logger.info("Update weights from IPC handles") + device_uuid = get_device_uuid(self.device_id) + + if device_uuid not in ipc_handles: + raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") + + weights = {} + + serialized_handles = ipc_handles[device_uuid] + if isinstance(serialized_handles, str): + # Data is base64-encoded pickled bytes - deserialize it + # using restricted unpickler from tensorrt_llm.serialization + logger.info("Deserializing base64-encoded weight handles") + decoded_data = base64.b64decode(serialized_handles) + # Allow basic builtins and all torch modules + approved_imports = { + "builtins": [ + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + ], + } + all_handles = serialization.loads( + decoded_data, + approved_imports=approved_imports, + approved_module_patterns=[r"^torch.*"], + ) + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError( + f"Deserialized data must be a list, got {type(all_handles).__name__} instead" + ) + else: + # Data is already in the correct format (backward compatibility) + all_handles = serialized_handles + + for param_name, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = self.device_id + tensor = func(*list_args) + weights[param_name] = tensor + + logger.info(f"weights key size: {len(weights.keys())}") + + # Check if model supports partial loading and use appropriate strategy + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + supports_partial_loading = "allow_partial_loading" in load_weights_args + + if supports_partial_loading: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) + else: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) else: + logger.info("Finalize update weights") for module in self.engine.model_engine.model.modules(): if hasattr(module, "process_weights_after_loading") and not getattr( module, "_weights_removed", False ): module.process_weights_after_loading() - if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): + if hasattr(module, "post_load_weights") and not getattr( + module, "_weights_removed", False + ): module.post_load_weights() moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) if isinstance(moe_load_balancer, MoeLoadBalancer): @@ -79,3 +136,10 @@ def update_weights(self, ipc_handles: Optional[dict] = None): except Exception as e: logger.error("Encountered an error in update_weights") raise e + + def check_weights_updated(self) -> bool: + """Check if the weights are updated to 0.""" + weights_updated = True + for name, p in self.engine.model_engine.model.named_parameters(): + weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) + return weights_updated From f6e58b882ecccdfc42fde76dda58b556c9ea3fc6 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 19:53:43 +0800 Subject: [PATCH 5/9] Support for qwen2.5 vl --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 24b380a6962..221655847a4 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -225,6 +225,7 @@ async def generate( input_dict = { "prompt_token_ids": prompt_ids, "multi_modal_data": {}, + "mm_processor_kwargs": {}, } if image_data: input_dict["multi_modal_data"]["image"] = image_data From 7af6917e3ed9d0b3c9fa68a769a1e7d2fc8d4ee6 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 19:56:46 +0800 Subject: [PATCH 6/9] Add trtllm rollout test script --- .../rollout/rollout_trtllm/__init__.py | 14 + .../test_trtllm_rollout_utils.py | 458 ++++++++++++++++++ 2 files changed, 472 insertions(+) create mode 100644 tests/workers/rollout/rollout_trtllm/__init__.py create mode 100644 tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py diff --git a/tests/workers/rollout/rollout_trtllm/__init__.py b/tests/workers/rollout/rollout_trtllm/__init__.py new file mode 100644 index 00000000000..46866da4cd9 --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py new file mode 100644 index 00000000000..dd99f09f60c --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py @@ -0,0 +1,458 @@ +import asyncio +import os +import uuid + +import numpy as np +import pytest +import ray +import torch +from omegaconf import OmegaConf +from PIL import Image +from transformers import AutoTokenizer + +UNIMODAL_MODEL_PATH = "Qwen/Qwen2.5-Math-7B" +MULTIMODAL_MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct" + +MAX_MODEL_LEN = 4096 +RESPONSE_LENGTH = 256 +MAX_NUM_SEQS = 16 +GPU_MEMORY_UTILIZATION = 0.8 +TENSOR_PARALLEL_SIZE = 1 + + +def create_test_image(width: int = 224, height: int = 224) -> Image.Image: + img_array = np.zeros((height, width, 3), dtype=np.uint8) + for i in range(height): + for j in range(width): + img_array[i, j] = [ + int(255 * i / height), + int(255 * j / width), + int(255 * (i + j) / (height + width)), + ] + return Image.fromarray(img_array) + + +def create_rollout_config_dict(): + config_dict = { + "_target_": "verl.workers.config.RolloutConfig", + "name": "trtllm", + "mode": "async", + "temperature": 0.7, + "top_k": 50, + "top_p": 0.9, + "do_sample": True, + "n": 1, + "prompt_length": 512, + "response_length": RESPONSE_LENGTH, + "dtype": "bfloat16", + "gpu_memory_utilization": GPU_MEMORY_UTILIZATION, + "ignore_eos": False, + "enforce_eager": True, + "free_cache_engine": False, + "data_parallel_size": 1, + "tensor_model_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_model_parallel_size": 1, + "max_num_batched_tokens": 8192, + "max_model_len": MAX_MODEL_LEN, + "max_num_seqs": MAX_NUM_SEQS, + "load_format": "auto", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + } + return OmegaConf.create(config_dict) + + +def create_model_config_dict(model_path: str): + config_dict = { + "_target_": "verl.workers.config.HFModelConfig", + "path": model_path, + "trust_remote_code": True, + "load_tokenizer": True, + } + return OmegaConf.create(config_dict) + + +def get_tokenizer(model_path: str): + return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + +def get_processor(model_path: str): + from transformers import AutoProcessor + return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestUnimodalTRTLLMRollout: + + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "What is 2 + 2?", + "Solve for x: 3x + 5 = 20", + "Calculate the derivative of x^2 + 3x + 1", + ], + ) + def test_unimodal_generate(self, trtllm_replica, tokenizer, prompt): + replica = trtllm_replica + + messages = [ + {"role": "system", "content": "You are a helpful math assistant."}, + {"role": "user", "content": prompt}, + ] + + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": True, + } + + request_id = str(uuid.uuid4()) + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=request_id, + )) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print(f"\n[Unimodal Test]") + print(f"Prompt: {prompt}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): + replica = trtllm_replica + + prompts = [ + "What is 1 + 1?", + "What is 2 * 3?", + "What is 10 / 2?", + ] + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + results = [] + + for i, prompt in enumerate(prompts): + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + results.append(output) + + assert len(results) == len(prompts) + for i, (prompt, result) in enumerate(zip(prompts, results)): + assert result is not None + assert len(result.token_ids) > 0 + generated = tokenizer.decode(result.token_ids, skip_special_tokens=True) + print(f"\n[Batch {i}] Prompt: {prompt}") + print(f"Generated: {generated[:100]}...") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestMultimodalTRTLLMRollout: + + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_vlm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(MULTIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(MULTIMODAL_MODEL_PATH) + + @pytest.fixture(scope="class") + def processor(self): + return get_processor(MULTIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "Describe this image in detail.", + "What colors do you see in this image?", + "What patterns are visible in this image?", + ], + ) + def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tokenizer, prompt): + replica = trtllm_vlm_replica + + test_image = create_test_image(224, 224) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("text: ", text) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + print("input_ids decoded: ", processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False)) + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + )) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print(f"\n[Multimodal Test]") + print(f"Prompt: {prompt}") + print(f"Image size: {test_image.size}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + @pytest.mark.parametrize( + "image_size", + [(224, 224), (384, 384), (512, 512)], + ) + def test_multimodal_different_image_sizes(self, trtllm_vlm_replica, processor, tokenizer, image_size): + replica = trtllm_vlm_replica + + width, height = image_size + test_image = create_test_image(width, height) + + prompt = "What is shown in this image?" + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + )) + + assert output is not None + assert len(output.token_ids) > 0 + print(f"\n[Image Size {image_size}] Generated {len(output.token_ids)} tokens") + + def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): + replica = trtllm_vlm_replica + + prompt = "What is the capital of China?" + messages = [{"role": "user", "content": prompt}] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + + assert output is not None + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print(f"\n[Text-only on VLM]") + print(f"Prompt: {prompt}") + print(f"Generated: {generated_text}") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestTRTLLMServerLifecycle: + + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica_lifecycle(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica, loop + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): + replica, loop = trtllm_replica_lifecycle + + prompt = "Hello, world!" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = {"temperature": 0.7, "top_p": 0.9, "top_k": 50, "logprobs": False} + + output1 = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + assert output1 is not None + assert len(output1.token_ids) > 0 + print(f"\n[Before Sleep] Generated {len(output1.token_ids)} tokens") + + loop.run_until_complete(replica.sleep()) + print("[Sleep] Server put to sleep") + + loop.run_until_complete(replica.wake_up()) + print("[Wake Up] Server woken up") + + output2 = ray.get(replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + )) + assert output2 is not None + assert len(output2.token_ids) > 0 + print(f"[After Wake Up] Generated {len(output2.token_ids)} tokens") From 94c4eb0a1adce3881fd6d978c3a7e07cb8d6c0ae Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 19:57:22 +0800 Subject: [PATCH 7/9] Add test_trtllm_rollout workflow to test trtllm_rollout --- .github/workflows/test_trtllm_rollout.yml | 82 +++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 .github/workflows/test_trtllm_rollout.yml diff --git a/.github/workflows/test_trtllm_rollout.yml b/.github/workflows/test_trtllm_rollout.yml new file mode 100644 index 00000000000..9c714de4892 --- /dev/null +++ b/.github/workflows/test_trtllm_rollout.yml @@ -0,0 +1,82 @@ +name: test_trtllm_rollout + +on: + push: + branches: + - main + - v0.* + paths: + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/test_trtllm_rollout.yml" + pull_request: + branches: + - main + - v0.* + paths: + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/test_trtllm_rollout.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +permissions: + contents: read + +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:trtllm1.2.0rc6" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + +jobs: + setup: + if: github.repository_owner == 'verl-project' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + + test_trtllm_rollout: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install -r requirements-test.txt + pip3 install --no-deps -e . + - name: Run TRT-LLM rollout tests + run: | + ray stop --force + pytest -v -s tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py + + cleanup: + runs-on: ubuntu-latest + needs: [setup, test_trtllm_rollout] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" + From 25518fee2daee8e5d06575d1526a08c7a20fe124 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 20:00:34 +0800 Subject: [PATCH 8/9] Add back mistakenly deleted file --- .../config/test_optim_config_on_cpu.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/workers/config/test_optim_config_on_cpu.py diff --git a/tests/workers/config/test_optim_config_on_cpu.py b/tests/workers/config/test_optim_config_on_cpu.py new file mode 100644 index 00000000000..5aae6bb8c2c --- /dev/null +++ b/tests/workers/config/test_optim_config_on_cpu.py @@ -0,0 +1,48 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from verl.workers.config.optimizer import FSDPOptimizerConfig + + +class TestFSDPOptimizerConfigCPU: + def test_default_configuration(self): + config = FSDPOptimizerConfig(lr=0.1) + assert config.min_lr_ratio is None + assert config.lr_scheduler_type == "constant" + assert config.num_cycles == 0.5 + + @pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"]) + def test_valid_lr_scheduler_types(self, lr_scheduler_type): + config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1) + assert config.lr_scheduler_type == lr_scheduler_type + + @pytest.mark.parametrize("warmup_style", ["constant", "cosine"]) + def test_valid_warmup_style_types(self, warmup_style): + config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1) + assert config.lr_scheduler_type == warmup_style + + def test_invalid_lr_scheduler_type(self): + with pytest.raises((ValueError, AssertionError)): + FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1) + + def test_invalid_warmup_style_type(self): + with pytest.raises((ValueError, AssertionError)): + FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1) + + @pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5]) + def test_num_cycles_configuration(self, num_cycles): + config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1) + assert config.num_cycles == num_cycles \ No newline at end of file From fd007fb333b72d8fed9cc11aa105407f6fe9eaa5 Mon Sep 17 00:00:00 2001 From: dingruiyi Date: Mon, 2 Feb 2026 20:03:20 +0800 Subject: [PATCH 9/9] Precommit check --- .../config/test_optim_config_on_cpu.py | 2 +- .../rollout/rollout_trtllm/__init__.py | 3 +- .../test_trtllm_rollout_utils.py | 121 +++++++++++------- .../trtllm_rollout/trtllm_worker_extension.py | 15 +-- 4 files changed, 80 insertions(+), 61 deletions(-) diff --git a/tests/workers/config/test_optim_config_on_cpu.py b/tests/workers/config/test_optim_config_on_cpu.py index 5aae6bb8c2c..b44cb40c6b1 100644 --- a/tests/workers/config/test_optim_config_on_cpu.py +++ b/tests/workers/config/test_optim_config_on_cpu.py @@ -45,4 +45,4 @@ def test_invalid_warmup_style_type(self): @pytest.mark.parametrize("num_cycles", [0.1, 1.0, 2.5]) def test_num_cycles_configuration(self, num_cycles): config = FSDPOptimizerConfig(num_cycles=num_cycles, lr=0.1) - assert config.num_cycles == num_cycles \ No newline at end of file + assert config.num_cycles == num_cycles diff --git a/tests/workers/rollout/rollout_trtllm/__init__.py b/tests/workers/rollout/rollout_trtllm/__init__.py index 46866da4cd9..d828409b82e 100644 --- a/tests/workers/rollout/rollout_trtllm/__init__.py +++ b/tests/workers/rollout/rollout_trtllm/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2026 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py index dd99f09f60c..21ab5689113 100644 --- a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py +++ b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py @@ -1,5 +1,17 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio -import os import uuid import numpy as np @@ -78,6 +90,7 @@ def get_tokenizer(model_path: str): def get_processor(model_path: str): from transformers import AutoProcessor + return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) @@ -86,7 +99,6 @@ def get_processor(model_path: str): reason="CUDA not available", ) class TestUnimodalTRTLLMRollout: - @pytest.fixture(scope="class") def ray_context(self): if ray.is_initialized(): @@ -153,18 +165,20 @@ def test_unimodal_generate(self, trtllm_replica, tokenizer, prompt): } request_id = str(uuid.uuid4()) - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=request_id, - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=request_id, + ) + ) assert output is not None assert hasattr(output, "token_ids") assert len(output.token_ids) > 0 generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) - print(f"\n[Unimodal Test]") + print("\n[Unimodal Test]") print(f"Prompt: {prompt}") print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") @@ -191,15 +205,17 @@ def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) results.append(output) assert len(results) == len(prompts) - for i, (prompt, result) in enumerate(zip(prompts, results)): + for i, (prompt, result) in enumerate(zip(prompts, results, strict=False)): assert result is not None assert len(result.token_ids) > 0 generated = tokenizer.decode(result.token_ids, skip_special_tokens=True) @@ -212,7 +228,6 @@ def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): reason="CUDA not available", ) class TestMultimodalTRTLLMRollout: - @pytest.fixture(scope="class") def ray_context(self): if ray.is_initialized(): @@ -283,8 +298,11 @@ def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tok print("text: ", text) input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() - print("input_ids decoded: ", processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False)) - + print( + "input_ids decoded: ", + processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False), + ) + sampling_params = { "temperature": 0.7, "top_p": 0.9, @@ -292,19 +310,21 @@ def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tok "logprobs": False, } - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - image_data=[test_image], - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) assert output is not None assert hasattr(output, "token_ids") assert len(output.token_ids) > 0 generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) - print(f"\n[Multimodal Test]") + print("\n[Multimodal Test]") print(f"Prompt: {prompt}") print(f"Image size: {test_image.size}") print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") @@ -340,12 +360,14 @@ def test_multimodal_different_image_sizes(self, trtllm_vlm_replica, processor, t "logprobs": False, } - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - image_data=[test_image], - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) assert output is not None assert len(output.token_ids) > 0 @@ -367,17 +389,19 @@ def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): "logprobs": False, } - output = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) assert output is not None assert len(output.token_ids) > 0 generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) - print(f"\n[Text-only on VLM]") + print("\n[Text-only on VLM]") print(f"Prompt: {prompt}") print(f"Generated: {generated_text}") @@ -387,7 +411,6 @@ def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): reason="CUDA not available", ) class TestTRTLLMServerLifecycle: - @pytest.fixture(scope="class") def ray_context(self): if ray.is_initialized(): @@ -433,11 +456,13 @@ def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): sampling_params = {"temperature": 0.7, "top_p": 0.9, "top_k": 50, "logprobs": False} - output1 = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output1 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) assert output1 is not None assert len(output1.token_ids) > 0 print(f"\n[Before Sleep] Generated {len(output1.token_ids)} tokens") @@ -448,11 +473,13 @@ def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): loop.run_until_complete(replica.wake_up()) print("[Wake Up] Server woken up") - output2 = ray.get(replica.server_handle.generate.remote( - prompt_ids=input_ids, - sampling_params=sampling_params, - request_id=str(uuid.uuid4()), - )) + output2 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) assert output2 is not None assert len(output2.token_ids) > 0 print(f"[After Wake Up] Generated {len(output2.token_ids)} tokens") diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index ce56c8b9b5c..6bb5190dfbc 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -16,7 +16,6 @@ from typing import Optional import torch - from tensorrt_llm import serialization from tensorrt_llm._ray_utils import control_action_decorator from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer @@ -44,11 +43,9 @@ def update_weights(self, ipc_handles: Optional[dict] = None): try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): - if hasattr(module, "pre_reload_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): module.pre_reload_weights() - setattr(self.engine.model_engine.model, "first_pre_reload_weights", True) + self.engine.model_engine.model.first_pre_reload_weights = True if ipc_handles is not None: logger.info("Update weights from IPC handles") @@ -88,9 +85,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): # Verify the result is a list as expected if not isinstance(all_handles, list): - raise ValueError( - f"Deserialized data must be a list, got {type(all_handles).__name__} instead" - ) + raise ValueError(f"Deserialized data must be a list, got {type(all_handles).__name__} instead") else: # Data is already in the correct format (backward compatibility) all_handles = serialized_handles @@ -120,9 +115,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): module, "_weights_removed", False ): module.process_weights_after_loading() - if hasattr(module, "post_load_weights") and not getattr( - module, "_weights_removed", False - ): + if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): module.post_load_weights() moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) if isinstance(moe_load_balancer, MoeLoadBalancer):