diff --git a/.github/workflows/test_trtllm_rollout.yml b/.github/workflows/test_trtllm_rollout.yml new file mode 100644 index 00000000000..9c714de4892 --- /dev/null +++ b/.github/workflows/test_trtllm_rollout.yml @@ -0,0 +1,82 @@ +name: test_trtllm_rollout + +on: + push: + branches: + - main + - v0.* + paths: + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/test_trtllm_rollout.yml" + pull_request: + branches: + - main + - v0.* + paths: + - "verl/workers/rollout/trtllm_rollout/**" + - "tests/workers/rollout/rollout_trtllm/**" + - ".github/workflows/test_trtllm_rollout.yml" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +permissions: + contents: read + +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:trtllm1.2.0rc6" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + +jobs: + setup: + if: github.repository_owner == 'verl-project' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + + test_trtllm_rollout: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install -r requirements-test.txt + pip3 install --no-deps -e . + - name: Run TRT-LLM rollout tests + run: | + ray stop --force + pytest -v -s tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py + + cleanup: + runs-on: ubuntu-latest + needs: [setup, test_trtllm_rollout] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" + diff --git a/tests/workers/rollout/rollout_trtllm/__init__.py b/tests/workers/rollout/rollout_trtllm/__init__.py new file mode 100644 index 00000000000..d828409b82e --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py new file mode 100644 index 00000000000..21ab5689113 --- /dev/null +++ b/tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py @@ -0,0 +1,485 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import uuid + +import numpy as np +import pytest +import ray +import torch +from omegaconf import OmegaConf +from PIL import Image +from transformers import AutoTokenizer + +UNIMODAL_MODEL_PATH = "Qwen/Qwen2.5-Math-7B" +MULTIMODAL_MODEL_PATH = "Qwen/Qwen2.5-VL-7B-Instruct" + +MAX_MODEL_LEN = 4096 +RESPONSE_LENGTH = 256 +MAX_NUM_SEQS = 16 +GPU_MEMORY_UTILIZATION = 0.8 +TENSOR_PARALLEL_SIZE = 1 + + +def create_test_image(width: int = 224, height: int = 224) -> Image.Image: + img_array = np.zeros((height, width, 3), dtype=np.uint8) + for i in range(height): + for j in range(width): + img_array[i, j] = [ + int(255 * i / height), + int(255 * j / width), + int(255 * (i + j) / (height + width)), + ] + return Image.fromarray(img_array) + + +def create_rollout_config_dict(): + config_dict = { + "_target_": "verl.workers.config.RolloutConfig", + "name": "trtllm", + "mode": "async", + "temperature": 0.7, + "top_k": 50, + "top_p": 0.9, + "do_sample": True, + "n": 1, + "prompt_length": 512, + "response_length": RESPONSE_LENGTH, + "dtype": "bfloat16", + "gpu_memory_utilization": GPU_MEMORY_UTILIZATION, + "ignore_eos": False, + "enforce_eager": True, + "free_cache_engine": False, + "data_parallel_size": 1, + "tensor_model_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_model_parallel_size": 1, + "max_num_batched_tokens": 8192, + "max_model_len": MAX_MODEL_LEN, + "max_num_seqs": MAX_NUM_SEQS, + "load_format": "auto", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + } + return OmegaConf.create(config_dict) + + +def create_model_config_dict(model_path: str): + config_dict = { + "_target_": "verl.workers.config.HFModelConfig", + "path": model_path, + "trust_remote_code": True, + "load_tokenizer": True, + } + return OmegaConf.create(config_dict) + + +def get_tokenizer(model_path: str): + return AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + +def get_processor(model_path: str): + from transformers import AutoProcessor + + return AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestUnimodalTRTLLMRollout: + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "What is 2 + 2?", + "Solve for x: 3x + 5 = 20", + "Calculate the derivative of x^2 + 3x + 1", + ], + ) + def test_unimodal_generate(self, trtllm_replica, tokenizer, prompt): + replica = trtllm_replica + + messages = [ + {"role": "system", "content": "You are a helpful math assistant."}, + {"role": "user", "content": prompt}, + ] + + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": True, + } + + request_id = str(uuid.uuid4()) + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=request_id, + ) + ) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print("\n[Unimodal Test]") + print(f"Prompt: {prompt}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + def test_unimodal_batch_generate(self, trtllm_replica, tokenizer): + replica = trtllm_replica + + prompts = [ + "What is 1 + 1?", + "What is 2 * 3?", + "What is 10 / 2?", + ] + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + results = [] + + for i, prompt in enumerate(prompts): + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + results.append(output) + + assert len(results) == len(prompts) + for i, (prompt, result) in enumerate(zip(prompts, results, strict=False)): + assert result is not None + assert len(result.token_ids) > 0 + generated = tokenizer.decode(result.token_ids, skip_special_tokens=True) + print(f"\n[Batch {i}] Prompt: {prompt}") + print(f"Generated: {generated[:100]}...") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestMultimodalTRTLLMRollout: + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_vlm_replica(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(MULTIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(MULTIMODAL_MODEL_PATH) + + @pytest.fixture(scope="class") + def processor(self): + return get_processor(MULTIMODAL_MODEL_PATH) + + @pytest.mark.parametrize( + "prompt", + [ + "Describe this image in detail.", + "What colors do you see in this image?", + "What patterns are visible in this image?", + ], + ) + def test_multimodal_generate_with_image(self, trtllm_vlm_replica, processor, tokenizer, prompt): + replica = trtllm_vlm_replica + + test_image = create_test_image(224, 224) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + print("text: ", text) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + print( + "input_ids decoded: ", + processor.tokenizer.decode(input_ids, skip_special_tokens=False, add_special_tokens=False), + ) + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) + + assert output is not None + assert hasattr(output, "token_ids") + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print("\n[Multimodal Test]") + print(f"Prompt: {prompt}") + print(f"Image size: {test_image.size}") + print(f"Generated ({len(output.token_ids)} tokens): {generated_text[:300]}...") + + @pytest.mark.parametrize( + "image_size", + [(224, 224), (384, 384), (512, 512)], + ) + def test_multimodal_different_image_sizes(self, trtllm_vlm_replica, processor, tokenizer, image_size): + replica = trtllm_vlm_replica + + width, height = image_size + test_image = create_test_image(width, height) + + prompt = "What is shown in this image?" + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": prompt}, + ], + } + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = processor.tokenizer(text, return_tensors="pt", padding=True)["input_ids"][0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + image_data=[test_image], + ) + ) + + assert output is not None + assert len(output.token_ids) > 0 + print(f"\n[Image Size {image_size}] Generated {len(output.token_ids)} tokens") + + def test_multimodal_text_only_fallback(self, trtllm_vlm_replica, tokenizer): + replica = trtllm_vlm_replica + + prompt = "What is the capital of China?" + messages = [{"role": "user", "content": prompt}] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 50, + "logprobs": False, + } + + output = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + + assert output is not None + assert len(output.token_ids) > 0 + + generated_text = tokenizer.decode(output.token_ids, skip_special_tokens=True) + print("\n[Text-only on VLM]") + print(f"Prompt: {prompt}") + print(f"Generated: {generated_text}") + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) +class TestTRTLLMServerLifecycle: + @pytest.fixture(scope="class") + def ray_context(self): + if ray.is_initialized(): + ray.shutdown() + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + @pytest.fixture(scope="class") + def trtllm_replica_lifecycle(self, ray_context): + from verl.workers.rollout.trtllm_rollout.trtllm_async_server import TRTLLMReplica + + rollout_config = create_rollout_config_dict() + model_config = create_model_config_dict(UNIMODAL_MODEL_PATH) + + replica = TRTLLMReplica( + replica_rank=0, + config=rollout_config, + model_config=model_config, + gpus_per_node=torch.cuda.device_count(), + is_reward_model=False, + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(replica.init_standalone()) + + yield replica, loop + + loop.close() + + @pytest.fixture(scope="class") + def tokenizer(self): + return get_tokenizer(UNIMODAL_MODEL_PATH) + + def test_wake_sleep_cycle(self, trtllm_replica_lifecycle, tokenizer): + replica, loop = trtllm_replica_lifecycle + + prompt = "Hello, world!" + messages = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer.encode(text, return_tensors="pt")[0].tolist() + + sampling_params = {"temperature": 0.7, "top_p": 0.9, "top_k": 50, "logprobs": False} + + output1 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + assert output1 is not None + assert len(output1.token_ids) > 0 + print(f"\n[Before Sleep] Generated {len(output1.token_ids)} tokens") + + loop.run_until_complete(replica.sleep()) + print("[Sleep] Server put to sleep") + + loop.run_until_complete(replica.wake_up()) + print("[Wake Up] Server woken up") + + output2 = ray.get( + replica.server_handle.generate.remote( + prompt_ids=input_ids, + sampling_params=sampling_params, + request_id=str(uuid.uuid4()), + ) + ) + assert output2 is not None + assert len(output2.token_ids) > 0 + print(f"[After Wake Up] Generated {len(output2.token_ids)} tokens") diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 2f7a7b1276a..80b83d3b501 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -89,6 +89,10 @@ def __init__( logger.warning(f"rollout mode is {self.rollout_mode}, load_format is dummy, set to auto") self.config.load_format = "auto" + self.is_vlm_model = ( + self.model_config.hf_config is not None and hasattr(self.model_config.hf_config, "vision_config") + ) or hasattr(self.model_config, "vision_config") + # used for http server self._server_address = ray.util.get_node_ip_address().strip("[]") self._server_port = None @@ -125,7 +129,7 @@ async def launch_server(self): "model": self.model_config.local_path, "backend": "pytorch", "orchestrator_type": "ray", - "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + "ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension", "kv_cache_config": kv_cache_config, "max_seq_len": self.config.max_model_len, "max_batch_size": self.config.max_num_seqs, @@ -159,15 +163,42 @@ async def launch_server(self): } ) - self.llm = await AsyncLLM(**llm_kwargs) + if self.is_vlm_model: + from tensorrt_llm.inputs.multimodal import MultimodalServerConfig + + multimodal_config = MultimodalServerConfig( + media_io_kwargs={ + "image": { + "format": "pil", + "device": "cpu", + }, + "video": { + "num_frames": 8, + "fps": 30, + "format": "pil", + "device": "cpu", + }, + } + ) + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + multimodal_server_config=multimodal_config, + ) + else: + self.llm = await AsyncLLM(**llm_kwargs) + trtllm_server = OpenAIServer( + llm=self.llm, + model=self.model_config.local_path, + tool_parser=None, + server_role=None, + metadata_server_cfg=None, + ) - trtllm_server = OpenAIServer( - llm=self.llm, - model=self.model_config.local_path, - tool_parser=None, - server_role=None, - metadata_server_cfg=None, - ) app = trtllm_server.app self._server_port, self._server_task = await run_unvicorn(app, None, self._server_address) @@ -179,9 +210,6 @@ async def generate( image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, ) -> TokenOutput: - """Generate sequence with token-in-token-out.""" - assert image_data is None and video_data is None, "Multimodality is not yet supported in TRTLLMHttpServer." - from tensorrt_llm.llmapi import SamplingParams max_tokens = min(self.config.response_length, self.config.max_model_len - len(prompt_ids)) @@ -192,15 +220,35 @@ async def generate( sampling_params.update(self.sampling_args) trt_llm_sampling_params = SamplingParams(**sampling_params) - outputs = await self.llm.generate_async( - inputs=prompt_ids, - sampling_params=trt_llm_sampling_params, - ) - + if self.is_vlm_model: + if image_data or video_data: + input_dict = { + "prompt_token_ids": prompt_ids, + "multi_modal_data": {}, + "mm_processor_kwargs": {}, + } + if image_data: + input_dict["multi_modal_data"]["image"] = image_data + if video_data: + input_dict["multi_modal_data"]["video"] = video_data + outputs = await self.llm.generate_async( + inputs=input_dict, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) + else: + outputs = await self.llm.generate_async( + inputs=prompt_ids, + sampling_params=trt_llm_sampling_params, + ) token_ids = outputs.outputs[0].token_ids log_probs = None - if trt_llm_sampling_params.logprobs is not None: - log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] + if outputs.outputs[0].logprobs is not None: + log_probs = [logprobs[token_ids[i]].logprob for i, logprobs in enumerate(outputs.outputs[0].logprobs)] return TokenOutput(token_ids=token_ids, log_probs=log_probs) async def wake_up(self): diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 099c0eb7995..df82c0adf88 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -281,6 +281,7 @@ def __init__( self.is_leader_rank = None self.replica_rank = None self.is_dp_rank = None + self._supports_partial_loading = None # hybrid mode if self.device_mesh is not None: @@ -312,6 +313,21 @@ def __init__( self.node_ip = ray.util.get_node_ip_address().strip("[]") + async def get_supports_partial_loading(self) -> bool: + """Query and cache whether the model supports partial weight loading.""" + if self._supports_partial_loading is not None: + return self._supports_partial_loading + + await self._init_server_adapter() + try: + self._supports_partial_loading = await self.server_actor.supports_partial_loading.remote() + except Exception as e: + logger.warning(f"Failed to query partial loading support: {e}, defaulting to False") + self._supports_partial_loading = False + + logger.info(f"Model supports partial loading: {self._supports_partial_loading}") + return self._supports_partial_loading + async def _init_server_adapter(self): if self._adapter is not None: return @@ -406,15 +422,20 @@ async def flush(): cur_available_bytes = total_available_bytes cur_handles = [] + # Query if model supports partial loading + supports_partial_loading = await self.get_supports_partial_loading() + for name, param in weights: - size_in_bytes = param.element_size() * param.numel() - if size_in_bytes > cur_available_bytes: - await flush() + if supports_partial_loading: + size_in_bytes = param.element_size() * param.numel() + if size_in_bytes > cur_available_bytes: + await flush() + + assert cur_available_bytes >= size_in_bytes, ( + f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" + ) + cur_available_bytes -= size_in_bytes - assert cur_available_bytes >= size_in_bytes, ( - f"cur_available_bytes: {cur_available_bytes:,} size_in_bytes: {size_in_bytes:,} name: {name}" - ) - cur_available_bytes -= size_in_bytes handle = reduce_tensor(param.detach()) cur_handles.append((name, handle)) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py new file mode 100644 index 00000000000..6bb5190dfbc --- /dev/null +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -0,0 +1,138 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import base64 +import inspect +from typing import Optional + +import torch +from tensorrt_llm import serialization +from tensorrt_llm._ray_utils import control_action_decorator +from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import MoeLoadBalancer +from tensorrt_llm._torch.utils import get_device_uuid +from tensorrt_llm.logger import logger + + +class WorkerExtension: + def __init__(self): + pass + + @control_action_decorator + def supports_partial_loading(self) -> bool: + """Check if the model supports partial weight loading.""" + try: + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + return "allow_partial_loading" in load_weights_args + except Exception as e: + logger.warning(f"Failed to check partial loading support: {e}") + return False + + @control_action_decorator + def update_weights(self, ipc_handles: Optional[dict] = None): + try: + if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "pre_reload_weights") and not getattr(module, "_weights_removed", False): + module.pre_reload_weights() + self.engine.model_engine.model.first_pre_reload_weights = True + + if ipc_handles is not None: + logger.info("Update weights from IPC handles") + device_uuid = get_device_uuid(self.device_id) + + if device_uuid not in ipc_handles: + raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles") + + weights = {} + + serialized_handles = ipc_handles[device_uuid] + if isinstance(serialized_handles, str): + # Data is base64-encoded pickled bytes - deserialize it + # using restricted unpickler from tensorrt_llm.serialization + logger.info("Deserializing base64-encoded weight handles") + decoded_data = base64.b64decode(serialized_handles) + # Allow basic builtins and all torch modules + approved_imports = { + "builtins": [ + "list", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "dict", + "NoneType", + "type", + ], + } + all_handles = serialization.loads( + decoded_data, + approved_imports=approved_imports, + approved_module_patterns=[r"^torch.*"], + ) + + # Verify the result is a list as expected + if not isinstance(all_handles, list): + raise ValueError(f"Deserialized data must be a list, got {type(all_handles).__name__} instead") + else: + # Data is already in the correct format (backward compatibility) + all_handles = serialized_handles + + for param_name, tensor_handle in all_handles: + func, args = tensor_handle + list_args = list(args) + list_args[6] = self.device_id + tensor = func(*list_args) + weights[param_name] = tensor + + logger.info(f"weights key size: {len(weights.keys())}") + + # Check if model supports partial loading and use appropriate strategy + model = self.engine.model_engine.model + load_weights_args = inspect.getfullargspec(model.load_weights).args + supports_partial_loading = "allow_partial_loading" in load_weights_args + + if supports_partial_loading: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=True) + else: + self.engine.model_engine.model_loader.reload(model, weights, allow_partial_loading=False) + else: + logger.info("Finalize update weights") + for module in self.engine.model_engine.model.modules(): + if hasattr(module, "process_weights_after_loading") and not getattr( + module, "_weights_removed", False + ): + module.process_weights_after_loading() + if hasattr(module, "post_load_weights") and not getattr(module, "_weights_removed", False): + module.post_load_weights() + moe_load_balancer = getattr(self.engine.model_engine, "moe_load_balancer", None) + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + self.engine.reset_prefix_cache() + delattr(self.engine.model_engine.model, "first_pre_reload_weights") + + except Exception as e: + logger.error("Encountered an error in update_weights") + raise e + + def check_weights_updated(self) -> bool: + """Check if the weights are updated to 0.""" + weights_updated = True + for name, p in self.engine.model_engine.model.named_parameters(): + weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) + return weights_updated