From 630147656abac618a97860d232d9534c4d698569 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 8 Sep 2025 15:39:13 +0800 Subject: [PATCH 01/42] support only sync lora weight --- swift/llm/argument/deploy_args.py | 4 +- .../infer/infer_engine/grpo_vllm_engine.py | 27 ++++- swift/llm/infer/rollout.py | 2 + swift/trainers/rlhf_arguments.py | 1 + swift/trainers/rlhf_trainer/grpo_trainer.py | 73 ++++++++++-- swift/trainers/rlhf_trainer/utils.py | 105 +++++++++++++++++- 6 files changed, 196 insertions(+), 16 deletions(-) diff --git a/swift/llm/argument/deploy_args.py b/swift/llm/argument/deploy_args.py index c6a762bec0..fe326a9df5 100644 --- a/swift/llm/argument/deploy_args.py +++ b/swift/llm/argument/deploy_args.py @@ -86,7 +86,9 @@ class RolloutArguments(DeployArguments): # only for GRPO rollout with AsyncEngine, see details in swift/plugin/multi_turn multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None - + # lora, TODO: modify example script for lora + vllm_enable_lora: bool = False + vllm_max_lora_rank: int = 16 # GYM env gym_env: Optional[str] = None context_manager: Optional[str] = None diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index ac546f4b0d..eb682b5424 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -14,6 +14,7 @@ try: os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400' + from swift.trainers.rlhf_trainer.utils import TensorLoRARequest except Exception: raise @@ -96,8 +97,18 @@ def infer( *, template: Optional[Template] = None, use_tqdm: Optional[bool] = None, - adapter_request: Optional[AdapterRequest] = None, + adapter_request: Optional[Union[AdapterRequest, TensorLoRARequest]] = None, ) -> List[RolloutOutput]: + if not adapter_request and self.enable_lora: + # TODO: check if get the latest lora + lora_int_ids = list(self.llm_engine.list_loras()) + if lora_int_ids: + adapter_request = TensorLoRARequest( + lora_name=f'lora_{lora_int_ids[0]}', + lora_int_id=lora_int_ids[0], + path='dummy_lora_path', + ) + res = super().infer( infer_requests, request_config, @@ -189,3 +200,17 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r id=request_id, prompt_token_ids=prompt_token_ids, images_size=images_size) + + def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, TensorLoRARequest]] = None): + assert self.enable_lora, f'adapter_request: {adapter_request}, self.enable_lora: {self.enable_lora}' + from vllm.lora.request import LoRARequest + if isinstance(adapter_request, AdapterRequest): + return super()._add_adapter(adapter_request) + elif isinstance(adapter_request, TensorLoRARequest): + return adapter_request + else: + raise ValueError(f'Invalid adapter request: {adapter_request}') + + @property + def llm_engine(self): + return self.engine.llm_engine diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index fe2f366f8c..a130e67b89 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -222,6 +222,8 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, + 'enable_lora': args.vllm_enable_lora, + 'max_lora_rank': args.vllm_max_lora_rank, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend if infer_backend != 'vllm': diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py index bdb9f83c9a..ce9a6b33af 100644 --- a/swift/trainers/rlhf_arguments.py +++ b/swift/trainers/rlhf_arguments.py @@ -51,6 +51,7 @@ class GKDConfig(SwiftArgumentsMixin, HfGKDConfig): @dataclass class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig): stop_words: List[str] = field(default_factory=list) + lora_rank: int = 8 # for vllm lora adapter def __post_init__(self): GRPOArgumentsMixin.__post_init__(self) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index d17a8cac88..11d8fb7515 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -7,7 +7,7 @@ import re import time import uuid -from collections import defaultdict, deque +from collections import OrderedDict, defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -25,6 +25,7 @@ from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from dacite import from_dict from packaging import version +from peft.utils.save_and_load import get_peft_model_state_dict from torch.nn import ModuleList from torch.utils.data import DataLoader from transformers import PreTrainedModel, TrainerCallback @@ -49,9 +50,10 @@ unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import (_ForwardRedirection, compute_chord_loss, identity_data_collator, load_pil_img, - make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, - patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids) +from .utils import (TensorLoRARequest, _ForwardRedirection, compute_chord_loss, get_gather_if_zero3_context, + identity_data_collator, load_pil_img, make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, + patch_profiling_context, patch_profiling_decorator, patch_save_last_checkpoint, + patch_vllm_load_adapter, replace_assistant_response_with_ids) from .vllm_client import VLLMClient try: @@ -258,6 +260,9 @@ def __init__(self, if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm -U` to use it.') + self.args.train_type = 'full' + self.base_sync_done = False # tag for lora weights sync + if self.vllm_mode == 'server': self.vllm_client: VLLMClient = vllm_client if self.accelerator.is_main_process: @@ -532,6 +537,14 @@ def prepare_vllm(self, model): from swift.llm.infer.infer_engine import GRPOVllmEngine max_num_seqs = ( self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) + lora_kwargs = {} + if self.args.train_type == 'lora': + lora_kwargs = { + 'enable_lora': True, + 'max_loras': 1, + 'max_lora_rank': self.args.lora_rank, + } + patch_vllm_load_adapter() with Swift.grpo_context(model, self.template.processor): engine = GRPOVllmEngine( model.model_dir, @@ -551,6 +564,7 @@ def prepare_vllm(self, model): load_format='dummy', template=copy(self.template), distributed_executor_backend='external_launcher', + **lora_kwargs, ) return engine @@ -566,18 +580,52 @@ def _template_context(self, template: Template): @patch_profiling_decorator def _move_model_to_vllm(self, skip_async_check=False): - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - if self.args.async_generate and not skip_async_check: # before sync weight, we should wait async generate finish self._wait_queue() + train_type = self.args.train_type + + if train_type == 'full' or (train_type == 'lora' and not self.base_sync_done): + self._move_full_model_to_vllm() + else: + self._move_adapter_to_vllm() + + def _move_adapter_to_vllm(self): + lora_params = OrderedDict() + for i, parameter_group in enumerate(self.parameter_groups): # < this is the change + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + gather_if_zero3 = get_gather_if_zero3_context(self) + with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): + peft_config = self.model.peft_config.get('default', None) + self.model.merge_adapter() + cur_lora_params = get_peft_model_state_dict(self.model) + cur_lora_params = { + name: param.full_tensor().detach().cpu() if hasattr(param, 'full_tensor') else param.detach().cpu() + for name, param in lora_params.items() + } + lora_params.update(cur_lora_params) + self.model.unmerge_adapter() + del cur_lora_params + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) + lora_reqest = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=asdict(peft_config), + lora_tensors=lora_params, + ) + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + self.vllm_client.add_lora(lora_reqest) # TODO + elif self.vllm_mode == 'colocate': + self.engine.llm_engine.add_lora(lora_reqest) + del lora_params + + def _move_full_model_to_vllm(self): + gather_if_zero3 = get_gather_if_zero3_context(self) if is_peft_model(self.model): for i, parameter_group in enumerate(self.parameter_groups): # < this is the change parameter_group_no_lora = self.parameter_groups_no_lora[i] @@ -613,6 +661,7 @@ def _move_model_to_vllm(self, skip_async_check=False): with patch_lora_unmerge(self.model): self.model.unmerge_adapter() del state_dict + self.base_sync_done = True else: for name, param in self.model.named_parameters(): with gather_if_zero3([param]): diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 0f2d5ce261..fbae2cc30d 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -2,7 +2,7 @@ import functools import math import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial from io import BytesIO from types import MethodType @@ -11,6 +11,7 @@ import datasets import torch import torch.nn.functional as F +from msgspec import field from peft.tuners import lora from peft.tuners.lora import LoraLayer from PIL import Image @@ -18,7 +19,7 @@ from torch.utils.data import DataLoader from transformers import Trainer -from swift.utils import is_swanlab_available, is_wandb_available +from swift.utils import is_swanlab_available, is_vllm_available, is_wandb_available if is_wandb_available(): import wandb @@ -28,6 +29,14 @@ if TYPE_CHECKING: from swift.llm.utils import Messages +TensorLoRARequest = None +if is_vllm_available(): + from vllm.lora.request import LoRARequest + + class TensorLoRARequest(LoRARequest): + peft_config: dict = field(default=None) + lora_tensors: dict = field(default=None) + def round_robin(num_reqs, num_workers): """Distribute requests evenly across workers using round-robin algorithm. @@ -367,6 +376,98 @@ def patched_len(self) -> int: RepeatSampler.old_len_func = origin_len_func +def get_gather_if_zero3_context(trainer): + deepspeed_plugin = trainer.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + return gather_if_zero3 + + +def patch_vllm_load_adapter(): + # from vllm.lora.worker_manager import WorkerLoRAManager + from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + from vllm.lora.models import LoRAModel + from vllm.lora.utils import get_adapter_absolute_path + + def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: + """ + code borrowed from verl.utils.vllm.utils.py + based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors + Reason: + VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. + To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to + load memory-based LoRA tensors. + """ + try: + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_modules: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + expected_lora_modules = list(set(expected_lora_modules)) + # this is the patch + lora_tensors = None + from vllm.lora.peft_helper import PEFTHelper + if isinstance(lora_request, TensorLoRARequest): + peft_config = lora_request.peft_config + lora_tensors = lora_request.lora_tensors + peft_helper = PEFTHelper.from_dict(peft_config) + else: + lora_path = get_adapter_absolute_path(lora_request.lora_path) + peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None) + if isinstance(lora_request, TensorLoRARequest): # this is the patch + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_tensors, + peft_helper=peft_helper, + device='cpu', + dtype=self.lora_config.lora_dtype, + embeddings=None, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + else: + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device='cpu', + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + except Exception as e: + raise e + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than ' + f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.') + return lora + + if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): + _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter + LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter + LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter + + def identity_data_collator(features): """Identity data collator that returns features as-is without any processing.""" return features From c7be012698023aebf1b5df089b1646024d1ddae6 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 8 Sep 2025 16:53:11 +0800 Subject: [PATCH 02/42] fix wip --- swift/llm/infer/infer_engine/grpo_vllm_engine.py | 8 ++------ swift/trainers/rlhf_trainer/grpo_trainer.py | 12 ++++-------- swift/trainers/rlhf_trainer/utils.py | 7 +++++++ 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index eb682b5424..7662240eba 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -101,12 +101,12 @@ def infer( ) -> List[RolloutOutput]: if not adapter_request and self.enable_lora: # TODO: check if get the latest lora - lora_int_ids = list(self.llm_engine.list_loras()) + lora_int_ids = list(self.engine.list_loras()) if lora_int_ids: adapter_request = TensorLoRARequest( lora_name=f'lora_{lora_int_ids[0]}', lora_int_id=lora_int_ids[0], - path='dummy_lora_path', + lora_path='dummy_lora_path', ) res = super().infer( @@ -210,7 +210,3 @@ def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, TensorLoR return adapter_request else: raise ValueError(f'Invalid adapter request: {adapter_request}') - - @property - def llm_engine(self): - return self.engine.llm_engine diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 11d8fb7515..dd7ae5a743 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -260,7 +260,6 @@ def __init__(self, if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm -U` to use it.') - self.args.train_type = 'full' self.base_sync_done = False # tag for lora weights sync if self.vllm_mode == 'server': @@ -603,9 +602,9 @@ def _move_adapter_to_vllm(self): peft_config = self.model.peft_config.get('default', None) self.model.merge_adapter() cur_lora_params = get_peft_model_state_dict(self.model) - cur_lora_params = { + cur_lora_params = { # base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.weight name: param.full_tensor().detach().cpu() if hasattr(param, 'full_tensor') else param.detach().cpu() - for name, param in lora_params.items() + for name, param in cur_lora_params.items() } lora_params.update(cur_lora_params) self.model.unmerge_adapter() @@ -621,7 +620,7 @@ def _move_adapter_to_vllm(self): if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.add_lora(lora_reqest) # TODO elif self.vllm_mode == 'colocate': - self.engine.llm_engine.add_lora(lora_reqest) + self.engine.engine.add_lora(lora_reqest) del lora_params def _move_full_model_to_vllm(self): @@ -636,10 +635,7 @@ def _move_full_model_to_vllm(self): with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): self.model.merge_adapter() state_dict = self.model.state_dict() - state_dict = { - k.removeprefix('base_model.model.').replace('.base_layer', ''): v - for k, v in state_dict.items() - } + state_dict = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} # When module to save, remove its prefix and discard the original module state_dict = { diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index fbae2cc30d..9981428355 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -392,6 +392,7 @@ def patch_vllm_load_adapter(): from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.models import LoRAModel from vllm.lora.utils import get_adapter_absolute_path + from vllm.transformers_utils.tokenizer_group import TokenizerGroup def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: """ @@ -462,10 +463,16 @@ def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLo f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.') return lora + def patched_get_lora_tokenizer(self: TokenizerGroup, lora_request: LoRARequest): + # since we pass dummy path, skip get tokenizer from path + return self.tokenizer + if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter + TokenizerGroup._old_get_lora_tokenizer = TokenizerGroup.get_lora_tokenizer + TokenizerGroup.get_lora_tokenizer = patched_get_lora_tokenizer def identity_data_collator(features): From 22042fc9cae9c733d9789b5b85e53e151b38143a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 8 Sep 2025 22:16:09 +0800 Subject: [PATCH 03/42] wip --- swift/trainers/rlhf_trainer/grpo_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index dd7ae5a743..28066142a5 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -599,12 +599,14 @@ def _move_adapter_to_vllm(self): ] gather_if_zero3 = get_gather_if_zero3_context(self) with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): + assert len(parameters) == len(parameter_group) + state_dict = {name: p.detach().clone().cpu() for p, name in zip(parameters, parameter_group)} peft_config = self.model.peft_config.get('default', None) self.model.merge_adapter() - cur_lora_params = get_peft_model_state_dict(self.model) + cur_lora_params = get_peft_model_state_dict(self.model, state_dict) cur_lora_params = { # base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.weight name: param.full_tensor().detach().cpu() if hasattr(param, 'full_tensor') else param.detach().cpu() - for name, param in cur_lora_params.items() + for name, param in parameter_group.items() } lora_params.update(cur_lora_params) self.model.unmerge_adapter() From 1081caa0d6b5660ce209b9e0a6bff76624c0ec3b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 14:28:25 +0800 Subject: [PATCH 04/42] fix colocate lora --- swift/trainers/rlhf_trainer/grpo_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 28066142a5..b7f1ece948 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -600,17 +600,18 @@ def _move_adapter_to_vllm(self): gather_if_zero3 = get_gather_if_zero3_context(self) with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): assert len(parameters) == len(parameter_group) - state_dict = {name: p.detach().clone().cpu() for p, name in zip(parameters, parameter_group)} + state_dict = {name: p for p, name in zip(parameters, parameter_group)} peft_config = self.model.peft_config.get('default', None) self.model.merge_adapter() cur_lora_params = get_peft_model_state_dict(self.model, state_dict) cur_lora_params = { # base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.weight name: param.full_tensor().detach().cpu() if hasattr(param, 'full_tensor') else param.detach().cpu() - for name, param in parameter_group.items() + for name, param in cur_lora_params.items() } lora_params.update(cur_lora_params) self.model.unmerge_adapter() del cur_lora_params + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) lora_reqest = TensorLoRARequest( lora_name=f'{lora_int_id}', From 4c04d36b21b95663291dba070b46415059b2bdc2 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 15:36:30 +0800 Subject: [PATCH 05/42] add lora for server wip --- swift/llm/infer/rollout.py | 69 ++++++++++++- swift/trainers/rlhf_trainer/grpo_trainer.py | 34 +++--- swift/trainers/rlhf_trainer/utils.py | 108 +++++++++++++++++++- swift/trainers/rlhf_trainer/vllm_client.py | 38 +++++++ 4 files changed, 233 insertions(+), 16 deletions(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index a130e67b89..05a06f55a7 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -18,11 +18,13 @@ import uvicorn from aiohttp import ClientConnectorError from fastapi import FastAPI -from trl.scripts.vllm_serve import WeightSyncWorkerExtension +from trl.scripts.vllm_serve import WeightSyncWorkerExtension as HFWeightSyncWorkerExtension from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest from swift.plugin.multi_turn import RolloutScheduler, multi_turns +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, LoRARequest, + TensorLoRARequest) from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest @@ -50,6 +52,53 @@ - For inference or deployment, please use the `swift infer` or `swift deploy` commands. """ + +class WeightSyncWorkerExtension(HFWeightSyncWorkerExtension): + + def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + + Args: + name (`str`): + Name of the weight tensor being updated. + dtype (`str`): + Data type of the weight tensor as a string (e.g., `"torch.float32"`). + shape (`Sequence[int]`): + Shape of the weight tensor. + """ + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + + dtype = getattr(torch, dtype.split('.')[-1]) + # Allocate memory for the incoming weight tensor on the correct device. + weight = torch.empty(shape, dtype=dtype, device=self.device) + + # Use NCCL to broadcast the updated weights from the client (src) to all workers. + self.pynccl_comm.broadcast(weight, src=self.client_rank) + self.pynccl_comm.group.barrier() + + # Load the received weights into the model. + self.model_runner.model.load_weights(weights=[(name, weight)]) + + def update_adapter_flattened_param(self, lora_request: LoRARequest, + metadatas: list[FlattenedTensorMetadata]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + """ + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + flatten_tensor_length = max(metadata.end_idx for metadata in metadatas) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=torch.float32, device=self.device) + self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) + self.pynccl_comm.group.barrier() + flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) + named_params = flattened_tensor_bucket.reconstruct_tensors() + + # TODO: Check + self.add_lora(TensorLoRARequest(lora_request=lora_request, lora_tensors=named_params)) + + logger = get_logger() @@ -165,6 +214,7 @@ def _register_rl_rollout_app(self): self.app.get('/get_world_size/')(self.get_world_size) self.app.post('/init_communicator/')(self.init_communicator) self.app.post('/update_named_param/')(self.update_named_param) + self.app.post('/update_adapter_flattened_param/')(self.update_adapter_flattened_param) self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache) self.app.post('/close_communicator/')(self.close_communicator) self.app.post('/infer/', response_model=None)(self.infer) @@ -311,6 +361,23 @@ async def update_named_param(self, request: UpdateWeightsRequest): return {'message': 'Request received, updating named parameter'} + async def update_adapter_flattened_param(self, lora_request, metadatas): + # Create a LoRA request object, or pass request directly + # from swift.trainers.rlhf_trainer.utils import TensorLoRARequest + # lora_request = TensorLoRARequest( + # lora_name=request.lora_name, + # lora_int_id=request.lora_int_id, + # lora_path=request.lora_path, + # peft_config=request.peft_config, + # lora_tensors=request.lora_tensors + # ) + + kwargs = {'method': 'update_adapter_flattened_param', 'args': (lora_request, metadatas)} + for connection in self.connections: + connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) + + return {'message': 'Request received, updating adapter parameter'} + async def reset_prefix_cache(self): """ Resets the prefix cache for the model. diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 11d8fb7515..7ac506aa66 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -50,10 +50,10 @@ unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import (TensorLoRARequest, _ForwardRedirection, compute_chord_loss, get_gather_if_zero3_context, - identity_data_collator, load_pil_img, make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, - patch_profiling_context, patch_profiling_decorator, patch_save_last_checkpoint, - patch_vllm_load_adapter, replace_assistant_response_with_ids) +from .utils import (FlattenedTensorBucket, LoRARequest, TensorLoRARequest, _ForwardRedirection, compute_chord_loss, + get_gather_if_zero3_context, identity_data_collator, load_pil_img, make_chord_sft_dataset, + patch_lora_merge, patch_lora_unmerge, patch_profiling_context, patch_profiling_decorator, + patch_save_last_checkpoint, patch_vllm_load_adapter, replace_assistant_response_with_ids) from .vllm_client import VLLMClient try: @@ -611,16 +611,26 @@ def _move_adapter_to_vllm(self): self.model.unmerge_adapter() del cur_lora_params lora_int_id = int(time.time_ns() % 0x7FFFFFFF) - lora_reqest = TensorLoRARequest( - lora_name=f'{lora_int_id}', - lora_int_id=lora_int_id, - lora_path='dummy_lora_path', - peft_config=asdict(peft_config), - lora_tensors=lora_params, - ) + if self.vllm_mode == 'server' and self.accelerator.is_main_process: - self.vllm_client.add_lora(lora_reqest) # TODO + lora_reqest = LoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=asdict(peft_config), + ) + bucked = FlattenedTensorBucket(list(named_tensors=lora_params.items())) + metadatas = bucked.get_metadata() + flattened_tensor = bucked.get_flattened_tensor() + self.vllm_client.update_adapter_flattened_param(lora_reqest, metadatas, flattened_tensor) # TODO elif self.vllm_mode == 'colocate': + lora_reqest = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=asdict(peft_config), + lora_tensors=lora_params, + ) self.engine.llm_engine.add_lora(lora_reqest) del lora_params diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index fbae2cc30d..8e6f64226e 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -3,10 +3,11 @@ import math import time from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from functools import partial from io import BytesIO from types import MethodType -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import datasets import torch @@ -31,10 +32,12 @@ TensorLoRARequest = None if is_vllm_available(): - from vllm.lora.request import LoRARequest + from vllm.lora.request import LoRARequest as VLLMLoRARequest - class TensorLoRARequest(LoRARequest): + class LoRARequest(vLLMLoRARequest): peft_config: dict = field(default=None) + + class TensorLoRARequest(LoRARequest): lora_tensors: dict = field(default=None) @@ -468,6 +471,105 @@ def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLo LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter +# FlattenedTensor, code borrowed from sglang/srt/weight_sync/tensor_bucket.py +@dataclass +class FlattenedTensorMetadata: + """Metadata for a tensor in a flattened bucket""" + + name: str + shape: torch.Size + dtype: torch.dtype + start_idx: int + end_idx: int + numel: int + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError('Cannot create empty tensor bucket') + + # Collect metadata and flatten tensors + current_idx = 0 + flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) + + for i, (name, tensor) in enumerate(named_tensors): + flattened = tensor.flatten() + flattened_tensors[i] = flattened + + # Store metadata + + numel = flattened.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tensor.shape, + dtype=tensor.dtype, + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + + # Concatenate all flattened tensors + self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError('Must provide either named_tensors or both flattened_tensor and metadata') + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = [None] * len(self.metadata) + + for i, meta in enumerate(self.metadata): + tensor = self.flattened_tensor[meta.start_idx:meta.end_idx].reshape(meta.shape) + + # batch dtype conversion (if needed) + if tensor.dtype != meta.dtype: + tensor = tensor.to(meta.dtype) + + reconstructed[i] = (meta.name, tensor) + + return reconstructed + + def identity_data_collator(features): """Identity data collator that returns features as-is without any processing.""" return features diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index fc3dc2e614..4fde94c7da 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -242,6 +242,44 @@ def _update_single_server(i): if all_errors: raise RuntimeError(f'Multiple errors: {all_errors}') + def update_adapter_flattened_param(self, lora_request, metadatas, flattened_tensor): + """ + Adds a LoRA adapter to the model on all servers. + + Args: + lora_request: TensorLoRARequest object containing LoRA adapter information. + """ + errors = [None] * self.num_servers + + def _update_single_server(i): + try: + # Convert lora_request to dict for JSON serialization + data = { + 'lora_request': lora_request, + 'metadatas': metadatas, + } + + response = self.sessions[i].post( + f'{self.base_urls[i]}/update_adapter_flattened_param/', + json=data, + ) + if response.status_code != 200: + raise Exception(f'Server {i} update adapter failed: {response.text}') + + self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank) + self.pynccl_comms[i].group.barrier() + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors: {all_errors}') + def update_model_params(self, model: nn.Module): for name, param in model.named_parameters(): self.update_named_param(name, param.data) From 5fe36902f20757f6e1b53310535728dd8bfa4856 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 15:59:48 +0800 Subject: [PATCH 06/42] fix import --- swift/llm/infer/rollout.py | 22 ++-------------------- swift/trainers/rlhf_trainer/utils.py | 2 +- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 05a06f55a7..f6a90a24d2 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -6,13 +6,14 @@ import multiprocessing import os import time +from collections.abc import Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import asdict from functools import wraps from itertools import chain from multiprocessing import Pipe, Process from multiprocessing.connection import Connection -from typing import Dict, List, Optional, Union, get_type_hints +from typing import Dict, List, Optional, Union import torch import uvicorn @@ -496,22 +497,3 @@ def run_rollout(args: RolloutArguments, return_url: bool = False): finally: process.terminate() logger.info('The deployment process has been terminated.') - - -# https://github.com/huggingface/trl/pull/3690 -# This patch handles backward compatibility for dtype parameter type changes in TRL: -# - For TRL <= 0.19: dtype_annotation is torch.dtype (needs patching) -# - For TRL > 0.19: dtype_annotation is str (no patching needed) -old_update_named_param = WeightSyncWorkerExtension.update_named_param -dtype_annotation = get_type_hints(old_update_named_param).get('dtype') - -if not hasattr(WeightSyncWorkerExtension, 'old_update_named_param') and dtype_annotation == torch.dtype: - - @wraps(old_update_named_param) - def patched_update_named_param(self, name, dtype, shape) -> None: - if isinstance(dtype, str): - dtype = getattr(torch, dtype.split('.')[-1]) - return old_update_named_param(self, name, dtype, shape) - - WeightSyncWorkerExtension.update_named_param = patched_update_named_param - WeightSyncWorkerExtension.old_update_named_param = old_update_named_param diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 9885ab5496..5b1440dea9 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -32,7 +32,7 @@ TensorLoRARequest = None if is_vllm_available(): - from vllm.lora.request import LoRARequest as VLLMLoRARequest + from vllm.lora.request import LoRARequest as vLLMLoRARequest class LoRARequest(vLLMLoRARequest): peft_config: dict = field(default=None) From 161dac8cc822bf83a5e938792157d8a3871a7bf7 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 16:41:01 +0800 Subject: [PATCH 07/42] update extension path --- swift/llm/infer/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index f6a90a24d2..20637af560 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -284,7 +284,7 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): # used for RL external rollout backend engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync - engine_kwargs.update({'worker_extension_cls': 'trl.scripts.vllm_serve.WeightSyncWorkerExtension'}) + engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'}) engine_kwargs['load_format'] = 'dummy' if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1: engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size From 0a14d202b3a537d2dd5ebc3c658cf061e016c6db Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 16:41:45 +0800 Subject: [PATCH 08/42] override enable_lora for rollout --- swift/llm/infer/rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index f6a90a24d2..fa6fb2c790 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -273,7 +273,6 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, - 'enable_lora': args.vllm_enable_lora, 'max_lora_rank': args.vllm_max_lora_rank, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend @@ -281,6 +280,7 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): infer_backend = 'vllm' logger.info('Currently, rollout only supports the vLLM backend. Set vLLM backend') kwargs.update(args.get_vllm_engine_kwargs()) + kwargs.update({'enable_lora': args.vllm_enable_lora}) # override # used for RL external rollout backend engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync From efae3b267f566d23760214309ddfb58265c45bb7 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 17:43:37 +0800 Subject: [PATCH 09/42] catch rollout exception --- swift/llm/infer/rollout.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 9bb7fd5603..89a7d3bc29 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -6,6 +6,7 @@ import multiprocessing import os import time +import traceback from collections.abc import Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import asdict @@ -158,7 +159,11 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) - result = method(*args, **kwargs) + try: + result = method(*args, **kwargs) + except Exception: + logger.error(f'Method execution failed: {method_name}\n{traceback.format_exc()}') + result = None if command['type'] == 'call': connection.send(result) elif command['type'] == 'shutdown': @@ -186,7 +191,6 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast # Handle commands if command['type'] in ['call', 'fire_and_forget']: - import traceback method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) From f454598aae55d56ca990437fb5956f94490600e2 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 9 Sep 2025 17:49:05 +0800 Subject: [PATCH 10/42] fix lora request --- swift/llm/infer/infer_engine/grpo_vllm_engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 7662240eba..278a595ade 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -15,6 +15,7 @@ os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400' from swift.trainers.rlhf_trainer.utils import TensorLoRARequest + from vllm.lora.request import LoRARequest except Exception: raise @@ -97,13 +98,13 @@ def infer( *, template: Optional[Template] = None, use_tqdm: Optional[bool] = None, - adapter_request: Optional[Union[AdapterRequest, TensorLoRARequest]] = None, + adapter_request: Optional[AdapterRequest] = None, ) -> List[RolloutOutput]: if not adapter_request and self.enable_lora: # TODO: check if get the latest lora lora_int_ids = list(self.engine.list_loras()) if lora_int_ids: - adapter_request = TensorLoRARequest( + adapter_request = LoRARequest( lora_name=f'lora_{lora_int_ids[0]}', lora_int_id=lora_int_ids[0], lora_path='dummy_lora_path', @@ -201,12 +202,12 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r prompt_token_ids=prompt_token_ids, images_size=images_size) - def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, TensorLoRARequest]] = None): + def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, LoRARequest]] = None): assert self.enable_lora, f'adapter_request: {adapter_request}, self.enable_lora: {self.enable_lora}' from vllm.lora.request import LoRARequest if isinstance(adapter_request, AdapterRequest): return super()._add_adapter(adapter_request) - elif isinstance(adapter_request, TensorLoRARequest): + elif isinstance(adapter_request, LoRARequest): return adapter_request else: raise ValueError(f'Invalid adapter request: {adapter_request}') From d46bc1f9cd0c517882f10292b034e89c686e52df Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 10 Sep 2025 11:02:52 +0800 Subject: [PATCH 11/42] server wip --- swift/llm/infer/rollout.py | 13 +++-- swift/trainers/rlhf_trainer/grpo_trainer.py | 10 +--- swift/trainers/rlhf_trainer/utils.py | 56 +++++++++++++++++---- swift/trainers/rlhf_trainer/vllm_client.py | 14 ++++-- 4 files changed, 69 insertions(+), 24 deletions(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 89a7d3bc29..9021642b36 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -27,6 +27,7 @@ from swift.plugin.multi_turn import RolloutScheduler, multi_turns from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, LoRARequest, TensorLoRARequest) +from swift.tuners.lora import LoraConfig from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest @@ -83,11 +84,11 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non # Load the received weights into the model. self.model_runner.model.load_weights(weights=[(name, weight)]) - def update_adapter_flattened_param(self, lora_request: LoRARequest, - metadatas: list[FlattenedTensorMetadata]) -> None: + def update_adapter_flattened_param(self, peft_config: Dict, metadatas: list[FlattenedTensorMetadata]) -> None: """ Receives updated weights from the client process and updates the named parameter in the model. """ + peft_config = LoraConfig(**peft_config) if self.pynccl_comm is None: raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') flatten_tensor_length = max(metadata.end_idx for metadata in metadatas) @@ -96,7 +97,13 @@ def update_adapter_flattened_param(self, lora_request: LoRARequest, self.pynccl_comm.group.barrier() flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) named_params = flattened_tensor_bucket.reconstruct_tensors() - + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) + lora_request = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=peft_config, + lora_tensors=named_params) # TODO: Check self.add_lora(TensorLoRARequest(lora_request=lora_request, lora_tensors=named_params)) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 6e14942733..c5bf78ff60 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -615,16 +615,10 @@ def _move_adapter_to_vllm(self): lora_int_id = int(time.time_ns() % 0x7FFFFFFF) if self.vllm_mode == 'server' and self.accelerator.is_main_process: - lora_reqest = LoRARequest( - lora_name=f'{lora_int_id}', - lora_int_id=lora_int_id, - lora_path='dummy_lora_path', - peft_config=asdict(peft_config), - ) - bucked = FlattenedTensorBucket(list(named_tensors=lora_params.items())) + bucked = FlattenedTensorBucket(named_tensors=list(lora_params.items())) metadatas = bucked.get_metadata() flattened_tensor = bucked.get_flattened_tensor() - self.vllm_client.update_adapter_flattened_param(lora_reqest, metadatas, flattened_tensor) # TODO + self.vllm_client.update_adapter_flattened_param(peft_config, metadatas, flattened_tensor) # TODO elif self.vllm_mode == 'colocate': lora_reqest = TensorLoRARequest( lora_name=f'{lora_int_id}', diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 5b1440dea9..37972c6fbd 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -3,19 +3,21 @@ import math import time from contextlib import contextmanager, nullcontext -from dataclasses import dataclass +from dataclasses import asdict, dataclass from functools import partial from io import BytesIO from types import MethodType -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union import datasets +import json import torch import torch.nn.functional as F from msgspec import field from peft.tuners import lora from peft.tuners.lora import LoraLayer from PIL import Image +from pydantic import BaseModel, field_validator from torch import nn from torch.utils.data import DataLoader from transformers import Trainer @@ -479,17 +481,37 @@ def patched_get_lora_tokenizer(self: TokenizerGroup, lora_request: LoRARequest): # FlattenedTensor, code borrowed from sglang/srt/weight_sync/tensor_bucket.py -@dataclass -class FlattenedTensorMetadata: +class FlattenedTensorMetadata(BaseModel): """Metadata for a tensor in a flattened bucket""" - name: str - shape: torch.Size - dtype: torch.dtype + shape: Tuple[int, ...] + dtype: str start_idx: int end_idx: int numel: int + @field_validator('shape', mode='before') + @classmethod + def ensure_shape_tuple(cls, v: Any) -> Tuple[int, ...]: + # accept tuple/list, torch.Size, or other iterable of ints + if torch is not None and isinstance(v, torch.Size): + return tuple(int(x) for x in v) + if isinstance(v, (list, tuple)): + return tuple(int(x) for x in v) + if isinstance(v, Iterable): + return tuple(int(x) for x in v) + raise ValueError('shape must be an iterable of ints (e.g. tuple/list/torch.Size)') + + @field_validator('dtype', mode='before') + @classmethod + def ensure_dtype_str(cls, v: Any) -> str: + # accept torch.dtype or str + if torch is not None and isinstance(v, torch.dtype): + return str(v) + if isinstance(v, str): + return v + raise ValueError('dtype must be a torch.dtype or str') + class FlattenedTensorBucket: """ @@ -531,8 +553,8 @@ def __init__( numel = flattened.numel() metadata_obj = FlattenedTensorMetadata( name=name, - shape=tensor.shape, - dtype=tensor.dtype, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype), start_idx=current_idx, end_idx=current_idx + numel, numel=numel, @@ -727,3 +749,19 @@ def compute_chord_loss(trainer, grpo_loss: torch.Tensor) -> torch.Tensor: chord_sft_loss = torch.tensor(0.0, device=grpo_loss.device, dtype=grpo_loss.dtype) loss = (1 - mu) * grpo_loss + mu * chord_sft_loss return loss + + +def serialize_peft_config(peft_config): + if not isinstance(peft_config, dict): + peft_config = asdict(peft_config) + # turn set to list to serializable + if 'target_modules' in peft_config and isinstance(peft_config['target_modules'], set): + peft_config['target_modules'] = list(peft_config['target_modules']) + + return json.dumps(peft_config, ensure_ascii=False) + + +def deserialize_peft_config(data: str): + data = json.loads(data) + from swift.tuners.lora import LoraConfig + return LoraConfig(**data) diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 4fde94c7da..a0ff6b9aee 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -4,13 +4,15 @@ import threading import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict from typing import List, Optional, Union from urllib.parse import urlparse +import json import requests import torch from packaging import version -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from requests import ConnectionError from torch import nn from transformers.utils import is_torch_cuda_available @@ -19,6 +21,7 @@ from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig, RolloutOutput from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available +from .utils import serialize_peft_config if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -242,7 +245,7 @@ def _update_single_server(i): if all_errors: raise RuntimeError(f'Multiple errors: {all_errors}') - def update_adapter_flattened_param(self, lora_request, metadatas, flattened_tensor): + def update_adapter_flattened_param(self, peft_config, metadatas, flattened_tensor): """ Adds a LoRA adapter to the model on all servers. @@ -250,12 +253,15 @@ def update_adapter_flattened_param(self, lora_request, metadatas, flattened_tens lora_request: TensorLoRARequest object containing LoRA adapter information. """ errors = [None] * self.num_servers + peft_config = serialize_peft_config(peft_config) + metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] def _update_single_server(i): try: - # Convert lora_request to dict for JSON serialization data = { - 'lora_request': lora_request, + 'peft_config': { + **peft_config + }, 'metadatas': metadatas, } From 986ac8d6cd14bd123a1feaea8a832515482ca651 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 11 Sep 2025 16:26:31 +0800 Subject: [PATCH 12/42] server add_lora wip --- .../infer/infer_engine/grpo_vllm_engine.py | 1 - swift/llm/infer/protocol.py | 7 +++++ swift/llm/infer/rollout.py | 27 ++++++++++------- swift/trainers/rlhf_trainer/grpo_trainer.py | 4 +-- swift/trainers/rlhf_trainer/utils.py | 29 ++++++++++--------- 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 278a595ade..77df3f82de 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -14,7 +14,6 @@ try: os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400' - from swift.trainers.rlhf_trainer.utils import TensorLoRARequest from vllm.lora.request import LoRARequest except Exception: raise diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index e2e0d02783..742dfc734d 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -12,6 +12,8 @@ from PIL import Image from pydantic import BaseModel, Field, field_validator +from swift.trainers.rlhf_trainer.utils import FlattenedTensorMetadata +from swift.tuners.lora import LoraConfig from ..template import InferRequest from ..utils import Messages, Tool @@ -459,3 +461,8 @@ class UpdateWeightsRequest(BaseModel): name: str dtype: str shape: list[int] + + +class UpdateFlattenedAdapterRequest(BaseModel): + peft_config: LoraConfig + metadatas: List[FlattenedTensorMetadata] diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 9021642b36..daa5c5e642 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -25,12 +25,11 @@ from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest from swift.plugin.multi_turn import RolloutScheduler, multi_turns -from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, LoRARequest, - TensorLoRARequest) -from swift.tuners.lora import LoraConfig +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest, + patch_vllm_load_adapter) from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient -from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest +from .protocol import InitCommunicatorRequest, RequestConfig, UpdateFlattenedAdapterRequest, UpdateWeightsRequest try: from vllm.utils import get_open_port @@ -55,6 +54,8 @@ - For inference or deployment, please use the `swift infer` or `swift deploy` commands. """ +patch_vllm_load_adapter() + class WeightSyncWorkerExtension(HFWeightSyncWorkerExtension): @@ -84,11 +85,13 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non # Load the received weights into the model. self.model_runner.model.load_weights(weights=[(name, weight)]) - def update_adapter_flattened_param(self, peft_config: Dict, metadatas: list[FlattenedTensorMetadata]) -> None: + def update_adapter_flattened_param(self, peft_config: Dict, metadatas: list[Dict]) -> None: """ Receives updated weights from the client process and updates the named parameter in the model. """ - peft_config = LoraConfig(**peft_config) + # peft_config = json.loads(peft_config) + # peft_config = LoraConfig(**peft_config) + metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] if self.pynccl_comm is None: raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') flatten_tensor_length = max(metadata.end_idx for metadata in metadatas) @@ -105,7 +108,7 @@ def update_adapter_flattened_param(self, peft_config: Dict, metadatas: list[Flat peft_config=peft_config, lora_tensors=named_params) # TODO: Check - self.add_lora(TensorLoRARequest(lora_request=lora_request, lora_tensors=named_params)) + self.add_lora(lora_request) logger = get_logger() @@ -373,7 +376,7 @@ async def update_named_param(self, request: UpdateWeightsRequest): return {'message': 'Request received, updating named parameter'} - async def update_adapter_flattened_param(self, lora_request, metadatas): + async def update_adapter_flattened_param(self, request: UpdateFlattenedAdapterRequest): # Create a LoRA request object, or pass request directly # from swift.trainers.rlhf_trainer.utils import TensorLoRARequest # lora_request = TensorLoRARequest( @@ -383,8 +386,12 @@ async def update_adapter_flattened_param(self, lora_request, metadatas): # peft_config=request.peft_config, # lora_tensors=request.lora_tensors # ) - - kwargs = {'method': 'update_adapter_flattened_param', 'args': (lora_request, metadatas)} + peft_config = asdict(request.peft_config) + metadatas = [ + metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() + for metadata in request.metadatas + ] + kwargs = {'method': 'update_adapter_flattened_param', 'args': (peft_config, metadatas)} for connection in self.connections: connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c5bf78ff60..fff28fddef 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -50,7 +50,7 @@ unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import (FlattenedTensorBucket, LoRARequest, TensorLoRARequest, _ForwardRedirection, compute_chord_loss, +from .utils import (FlattenedTensorBucket, TensorLoRARequest, _ForwardRedirection, compute_chord_loss, get_gather_if_zero3_context, identity_data_collator, load_pil_img, make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, patch_profiling_decorator, patch_save_last_checkpoint, patch_vllm_load_adapter, replace_assistant_response_with_ids) @@ -605,7 +605,7 @@ def _move_adapter_to_vllm(self): self.model.merge_adapter() cur_lora_params = get_peft_model_state_dict(self.model, state_dict) cur_lora_params = { # base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.weight - name: param.full_tensor().detach().cpu() if hasattr(param, 'full_tensor') else param.detach().cpu() + name: param.full_tensor().detach() if hasattr(param, 'full_tensor') else param.detach() for name, param in cur_lora_params.items() } lora_params.update(cur_lora_params) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 37972c6fbd..eba6b683c9 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -7,7 +7,7 @@ from functools import partial from io import BytesIO from types import MethodType -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import datasets import json @@ -34,12 +34,10 @@ TensorLoRARequest = None if is_vllm_available(): - from vllm.lora.request import LoRARequest as vLLMLoRARequest - - class LoRARequest(vLLMLoRARequest): - peft_config: dict = field(default=None) + from vllm.lora.request import LoRARequest class TensorLoRARequest(LoRARequest): + peft_config: dict = field(default=None) lora_tensors: dict = field(default=None) @@ -393,6 +391,11 @@ def get_gather_if_zero3_context(trainer): def patch_vllm_load_adapter(): + """Patch vLLM's WorkerLoRAManager to support loading LoRA from tensors. + + This function applies a monkey patch to WorkerLoRAManager._load_adapter method + to support TensorLoRARequest + """ # from vllm.lora.worker_manager import WorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.models import LoRAModel @@ -579,22 +582,22 @@ def get_metadata(self) -> List[FlattenedTensorMetadata]: """Get metadata for all tensors in the bucket""" return self.metadata - def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]: + def reconstruct_tensors(self) -> Dict[str, torch.Tensor]: """ Reconstruct original tensors from flattened tensor with optimized performance. Uses memory-efficient operations to minimize allocations and copies. """ # preallocate the result list - reconstructed = [None] * len(self.metadata) + reconstructed = {} - for i, meta in enumerate(self.metadata): + for meta in self.metadata: tensor = self.flattened_tensor[meta.start_idx:meta.end_idx].reshape(meta.shape) - + dtype = getattr(torch, meta.dtype.split('.')[-1]) # batch dtype conversion (if needed) - if tensor.dtype != meta.dtype: - tensor = tensor.to(meta.dtype) + if tensor.dtype != dtype: + tensor = tensor.to(dtype) - reconstructed[i] = (meta.name, tensor) + reconstructed[meta.name] = tensor return reconstructed @@ -758,7 +761,7 @@ def serialize_peft_config(peft_config): if 'target_modules' in peft_config and isinstance(peft_config['target_modules'], set): peft_config['target_modules'] = list(peft_config['target_modules']) - return json.dumps(peft_config, ensure_ascii=False) + return peft_config def deserialize_peft_config(data: str): From 0dc8c6e9202dbe1452dde771bcefa7af8fb87a44 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 14:51:48 +0800 Subject: [PATCH 13/42] fix server tp --- swift/llm/infer/protocol.py | 1 + swift/llm/infer/rollout.py | 22 +++++---------------- swift/trainers/rlhf_trainer/grpo_trainer.py | 5 ++--- swift/trainers/rlhf_trainer/vllm_client.py | 2 ++ 4 files changed, 10 insertions(+), 20 deletions(-) diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 742dfc734d..f9ba849452 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -464,5 +464,6 @@ class UpdateWeightsRequest(BaseModel): class UpdateFlattenedAdapterRequest(BaseModel): + lora_int_id: int peft_config: LoraConfig metadatas: List[FlattenedTensorMetadata] diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index daa5c5e642..1c33b86e8f 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -85,29 +85,26 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non # Load the received weights into the model. self.model_runner.model.load_weights(weights=[(name, weight)]) - def update_adapter_flattened_param(self, peft_config: Dict, metadatas: list[Dict]) -> None: + def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, metadatas: list[Dict]) -> None: """ Receives updated weights from the client process and updates the named parameter in the model. """ - # peft_config = json.loads(peft_config) - # peft_config = LoraConfig(**peft_config) metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] if self.pynccl_comm is None: raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') - flatten_tensor_length = max(metadata.end_idx for metadata in metadatas) - flatten_tensor = torch.empty(flatten_tensor_length, dtype=torch.float32, device=self.device) + flatten_tensor_length = metadatas[-1].end_idx + dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1]) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device) self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) self.pynccl_comm.group.barrier() flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) named_params = flattened_tensor_bucket.reconstruct_tensors() - lora_int_id = int(time.time_ns() % 0x7FFFFFFF) lora_request = TensorLoRARequest( lora_name=f'{lora_int_id}', lora_int_id=lora_int_id, lora_path='dummy_lora_path', peft_config=peft_config, lora_tensors=named_params) - # TODO: Check self.add_lora(lora_request) @@ -377,21 +374,12 @@ async def update_named_param(self, request: UpdateWeightsRequest): return {'message': 'Request received, updating named parameter'} async def update_adapter_flattened_param(self, request: UpdateFlattenedAdapterRequest): - # Create a LoRA request object, or pass request directly - # from swift.trainers.rlhf_trainer.utils import TensorLoRARequest - # lora_request = TensorLoRARequest( - # lora_name=request.lora_name, - # lora_int_id=request.lora_int_id, - # lora_path=request.lora_path, - # peft_config=request.peft_config, - # lora_tensors=request.lora_tensors - # ) peft_config = asdict(request.peft_config) metadatas = [ metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() for metadata in request.metadatas ] - kwargs = {'method': 'update_adapter_flattened_param', 'args': (peft_config, metadatas)} + kwargs = {'method': 'update_adapter_flattened_param', 'args': (request.lora_int_id, peft_config, metadatas)} for connection in self.connections: connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index fff28fddef..7068553bd2 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -612,14 +612,13 @@ def _move_adapter_to_vllm(self): self.model.unmerge_adapter() del cur_lora_params - lora_int_id = int(time.time_ns() % 0x7FFFFFFF) - if self.vllm_mode == 'server' and self.accelerator.is_main_process: bucked = FlattenedTensorBucket(named_tensors=list(lora_params.items())) metadatas = bucked.get_metadata() flattened_tensor = bucked.get_flattened_tensor() - self.vllm_client.update_adapter_flattened_param(peft_config, metadatas, flattened_tensor) # TODO + self.vllm_client.update_adapter_flattened_param(peft_config, metadatas, flattened_tensor) elif self.vllm_mode == 'colocate': + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) lora_reqest = TensorLoRARequest( lora_name=f'{lora_int_id}', lora_int_id=lora_int_id, diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index a0ff6b9aee..958131de12 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -255,10 +255,12 @@ def update_adapter_flattened_param(self, peft_config, metadatas, flattened_tenso errors = [None] * self.num_servers peft_config = serialize_peft_config(peft_config) metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) def _update_single_server(i): try: data = { + 'lora_int_id': lora_int_id, 'peft_config': { **peft_config }, From ba284baf0df9c3d1c6fcdcf0c437713107b0134a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 8 Sep 2025 15:39:13 +0800 Subject: [PATCH 14/42] merge main --- .../infer/infer_engine/grpo_vllm_engine.py | 2 +- swift/llm/infer/rollout.py | 1 + swift/trainers/rlhf_trainer/grpo_trainer.py | 11 +++ swift/trainers/rlhf_trainer/utils.py | 98 ++++++++++++++++++- 4 files changed, 107 insertions(+), 5 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 278a595ade..601bbfd8cd 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -98,7 +98,7 @@ def infer( *, template: Optional[Template] = None, use_tqdm: Optional[bool] = None, - adapter_request: Optional[AdapterRequest] = None, + adapter_request: Optional[Union[AdapterRequest, TensorLoRARequest]] = None, ) -> List[RolloutOutput]: if not adapter_request and self.enable_lora: # TODO: check if get the latest lora diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 89a7d3bc29..30f1d05bdf 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -277,6 +277,7 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, + 'enable_lora': args.vllm_enable_lora, 'max_lora_rank': args.vllm_max_lora_rank, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 6e14942733..594a2aa58c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -8,6 +8,7 @@ import time import uuid from collections import OrderedDict, defaultdict, deque +from collections import OrderedDict, defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -26,6 +27,7 @@ from dacite import from_dict from packaging import version from peft.utils.save_and_load import get_peft_model_state_dict +from peft.utils.save_and_load import get_peft_model_state_dict from torch.nn import ModuleList from torch.utils.data import DataLoader from transformers import PreTrainedModel, TrainerCallback @@ -537,6 +539,14 @@ def prepare_vllm(self, model): max_num_seqs = ( self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) lora_kwargs = {} + if self.args.train_type == 'lora': + lora_kwargs = { + 'enable_lora': True, + 'max_loras': 1, + 'max_lora_rank': self.args.lora_rank, + } + patch_vllm_load_adapter() + lora_kwargs = {} if self.args.train_type == 'lora': lora_kwargs = { 'enable_lora': True, @@ -564,6 +574,7 @@ def prepare_vllm(self, model): template=copy(self.template), distributed_executor_backend='external_launcher', **lora_kwargs, + **lora_kwargs, ) return engine diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 5b1440dea9..b936e959d3 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -32,12 +32,10 @@ TensorLoRARequest = None if is_vllm_available(): - from vllm.lora.request import LoRARequest as vLLMLoRARequest - - class LoRARequest(vLLMLoRARequest): - peft_config: dict = field(default=None) + from vllm.lora.request import LoRARequest class TensorLoRARequest(LoRARequest): + peft_config: dict = field(default=None) lora_tensors: dict = field(default=None) @@ -390,6 +388,98 @@ def get_gather_if_zero3_context(trainer): return gather_if_zero3 +def patch_vllm_load_adapter(): + # from vllm.lora.worker_manager import WorkerLoRAManager + from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + from vllm.lora.models import LoRAModel + from vllm.lora.utils import get_adapter_absolute_path + + def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: + """ + code borrowed from verl.utils.vllm.utils.py + based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors + Reason: + VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. + To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to + load memory-based LoRA tensors. + """ + try: + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_modules: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + expected_lora_modules = list(set(expected_lora_modules)) + # this is the patch + lora_tensors = None + from vllm.lora.peft_helper import PEFTHelper + if isinstance(lora_request, TensorLoRARequest): + peft_config = lora_request.peft_config + lora_tensors = lora_request.lora_tensors + peft_helper = PEFTHelper.from_dict(peft_config) + else: + lora_path = get_adapter_absolute_path(lora_request.lora_path) + peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None) + if isinstance(lora_request, TensorLoRARequest): # this is the patch + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_tensors, + peft_helper=peft_helper, + device='cpu', + dtype=self.lora_config.lora_dtype, + embeddings=None, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + else: + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device='cpu', + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + except Exception as e: + raise e + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than ' + f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.') + return lora + + if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): + _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter + LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter + LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter + + +def get_gather_if_zero3_context(trainer): + deepspeed_plugin = trainer.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + return gather_if_zero3 + + def patch_vllm_load_adapter(): # from vllm.lora.worker_manager import WorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager From 849696f7769692e21bcd6ab31f59a6ac1333f3a1 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 16:24:34 +0800 Subject: [PATCH 15/42] doc wip --- .../Instruction/Command-line-parameters.md | 2 ++ docs/source_en/Instruction/GRPO/GetStarted/GRPO.md | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 59fa5cd9b5..01aff30075 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -612,6 +612,8 @@ Deployment Arguments inherit from the [inference arguments](#inference-arguments - Rollout Parameters - multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py. - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. + - vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. + - vllm_max_lora_rank: LoRA parameter for the vLLM engine. Must be greater than or equal to the training lora_rank; it is recommended to set them equal. Defaults to 16. ### Rollout Arguments The rollout parameters inherit from the [deployment parameters](#deployment-arguments). diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index 73b4772005..80a251e870 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -194,6 +194,20 @@ To configure the external vLLM server during training, use the following paramet --vllm_server_port \ --vllm_server_timeout \ ``` + +#### Weight-Sync Acceleration +Swift 3.9 optimizes weight synchronization for LoRA training, achieving ~10× speed-up over Swift 3.8. + +To enable the optimized LoRA weight sync, add the following arguments to your rollout command: + +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # set to the same value as lora_rank in the training script +``` +Note: For multimodal model training, vLLM supports loading adapters only for the language-model part. If you need to train the ViT layers of a multimodal model (freeze_vit false), set `vllm_enable_lora false`. + +For implementation details, please refer to the [PR](https://github.com/modelscope/ms-swift/pull/5773) + ## logged metrics - completions/mean_length: The average length of generated completions. - completions/min_length: The minimum length among generated completions. From f70e82740235c1e63b66d553bee03ba91a9fbd2c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 16:25:24 +0800 Subject: [PATCH 16/42] doc --- .../Instruction/GRPO/GetStarted/GRPO.md | 13 ++++- ...44\350\241\214\345\217\202\346\225\260.md" | 3 +- examples/train/grpo/external/README.md | 6 +++ examples/train/grpo/external/mllm_lora.sh | 52 +++++++++++++++++++ 4 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 examples/train/grpo/external/mllm_lora.sh diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index f0bb51bdf4..9be2c23a93 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -185,7 +185,7 @@ swift rollout \ 更多 rollout 参数参考[vLLM参数](../../../Instruction/命令行参数.md#vllm参数)和[rollout 参数](../../../Instruction/命令行参数.md#rollout参数) -注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP。 +注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP,或升级vLLM 训练使用以下参数配置外部 vLLM 服务器 @@ -196,6 +196,17 @@ swift rollout \ --vllm_server_port <服务端口> \ --vllm_server_timeout <超时时间> \ ``` +#### 权重同步加速 +swift 3.9对 LoRA 训练的权重同步进行了优化(相比swift3.8加速约10倍) + +为开启LoRA权重同步优化,请在rollout命令中设置以下参数 +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 +``` +注意:对于多模态模型训练,vLLM 仅支持多模态模型的语言模型部分的adapter加载,如果需要训练多模型模型的ViT层(freeze_vit false),请设置`vllm_enable_lora false` + +优化实现细节请参考该[PR](https://github.com/modelscope/ms-swift/pull/5773) ## logged metrics - completions/mean_length:生成的 completion 的平均长度。 diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 24e613576e..2b40a0a4f4 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -171,7 +171,6 @@ - enable_dft_loss: 是否在SFT训练中使用[DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss,默认为False。 - enable_channel_loss: 打开channel loss,默认为`False`。你需要在数据集中准备"channel"字段,ms-swift会根据该字段分组统计loss。数据集格式参考[channel loss](../Customization/自定义数据集.md#channel-loss)。channel loss兼容packing/padding_free/loss_scale等技术。 - 注意:该参数为"ms-swift>=3.8"新增,若要在"ms-swift<3.8"使用channel loss,请查看v3.7文档。 - - 注意:该功能暂不兼容序列并行,待修复。 - logging_dir: tensorboard日志路径。默认为None,即设置为`f'{self.output_dir}/runs'`。 - predict_with_generate: 验证时使用生成式的方式,默认为False。 - metric_for_best_model: 默认为None,即当`predict_with_generate`设置为False时,设置为'loss',否则设置为'rouge-l'(在PPO训练时,不进行默认值设置;GRPO训练设置为'reward')。 @@ -593,6 +592,8 @@ soft overlong 奖励参数 - Rollout 参数 - multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 + - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速) + - vllm_max_lora_rank: vLLM Engine LoRA参数,需大于等于训练的lora_rank,建议等于。默认为16。 ### Rollout参数 Rollout参数继承于[部署参数](#部署参数) diff --git a/examples/train/grpo/external/README.md b/examples/train/grpo/external/README.md index 733199dd4c..a4808a5c9d 100644 --- a/examples/train/grpo/external/README.md +++ b/examples/train/grpo/external/README.md @@ -7,6 +7,12 @@ 1. vLLM version 0.8.3 or higher. 2. trl version 0.17.0 or higher +For LoRA Training, set following parameters to speed up weight update +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # same as lora_rank in training script +``` + ## **Introduction** The GRPO (Group Relative Policy Optimization) training framework supports high-performance inference engines like vLLM to accelerate the sampling process. The **External Mode** allows you to connect to an external vLLM inference server, separating the inference service from the training process. This mode is ideal for scenarios where you want to offload inference to dedicated hardware or servers, improving resource utilization and scalability. diff --git a/examples/train/grpo/external/mllm_lora.sh b/examples/train/grpo/external/mllm_lora.sh new file mode 100644 index 0000000000..e2b4e38bf3 --- /dev/null +++ b/examples/train/grpo/external/mllm_lora.sh @@ -0,0 +1,52 @@ +# For LoRA Training, set following parameters to speed up weight update +# ```bash +# --vllm_enable_lora true +# --vllm_max_lora_rank xxx # same as lora_rank in training script +# ``` + +# CUDA_VISIBLE_DEVICES=4,5,6,7 \ +# swift rollout \ +# --model Qwen/Qwen2.5-VL-7B-Instruct \ +# --vllm_data_parallel_size 2 \ +# --vllm_tensor_parallel_size 2 \ +# --vllm_enable_lora true \ +# --vllm_max_lora_rank 16 + + +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NPROC_PER_NODE=4 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_r1v_acc format \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ + --torch_dtype bfloat16 \ + --dataset 'AI-ModelScope/clevr_cogen_a_train' \ + --max_completion_length 1024 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 2 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 1.0 \ + --system 'examples/train/grpo/prompt.txt' \ + --deepspeed zero3 \ + --log_completions true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 From 274f6db3e69e19eb591f3bd586d3511e170ff057 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 16:28:18 +0800 Subject: [PATCH 17/42] check lora --- swift/llm/infer/rollout.py | 19 +++++++++++-------- swift/trainers/rlhf_trainer/grpo_trainer.py | 17 +++++------------ swift/trainers/rlhf_trainer/vllm_client.py | 1 + 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 5d25aa1800..9e48d2b8f3 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -284,7 +284,6 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, - 'enable_lora': args.vllm_enable_lora, 'max_lora_rank': args.vllm_max_lora_rank, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend @@ -417,13 +416,17 @@ async def get_engine_type(self): enable_multi_turn = False if self.args.multi_turn_scheduler: enable_multi_turn = True - - if self.use_async_engine: - if self.use_gym_env: - return {'engine_type': 'AsyncLLMEngine', 'gym_env': True, 'enable_multi_turn': True} - return {'engine_type': 'AsyncLLMEngine', 'enable_multi_turn': enable_multi_turn} - else: - return {'engine_type': 'LLMEngine', 'enable_multi_turn': enable_multi_turn} + use_gym_env = False + if self.use_async_engine and self.use_gym_env: + use_gym_env = True + engine_type = 'AsyncLLMEngine' if self.use_async_engine else 'LLMEngine' + enable_lora = self.args.vllm_enable_lora + return { + 'engine_type': engine_type, + 'enable_multi_turn': enable_multi_turn, + 'use_gym_env': use_gym_env, + 'enable_lora': enable_lora, + } async def close_communicator(self): """ diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 3525fb065f..c1bf30b50c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -8,7 +8,6 @@ import time import uuid from collections import OrderedDict, defaultdict, deque -from collections import OrderedDict, defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -27,7 +26,6 @@ from dacite import from_dict from packaging import version from peft.utils.save_and_load import get_peft_model_state_dict -from peft.utils.save_and_load import get_peft_model_state_dict from torch.nn import ModuleList from torch.utils.data import DataLoader from transformers import PreTrainedModel, TrainerCallback @@ -256,6 +254,7 @@ def __init__(self, self.enable_offload = False self.use_gym_env = False self.enable_server_multi_turn = False + self.rollout_enable_lora = False # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs self.dynamic_num_samples = False if self.use_vllm: @@ -271,13 +270,16 @@ def __init__(self, vllm_use_async_engine = [self.vllm_client.use_async_engine] use_gym_env = [self.vllm_client.use_gym_env] enable_multi_turn = [self.vllm_client.enable_multi_turn] + enable_lora = [self.vllm_client.enable_lora] else: vllm_use_async_engine = [False] use_gym_env = [False] enable_multi_turn = [self.enable_server_multi_turn] + enable_lora = [False] self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + self.rollout_enable_lora = broadcast_object_list(enable_lora, from_process=0)[0] if self.use_gym_env: self.reward_func_names = ['gym_reward'] @@ -539,14 +541,6 @@ def prepare_vllm(self, model): max_num_seqs = ( self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) lora_kwargs = {} - if self.args.train_type == 'lora': - lora_kwargs = { - 'enable_lora': True, - 'max_loras': 1, - 'max_lora_rank': self.args.lora_rank, - } - patch_vllm_load_adapter() - lora_kwargs = {} if self.args.train_type == 'lora': lora_kwargs = { 'enable_lora': True, @@ -574,7 +568,6 @@ def prepare_vllm(self, model): template=copy(self.template), distributed_executor_backend='external_launcher', **lora_kwargs, - **lora_kwargs, ) return engine @@ -596,7 +589,7 @@ def _move_model_to_vllm(self, skip_async_check=False): train_type = self.args.train_type - if train_type == 'full' or (train_type == 'lora' and not self.base_sync_done): + if train_type == 'full' or (train_type == 'lora' and not self.base_sync_done) or not self.rollout_enable_lora: self._move_full_model_to_vllm() else: self._move_adapter_to_vllm() diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 958131de12..9f04d53e81 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -321,6 +321,7 @@ def get_engine_type(self): self.use_async_engine = result['engine_type'] == 'AsyncLLMEngine' self.enable_multi_turn = result.get('enable_multi_turn', False) self.use_gym_env = result.get('gym_env', False) + self.enable_lora = result.get('enable_lora', False) return result def close_communicator(self): From 6069888484be49f938ea699c102e6891bb4b7564 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 16:57:48 +0800 Subject: [PATCH 18/42] add args for lora script --- ...43\347\240\201\350\256\255\347\273\203.md" | 6 +- swift/llm/argument/deploy_args.py | 1 - .../infer/infer_engine/grpo_vllm_engine.py | 1 + swift/trainers/rlhf_trainer/utils.py | 97 ------------------- 4 files changed, 6 insertions(+), 99 deletions(-) diff --git "a/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" "b/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" index 04776eab4e..bd9bfc8ad4 100644 --- "a/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" @@ -42,7 +42,9 @@ ```bash CUDA_VISIBLE_DEVICES=7 \ swift rollout \ - --model Qwen/Qwen2.5-7B-Instruct + --model Qwen/Qwen2.5-7B-Instruct \ + --vllm_enable_lora true \ + --vllm_max_lora_rank 16 ``` ```bash @@ -61,6 +63,8 @@ swift rlhf \ --vllm_server_host 127.0.0.1 \ --vllm_server_port 8000 \ --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ --torch_dtype bfloat16 \ --dataset 'open-r1/verifiable-coding-problems-python-10k' \ --max_completion_length 2048 \ diff --git a/swift/llm/argument/deploy_args.py b/swift/llm/argument/deploy_args.py index fe326a9df5..71ae22b2bd 100644 --- a/swift/llm/argument/deploy_args.py +++ b/swift/llm/argument/deploy_args.py @@ -86,7 +86,6 @@ class RolloutArguments(DeployArguments): # only for GRPO rollout with AsyncEngine, see details in swift/plugin/multi_turn multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None - # lora, TODO: modify example script for lora vllm_enable_lora: bool = False vllm_max_lora_rank: int = 16 # GYM env diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index ba026f611e..145b26df28 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -8,6 +8,7 @@ from swift.llm import InferRequest, Template, VllmEngine from swift.plugin import Metric +from swift.trainers.rlhf_trainer.utils import TensorLoRARequest from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput from .utils import AdapterRequest diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 37baf94e2c..db0de53561 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -390,103 +390,6 @@ def get_gather_if_zero3_context(trainer): return gather_if_zero3 -def patch_vllm_load_adapter(): - """Patch vLLM's WorkerLoRAManager to support loading LoRA from tensors. - - This function applies a monkey patch to WorkerLoRAManager._load_adapter method - to support TensorLoRARequest - """ - # from vllm.lora.worker_manager import WorkerLoRAManager - from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager - from vllm.lora.models import LoRAModel - from vllm.lora.utils import get_adapter_absolute_path - - def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: - """ - code borrowed from verl.utils.vllm.utils.py - based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors - Reason: - VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. - To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to - load memory-based LoRA tensors. - """ - try: - supported_lora_modules = self._adapter_manager.supported_lora_modules - packed_modules_mapping = self._adapter_manager.packed_modules_mapping - expected_lora_modules: list[str] = [] - for module in supported_lora_modules: - if module in packed_modules_mapping: - expected_lora_modules.extend(packed_modules_mapping[module]) - else: - expected_lora_modules.append(module) - expected_lora_modules = list(set(expected_lora_modules)) - # this is the patch - lora_tensors = None - from vllm.lora.peft_helper import PEFTHelper - if isinstance(lora_request, TensorLoRARequest): - peft_config = lora_request.peft_config - lora_tensors = lora_request.lora_tensors - peft_helper = PEFTHelper.from_dict(peft_config) - else: - lora_path = get_adapter_absolute_path(lora_request.lora_path) - peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) - # Validates the LoRA configuration against requirements before - # loading weights, throwing an exception if validation fails. - peft_helper.validate_legal(self.lora_config) - # For some models like Qwen2VL, we need to use hf_to_vllm_mapper - # to ensure correct loading of lora weights. - model = self._adapter_manager.model - hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None) - if isinstance(lora_request, TensorLoRARequest): # this is the patch - lora = self._lora_model_cls.from_lora_tensors( - lora_model_id=lora_request.lora_int_id, - tensors=lora_tensors, - peft_helper=peft_helper, - device='cpu', - dtype=self.lora_config.lora_dtype, - embeddings=None, - target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, - embedding_modules=self.embedding_modules, - embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper, - ) - else: - lora = self._lora_model_cls.from_local_checkpoint( - lora_path, - expected_lora_modules, - peft_helper=peft_helper, - lora_model_id=lora_request.lora_int_id, - device='cpu', - dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, - embedding_modules=self.embedding_modules, - embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper, - ) - except Exception as e: - raise e - if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than ' - f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.') - return lora - - if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): - _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter - LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter - LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter - - -def get_gather_if_zero3_context(trainer): - deepspeed_plugin = trainer.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - return gather_if_zero3 - - def patch_vllm_load_adapter(): # from vllm.lora.worker_manager import WorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager From 0cf5a6286fbda69214c85a55579930a7fcfce261 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 16:59:14 +0800 Subject: [PATCH 19/42] update script --- docs/source_en/BestPractices/GRPO-Code-Training.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source_en/BestPractices/GRPO-Code-Training.md b/docs/source_en/BestPractices/GRPO-Code-Training.md index d41818addf..adc73c9ea4 100644 --- a/docs/source_en/BestPractices/GRPO-Code-Training.md +++ b/docs/source_en/BestPractices/GRPO-Code-Training.md @@ -46,7 +46,9 @@ launch external vLLM server using following script ```bash CUDA_VISIBLE_DEVICES=7 \ swift rollout \ - --model Qwen/Qwen2.5-7B-Instruct + --model Qwen/Qwen2.5-7B-Instruct \ + --vllm_enable_lora true \ + --vllm_max_lora_rank 16 ``` ```bash @@ -65,6 +67,8 @@ swift rlhf \ --vllm_server_host 127.0.0.1 \ --vllm_server_port 8000 \ --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ --torch_dtype bfloat16 \ --dataset 'open-r1/verifiable-coding-problems-python-10k' \ --max_completion_length 2048 \ From 691a5dffe38bbc102c8e69b4bf6dc068b395be0a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 17:13:36 +0800 Subject: [PATCH 20/42] fix --- swift/llm/infer/infer_engine/grpo_vllm_engine.py | 4 ++-- swift/trainers/rlhf_trainer/utils.py | 10 ++-------- swift/trainers/rlhf_trainer/vllm_client.py | 4 ++-- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 145b26df28..5236dcab74 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -98,12 +98,12 @@ def infer( *, template: Optional[Template] = None, use_tqdm: Optional[bool] = None, - adapter_request: Optional[Union[AdapterRequest, TensorLoRARequest]] = None, + adapter_request: Optional[AdapterRequest] = None, ) -> List[RolloutOutput]: if not adapter_request and self.enable_lora: - # TODO: check if get the latest lora lora_int_ids = list(self.engine.list_loras()) if lora_int_ids: + # since max_lora = 1, pick the first lora adapter_request = LoRARequest( lora_name=f'lora_{lora_int_ids[0]}', lora_int_id=lora_int_ids[0], diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index db0de53561..ef64d1f9f3 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -749,17 +749,11 @@ def compute_chord_loss(trainer, grpo_loss: torch.Tensor) -> torch.Tensor: return loss -def serialize_peft_config(peft_config): +def peft_config(peft_config): if not isinstance(peft_config, dict): peft_config = asdict(peft_config) # turn set to list to serializable if 'target_modules' in peft_config and isinstance(peft_config['target_modules'], set): peft_config['target_modules'] = list(peft_config['target_modules']) - return peft_config - - -def deserialize_peft_config(data: str): - data = json.loads(data) - from swift.tuners.lora import LoraConfig - return LoraConfig(**data) + return peft_config \ No newline at end of file diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 9f04d53e81..d1ab6b7762 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -21,7 +21,7 @@ from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig, RolloutOutput from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available -from .utils import serialize_peft_config +from .utils import peft_config_to_dict if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -253,7 +253,7 @@ def update_adapter_flattened_param(self, peft_config, metadatas, flattened_tenso lora_request: TensorLoRARequest object containing LoRA adapter information. """ errors = [None] * self.num_servers - peft_config = serialize_peft_config(peft_config) + peft_config = peft_config_to_dict(peft_config) metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] lora_int_id = int(time.time_ns() % 0x7FFFFFFF) From 5cab78de5323d18c024ee2fe9a29be7c1f122f79 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 17:17:59 +0800 Subject: [PATCH 21/42] remove unused import --- swift/llm/infer/infer_engine/grpo_vllm_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 5236dcab74..6274f4d0ca 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -8,7 +8,6 @@ from swift.llm import InferRequest, Template, VllmEngine from swift.plugin import Metric -from swift.trainers.rlhf_trainer.utils import TensorLoRARequest from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput from .utils import AdapterRequest From e43a0da62bde37a0ec685ec66a92350655a3c3e6 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 17:19:47 +0800 Subject: [PATCH 22/42] fix --- swift/trainers/rlhf_trainer/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index ef64d1f9f3..a124c37c4c 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -749,11 +749,11 @@ def compute_chord_loss(trainer, grpo_loss: torch.Tensor) -> torch.Tensor: return loss -def peft_config(peft_config): +def peft_config_to_dict(peft_config): if not isinstance(peft_config, dict): peft_config = asdict(peft_config) # turn set to list to serializable if 'target_modules' in peft_config and isinstance(peft_config['target_modules'], set): peft_config['target_modules'] = list(peft_config['target_modules']) - return peft_config \ No newline at end of file + return peft_config From 688bf64c88f17d9a1efd55114d0b8dbb16f9fd16 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 17:24:07 +0800 Subject: [PATCH 23/42] fix typo --- docs/source/Instruction/GRPO/GetStarted/GRPO.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index 9be2c23a93..d2a8dd16ae 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -204,7 +204,7 @@ swift 3.9对 LoRA 训练的权重同步进行了优化(相比swift3.8加速约 --vllm_enable_lora true --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 ``` -注意:对于多模态模型训练,vLLM 仅支持多模态模型的语言模型部分的adapter加载,如果需要训练多模型模型的ViT层(freeze_vit false),请设置`vllm_enable_lora false` +注意:对于多模态模型训练,vLLM 仅支持多模态模型的语言模型部分的adapter加载,如果需要训练多模态模型的ViT层(freeze_vit false),请设置`vllm_enable_lora false` 优化实现细节请参考该[PR](https://github.com/modelscope/ms-swift/pull/5773) From 4fa2d2ff9397e204859863f9ce206eb377815a2c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 12 Sep 2025 19:16:41 +0800 Subject: [PATCH 24/42] fix unmerge --- swift/trainers/rlhf_trainer/grpo_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 6fac55acb0..bf5a51a1ed 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -611,7 +611,8 @@ def _move_adapter_to_vllm(self): for name, param in cur_lora_params.items() } lora_params.update(cur_lora_params) - self.model.unmerge_adapter() + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() del cur_lora_params if self.vllm_mode == 'server' and self.accelerator.is_main_process: From 78f9473a1c83f0e2cd3dacdf9e7e117748d38b60 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 28 Sep 2025 11:33:20 +0800 Subject: [PATCH 25/42] wip --- swift/trainers/rlhf_trainer/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index a124c37c4c..cda0d80630 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -39,6 +39,15 @@ class TensorLoRARequest(LoRARequest): peft_config: dict = field(default=None) lora_tensors: dict = field(default=None) + lora_embeddings: Optional[Dict[str, torch.Tensor]] = None + + @property + def config(self): + return self.peft_config + + @property + def embeddings(self): + return self.lora_embeddings def round_robin(num_reqs, num_workers): From 6745666d9eabf30f3c8ae9e3b5c8e56b1a1d9d06 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 9 Oct 2025 18:02:41 +0800 Subject: [PATCH 26/42] bucket for full training in server mode --- swift/llm/infer/protocol.py | 4 ++ swift/llm/infer/rollout.py | 47 ++++++++++++++- swift/trainers/rlhf_trainer/grpo_trainer.py | 59 ++++++++++++++----- swift/trainers/rlhf_trainer/utils.py | 65 ++++++++++++++++++++- swift/trainers/rlhf_trainer/vllm_client.py | 38 ++++++++++++ 5 files changed, 197 insertions(+), 16 deletions(-) diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index f9ba849452..a3926df386 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -467,3 +467,7 @@ class UpdateFlattenedAdapterRequest(BaseModel): lora_int_id: int peft_config: LoraConfig metadatas: List[FlattenedTensorMetadata] + + +class UpdateFlattenedParamsRequest(BaseModel): + metadatas: List[FlattenedTensorMetadata] diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index e6a466f98e..205220b682 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -29,7 +29,8 @@ patch_vllm_load_adapter) from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient -from .protocol import InitCommunicatorRequest, RequestConfig, UpdateFlattenedAdapterRequest, UpdateWeightsRequest +from .protocol import (InitCommunicatorRequest, RequestConfig, UpdateFlattenedAdapterRequest, + UpdateFlattenedParamsRequest, UpdateWeightsRequest) try: from vllm.utils import get_open_port @@ -107,6 +108,30 @@ def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, me lora_tensors=named_params) self.add_lora(lora_request) + def update_flattened_params(self, metadatas: list[Dict]) -> None: + """ + Receives updated flattened weights from the client process and updates the model parameters. + + Args: + metadatas (list[Dict]): List of metadata dictionaries for the flattened tensors. + """ + metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + + flatten_tensor_length = metadatas[-1].end_idx + dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1]) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device) + + self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) + self.pynccl_comm.group.barrier() + + flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) + named_params = flattened_tensor_bucket.reconstruct_tensors() + + # Load the reconstructed parameters into the model + self.model_runner.model.load_weights(weights=list(named_params.items())) + logger = get_logger() @@ -229,6 +254,7 @@ def _register_rl_rollout_app(self): self.app.post('/init_communicator/')(self.init_communicator) self.app.post('/update_named_param/')(self.update_named_param) self.app.post('/update_adapter_flattened_param/')(self.update_adapter_flattened_param) + self.app.post('/update_flattened_params/')(self.update_flattened_params) self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache) self.app.post('/close_communicator/')(self.close_communicator) self.app.post('/infer/', response_model=None)(self.infer) @@ -387,6 +413,25 @@ async def update_adapter_flattened_param(self, request: UpdateFlattenedAdapterRe return {'message': 'Request received, updating adapter parameter'} + async def update_flattened_params(self, request: UpdateFlattenedParamsRequest): + """ + Updates the model weights with flattened tensor data. + + Args: + request (UpdateFlattenedParamsRequest): + - metadatas (List[FlattenedTensorMetadata]): Metadata for the flattened tensors. + + """ + metadatas = [ + metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() + for metadata in request.metadatas + ] + kwargs = {'method': 'update_flattened_params', 'args': (metadatas, )} + for connection in self.connections: + connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) + + return {'message': 'Request received, updating flattened parameters'} + async def reset_prefix_cache(self): """ Resets the prefix cache for the model. diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 532b341e0d..ec3a96567f 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -49,11 +49,11 @@ unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import (FlattenedTensorBucket, TensorLoRARequest, _ForwardRedirection, compute_chord_loss, - get_gather_if_zero3_context, identity_data_collator, load_pil_img, make_chord_sft_dataset, - patch_lora_merge, patch_lora_unmerge, patch_profiling_context, patch_profiling_decorator, - patch_save_last_checkpoint, patch_vllm_load_adapter, replace_assistant_response_with_ids, - set_expandable_segments) +from .utils import (FlattenedTensorBucket, TensorLoRARequest, _create_parameter_buckets, _ForwardRedirection, + _process_bucket_with_flattened_tensor, compute_chord_loss, get_gather_if_zero3_context, + identity_data_collator, load_pil_img, make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, + patch_profiling_context, patch_profiling_decorator, patch_save_last_checkpoint, + patch_vllm_load_adapter, replace_assistant_response_with_ids, set_expandable_segments) from .vllm_client import VLLMClient try: @@ -663,8 +663,15 @@ def _move_full_model_to_vllm(self): [state.shape != torch.Size([0]) for state in state_dict.values()]) if self.vllm_mode == 'server' and self.accelerator.is_main_process: - for name, param in state_dict.items(): - self.vllm_client.update_named_param(name, param) + # Create parameter buckets and process them efficiently + named_params = list(state_dict.items()) + parameter_buckets = _create_parameter_buckets(named_params) + + # Process each bucket using flattened tensor approach + for bucket in parameter_buckets: + _process_bucket_with_flattened_tensor(self, bucket) + + del named_params, parameter_buckets elif self.vllm_mode == 'colocate': llm_model = self.engine.inner_model llm_model.load_weights(state_dict.items()) @@ -673,13 +680,37 @@ def _move_full_model_to_vllm(self): del state_dict self.base_sync_done = True else: - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights([(name, param.data)]) + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + # For non-PEFT models, use streaming bucket approach to avoid memory peaks + # Collect parameters in small batches and process them immediately + current_bucket = [] + current_size = 0 + bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 + for name, param in self.model.named_parameters(): + with gather_if_zero3([param]): + # Ensure we have the full parameter (not sharded) and COPY it + # to avoid the parameter being re-partitioned when exiting context + full_param = param.data.clone() if hasattr(param, 'data') else param.clone() + param_size = full_param.numel() * full_param.element_size() + + # If adding this param would exceed bucket size, process current bucket first + if current_size + param_size > bucket_size_bytes and current_bucket: + _process_bucket_with_flattened_tensor(self, current_bucket) + current_bucket = [] + current_size = 0 + + current_bucket.append((name, full_param)) + current_size += param_size + + # Process remaining parameters in the last bucket + if current_bucket: + _process_bucket_with_flattened_tensor(self, current_bucket) + else: + for name, param in self.model.named_parameters(): + with gather_if_zero3([param]): + if self.vllm_mode == 'colocate': + llm_model = self.engine.inner_model + llm_model.load_weights([(name, param.data)]) if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index bc85ad295d..53ed7142c3 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -23,6 +23,7 @@ from torch.utils.data import DataLoader, RandomSampler from transformers import Trainer +from swift.trainers.rlhf_trainer.grpo_trainer import GRPOTrainer from swift.utils import is_swanlab_available, is_vllm_available, is_wandb_available if is_wandb_available(): @@ -45,7 +46,7 @@ class TensorLoRARequest(LoRARequest): @property def config(self): return self.peft_config - + @property def embeddings(self): return self.lora_embeddings @@ -812,3 +813,65 @@ def peft_config_to_dict(peft_config): peft_config['target_modules'] = list(peft_config['target_modules']) return peft_config + + +def _create_parameter_buckets(named_params, bucket_size_mb=100): + """Create parameter buckets grouped by dtype for efficient processing""" + buckets = [] + current_bucket = [] + current_size = 0 + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + + # Group parameters by dtype first, then by size + dtype_groups = {} + for name, param in named_params: + dtype = param.dtype + if dtype not in dtype_groups: + dtype_groups[dtype] = [] + dtype_groups[dtype].append((name, param)) + + # Create buckets within each dtype group + for dtype, params in dtype_groups.items(): + for name, param in params: + param_size = param.numel() * param.element_size() + + # If adding this param would exceed bucket size, start a new bucket + if current_size + param_size > bucket_size_bytes and current_bucket: + buckets.append(current_bucket) + current_bucket = [] + current_size = 0 + + current_bucket.append((name, param)) + current_size += param_size + + # Add remaining params in current bucket + if current_bucket: + buckets.append(current_bucket) + current_bucket = [] + current_size = 0 + + return buckets + + +def _process_bucket_with_flattened_tensor(trainer: GRPOTrainer, bucket_params): + """Process a bucket of parameters using FlattenedTensorBucket for efficiency""" + if not bucket_params: + return + + # Create FlattenedTensorBucket for efficient processing + bucket = FlattenedTensorBucket(named_tensors=bucket_params) + metadatas = bucket.get_metadata() + flattened_tensor = bucket.get_flattened_tensor() + + # Use the new flattened parameter update method + # If not available, fall back to individual parameter updates + try: + trainer.vllm_client.update_flattened_params(metadatas, flattened_tensor) + except AttributeError: + # Fallback to individual parameter updates + reconstructed = bucket.reconstruct_tensors() + for name, param in reconstructed.items(): + trainer.vllm_client.update_named_param(name, param) + + # Clean up + del bucket, metadatas, flattened_tensor diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index d1ab6b7762..8440a220bb 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -288,6 +288,44 @@ def _update_single_server(i): if all_errors: raise RuntimeError(f'Multiple errors: {all_errors}') + def update_flattened_params(self, metadatas, flattened_tensor): + """ + Updates model parameters using flattened tensor data. + + Args: + metadatas: List of FlattenedTensorMetadata objects + flattened_tensor: The flattened tensor containing all parameters + """ + errors = [None] * self.num_servers + metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] + + def _update_single_server(i): + try: + data = { + 'metadatas': metadatas, + } + + response = self.sessions[i].post( + f'{self.base_urls[i]}/update_flattened_params/', + json=data, + ) + if response.status_code != 200: + raise Exception(f'Server {i} update flattened params failed: {response.text}') + + self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank) + self.pynccl_comms[i].group.barrier() + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors: {all_errors}') + def update_model_params(self, model: nn.Module): for name, param in model.named_parameters(): self.update_named_param(name, param.data) From dfccf15f06f77d5caf93563af88abb780844ee8e Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 10 Oct 2025 14:27:46 +0800 Subject: [PATCH 27/42] remove circle import --- swift/trainers/rlhf_trainer/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 53ed7142c3..12c6cdcf2c 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -23,7 +23,6 @@ from torch.utils.data import DataLoader, RandomSampler from transformers import Trainer -from swift.trainers.rlhf_trainer.grpo_trainer import GRPOTrainer from swift.utils import is_swanlab_available, is_vllm_available, is_wandb_available if is_wandb_available(): @@ -853,7 +852,7 @@ def _create_parameter_buckets(named_params, bucket_size_mb=100): return buckets -def _process_bucket_with_flattened_tensor(trainer: GRPOTrainer, bucket_params): +def _process_bucket_with_flattened_tensor(trainer, bucket_params): """Process a bucket of parameters using FlattenedTensorBucket for efficiency""" if not bucket_params: return From 4652f7713645e03c233ad60cd25d649cfc934d02 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 10 Oct 2025 14:38:25 +0800 Subject: [PATCH 28/42] fix TokenizerGroup removed in vllm 0.11.0 --- swift/trainers/rlhf_trainer/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 12c6cdcf2c..c340ad28f5 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -405,7 +405,12 @@ def patch_vllm_load_adapter(): from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.models import LoRAModel from vllm.lora.utils import get_adapter_absolute_path - from vllm.transformers_utils.tokenizer_group import TokenizerGroup + + try: + from vllm.transformers_utils.tokenizer_group import TokenizerGroup + except ImportError: + # removed in https://github.com/vllm-project/vllm/pull/24078 + TokenizerGroup = None def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: """ @@ -484,8 +489,9 @@ def patched_get_lora_tokenizer(self: TokenizerGroup, lora_request: LoRARequest): _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter - TokenizerGroup._old_get_lora_tokenizer = TokenizerGroup.get_lora_tokenizer - TokenizerGroup.get_lora_tokenizer = patched_get_lora_tokenizer + if TokenizerGroup is not None: + TokenizerGroup._old_get_lora_tokenizer = TokenizerGroup.get_lora_tokenizer + TokenizerGroup.get_lora_tokenizer = patched_get_lora_tokenizer # FlattenedTensor, code borrowed from sglang/srt/weight_sync/tensor_bucket.py From 8c73590cdf82d436addcaa2b501ab7ef3ef848f5 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 10 Oct 2025 16:01:01 +0800 Subject: [PATCH 29/42] rm comments --- swift/trainers/rlhf_trainer/grpo_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index ec3a96567f..1d8628564d 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -611,7 +611,7 @@ def _move_adapter_to_vllm(self): peft_config = self.model.peft_config.get('default', None) self.model.merge_adapter() cur_lora_params = get_peft_model_state_dict(self.model, state_dict) - cur_lora_params = { # base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.weight + cur_lora_params = { name: param.full_tensor().detach() if hasattr(param, 'full_tensor') else param.detach() for name, param in cur_lora_params.items() } @@ -681,8 +681,6 @@ def _move_full_model_to_vllm(self): self.base_sync_done = True else: if self.vllm_mode == 'server' and self.accelerator.is_main_process: - # For non-PEFT models, use streaming bucket approach to avoid memory peaks - # Collect parameters in small batches and process them immediately current_bucket = [] current_size = 0 bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 From cc565889601a22828f5e1a3d26a30d4128471c74 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 13 Oct 2025 15:36:40 +0800 Subject: [PATCH 30/42] move model batches for full parameters --- swift/trainers/rlhf_trainer/grpo_trainer.py | 65 +++++++++++++-------- swift/trainers/rlhf_trainer/utils.py | 23 ++++---- 2 files changed, 53 insertions(+), 35 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index ec3a96567f..18871f0695 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -247,8 +247,7 @@ def __init__(self, # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but # it's safer to set it in all cases. set_seed(args.seed, device_specific=True) - if is_peft_model(self.model): - self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() self.use_fast_infer = self.use_vllm # whether to use the PT backend self.vllm_use_async_engine = False self.enable_offload = False @@ -548,6 +547,7 @@ def prepare_vllm(self, model): 'max_loras': 1, 'max_lora_rank': self.args.lora_rank, } + self.rollout_enable_lora = True patch_vllm_load_adapter() with Swift.grpo_context(model, self.template.processor): set_expandable_segments(False) @@ -640,7 +640,7 @@ def _move_adapter_to_vllm(self): def _move_full_model_to_vllm(self): gather_if_zero3 = get_gather_if_zero3_context(self) if is_peft_model(self.model): - for i, parameter_group in enumerate(self.parameter_groups): # < this is the change + for i, parameter_group in enumerate(self.parameter_groups): parameter_group_no_lora = self.parameter_groups_no_lora[i] parameters = [ parameter for name, parameter in self.model.named_parameters() @@ -649,7 +649,11 @@ def _move_full_model_to_vllm(self): with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): self.model.merge_adapter() state_dict = self.model.state_dict() - state_dict = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} + prefix_removed = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} + state_dict = prefix_removed if self.rollout_enable_lora else { + k.replace('.base_layer', ''): v + for k, v in prefix_removed.items() + } state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} # When module to save, remove its prefix and discard the original module state_dict = { @@ -680,31 +684,44 @@ def _move_full_model_to_vllm(self): del state_dict self.base_sync_done = True else: - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - # For non-PEFT models, use streaming bucket approach to avoid memory peaks - # Collect parameters in small batches and process them immediately - current_bucket = [] - current_size = 0 + if self.vllm_mode == 'server': bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - # Ensure we have the full parameter (not sharded) and COPY it - # to avoid the parameter being re-partitioned when exiting context - full_param = param.data.clone() if hasattr(param, 'data') else param.clone() - param_size = full_param.numel() * full_param.element_size() - - # If adding this param would exceed bucket size, process current bucket first - if current_size + param_size > bucket_size_bytes and current_bucket: - _process_bucket_with_flattened_tensor(self, current_bucket) + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + with gather_if_zero3(parameters): + if self.accelerator.is_main_process: + # Get state_dict AFTER gather to get full parameters + state_dict = self.model.state_dict() + + # Filter by parameter_group_no_lora if specified + if parameter_group_no_lora: + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + + # Split gathered parameters into buckets current_bucket = [] current_size = 0 - current_bucket.append((name, full_param)) - current_size += param_size + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + # If adding this param would exceed bucket size, process current bucket first + if current_size + param_size > bucket_size_bytes and current_bucket: + _process_bucket_with_flattened_tensor(self, current_bucket) + current_bucket = [] + current_size = 0 + + current_bucket.append((name, param)) + current_size += param_size + + # Process remaining parameters in the last bucket + if current_bucket: + _process_bucket_with_flattened_tensor(self, current_bucket) - # Process remaining parameters in the last bucket - if current_bucket: - _process_bucket_with_flattened_tensor(self, current_bucket) + del state_dict else: for name, param in self.model.named_parameters(): with gather_if_zero3([param]): diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index c340ad28f5..a17cb3f935 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -554,17 +554,11 @@ def __init__( if not named_tensors: raise ValueError('Cannot create empty tensor bucket') - # Collect metadata and flatten tensors + # First pass: compute total size and metadata current_idx = 0 - flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors) - + total_numel = 0 for i, (name, tensor) in enumerate(named_tensors): - flattened = tensor.flatten() - flattened_tensors[i] = flattened - - # Store metadata - - numel = flattened.numel() + numel = tensor.numel() metadata_obj = FlattenedTensorMetadata( name=name, shape=tuple(tensor.shape), @@ -575,9 +569,16 @@ def __init__( ) self.metadata[i] = metadata_obj current_idx += numel + total_numel += numel + + # Pre-allocate the final flattened tensor to avoid intermediate copies + # Use the dtype and device of the first tensor + first_tensor = named_tensors[0][1] + self.flattened_tensor = torch.empty(total_numel, dtype=first_tensor.dtype, device=first_tensor.device) - # Concatenate all flattened tensors - self.flattened_tensor = torch.cat(flattened_tensors, dim=0) + # Second pass: copy data directly into pre-allocated tensor + for meta, (name, tensor) in zip(self.metadata, named_tensors): + self.flattened_tensor[meta.start_idx:meta.end_idx].copy_(tensor.flatten()) else: # Initialize from pre-flattened data if flattened_tensor is None or metadata is None: From 1561e8f82b7dc52775bf03927e2c718358e0f950 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 13 Oct 2025 17:52:13 +0800 Subject: [PATCH 31/42] fix lora training with rollout enable_lora --- swift/llm/argument/rlhf_args.py | 1 + swift/trainers/rlhf_trainer/grpo_trainer.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 0b081a911c..dca42e5c18 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -265,6 +265,7 @@ def _init_external_vllm(self): hosts=self.vllm_server_host, server_ports=self.vllm_server_port, connection_timeout=self.vllm_server_timeout) + self.vllm_client.close_communicator() self.vllm_client.init_communicator(device=get_current_device()) logger.info('Connected to vLLM server') diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 909d15c84d..2a957180d0 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -247,7 +247,6 @@ def __init__(self, # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but # it's safer to set it in all cases. set_seed(args.seed, device_specific=True) - self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() self.use_fast_infer = self.use_vllm # whether to use the PT backend self.vllm_use_async_engine = False self.enable_offload = False @@ -309,6 +308,8 @@ def __init__(self, infer_template.padding_free = False self.engine = PtEngine.from_model_template(self.model, infer_template, max_batch_size=0) # 0: no limit + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + if not self.reward_funcs and not self.use_gym_env: raise ValueError('You must specify reward_funcs or reward_model') @@ -493,7 +494,10 @@ def replace_lora(name): if 'lora_' in name: return '' else: - return name.replace('base_layer.', '') + if not self.rollout_enable_lora: + return name.replace('base_layer.', '') + else: + return name def remove_lora_and_prefix(names): names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names]) From 24220bf2ef1c805399996f635959d3c3a0cf039b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 14 Oct 2025 10:37:01 +0800 Subject: [PATCH 32/42] check should merge adapter --- swift/llm/argument/rlhf_args.py | 1 - swift/trainers/rlhf_trainer/grpo_trainer.py | 41 +++++++++++++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index dca42e5c18..0b081a911c 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -265,7 +265,6 @@ def _init_external_vllm(self): hosts=self.vllm_server_host, server_ports=self.vllm_server_port, connection_timeout=self.vllm_server_timeout) - self.vllm_client.close_communicator() self.vllm_client.init_communicator(device=get_current_device()) logger.info('Connected to vLLM server') diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 2a957180d0..c2e32e0e85 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -495,7 +495,7 @@ def replace_lora(name): return '' else: if not self.rollout_enable_lora: - return name.replace('base_layer.', '') + return re.sub(r'\.base_layer\.', '.', name) else: return name @@ -651,7 +651,9 @@ def _move_full_model_to_vllm(self): if not parameter_group or name in parameter_group ] with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): - self.model.merge_adapter() + if self.should_merge_adapter: + # if rollout enable lora, we will only execute once before the first rollout + self.model.merge_adapter() state_dict = self.model.state_dict() prefix_removed = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} state_dict = prefix_removed if self.rollout_enable_lora else { @@ -683,8 +685,9 @@ def _move_full_model_to_vllm(self): elif self.vllm_mode == 'colocate': llm_model = self.engine.inner_model llm_model.load_weights(state_dict.items()) - with patch_lora_unmerge(self.model): - self.model.unmerge_adapter() + if self.should_merge_adapter: + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() del state_dict self.base_sync_done = True else: @@ -3046,3 +3049,33 @@ def get_chunked_inputs(self, inputs, start_idx, end_idx): chunk_inputs.update(to_device(template.data_collator(encoded_data), self.model.device)) chunk_inputs.pop('labels', None) return chunk_inputs + + @property + def should_merge_adapter(self): + """ + Determine whether the LoRA adapter should be merged into the base model during weight synchronization. + + Note: + Merging or unmerging adapters in MoE models is computationally expensive and should be minimized. + + Raises: + AssertionError: If full-parameter training is used, as adapter merging is not supported. + + Returns: + bool: True if the adapter should be merged; False otherwise. + - Returns True when LoRA is not enabled for rollout. + - Returns True when loading from a checkpoint or using pre-trained adapters. + - Returns False during normal LoRA training (weights are already synchronized). + """ + assert self.args.train_type != 'full', 'Full-parameter training should not merge adapter' + + # Rollout does not support LoRA + if not self.rollout_enable_lora: + return True + + if self.args.resume_from_checkpoint: + # Resuming training: merge into base model + return True + + # base model weights are synced before training; no need to merge + return False From c66a80109f0f9b03c75392966ce35983be5e170b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 15 Oct 2025 16:07:22 +0800 Subject: [PATCH 33/42] update accuracy & test accuracy & moe script --- examples/train/grpo/external/moe.sh | 48 ++++++ swift/plugin/orm.py | 52 ++++--- tests/utils/test_rewards.py | 232 ++++++++++++++++++++++++++++ 3 files changed, 310 insertions(+), 22 deletions(-) create mode 100644 examples/train/grpo/external/moe.sh create mode 100644 tests/utils/test_rewards.py diff --git a/examples/train/grpo/external/moe.sh b/examples/train/grpo/external/moe.sh new file mode 100644 index 0000000000..977cad76a8 --- /dev/null +++ b/examples/train/grpo/external/moe.sh @@ -0,0 +1,48 @@ +# 8*80G A100 +# 200s/it + +# CUDA_VISIBLE_DEVICES=0 \ +# swift rollout \ +# --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ +# --vllm_max_model_len 16384 + +CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \ +NPROC_PER_NODE=7 \ +MAX_PIXELS=262144 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --train_type full \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 14 \ + --temperature 1.0 \ + --system 'swift/examples/train/grpo/prompt.txt' \ + --deepspeed zero3_offload \ + --log_completions true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --async_generate false \ + --move_model_batches 2 \ + --beta 0.001 \ + --move_model_batches 5 diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index d8f2b30042..75c8187fcc 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -246,29 +246,37 @@ def __call__(self, completions, solution, **kwargs) -> List[float]: from math_verify import LatexExtractionConfig, parse, verify rewards = [] for content, sol in zip(completions, solution): - gold_parsed = parse(sol, extraction_mode='first_match') + content_match = re.search(r'(.*?)', content, re.DOTALL) + content_to_parse = content_match.group(1).strip() if content_match else content + has_answer_tag = content_match is not None + + sol_match = re.search(r'(.*?)', sol, re.DOTALL) + sol_to_parse = sol_match.group(1).strip() if sol_match else sol + + gold_parsed = parse(sol_to_parse, extraction_mode='first_match') if len(gold_parsed) != 0: - # We require the answer to be provided in correct latex (no malformed operators) - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - equations=True, - boxed=True, - units=True, - ), - # Ensures that boxed is tried first - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode='first_match', - ) - # edge case + if has_answer_tag: + answer_parsed = parse(content_to_parse, extraction_mode='first_match') + else: + answer_parsed = parse( + content_to_parse, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=True, # Allow extraction of plain numbers without \boxed{} + ) + ], + extraction_mode='first_match', + ) try: reward = float(verify(gold_parsed, answer_parsed)) except Exception: diff --git a/tests/utils/test_rewards.py b/tests/utils/test_rewards.py new file mode 100644 index 0000000000..858e658334 --- /dev/null +++ b/tests/utils/test_rewards.py @@ -0,0 +1,232 @@ +import unittest + + +class TestMathAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + try: + from swift.plugin.orm import MathAccuracy + cls.math_accuracy = MathAccuracy() + cls.available = True + except (ImportError, AssertionError) as e: + print(f'Warning: MathAccuracy not available: {e}') + cls.available = False + + def setUp(self): + if not self.available: + self.skipTest('MathAccuracy not available (math_verify not installed)') + + def test_pure_latex_format(self): + completions = ['The answer is \\boxed{42}'] + solutions = ['\\boxed{42}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_latex_in_long_text(self): + completions = ['After careful calculation, the final answer is \\boxed{100}'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_multiple_steps_with_boxed(self): + completions = [ + 'Let me solve step by step:\n' + '1. First we have x = 2\n' + '2. Then y = 3x = 6\n' + '3. Finally z = x + y = 8\n' + '\nFinal answer: \\boxed{8}' + ] + solutions = ['\\boxed{8}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_wrong_answer_no_tag(self): + completions = ['The answer is \\boxed{42}'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_batch_processing_no_tag(self): + completions = ['\\boxed{42}', '\\boxed{100}', '\\boxed{8}'] + solutions = ['\\boxed{42}', '\\boxed{100}', '\\boxed{8}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 3) + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 1.0) + self.assertEqual(rewards[2], 1.0) + + def test_answer_tag_with_plain_number(self): + completions = ['84'] + solutions = ['\\boxed{84}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_answer_tag_with_latex(self): + completions = ['\\boxed{100}'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_long_text_with_answer_tag(self): + completions = [ + 'Let me solve:\n' + 'Step 1: Calculate x = 10\n' + 'Step 2: Calculate y = 20\n' + 'Step 3: Sum = 30\n' + '\n54' + ] + solutions = ['\\boxed{54}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_answer_tag_with_complex_expression(self): + completions = ['\\frac{1}{2}'] + solutions = ['\\boxed{\\frac{1}{2}}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_solution_with_answer_tag(self): + completions = ['84'] + solutions = ['\\boxed{84}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_answer_tag_wrong_answer(self): + completions = ['42'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_mixed_batch_with_and_without_tags(self): + completions = [ + '\\boxed{42}', + '100', + 'The answer is \\boxed{8}', + ] + solutions = [ + '\\boxed{42}', + '\\boxed{100}', + '\\boxed{8}', + ] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 3) + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 1.0) + self.assertEqual(rewards[2], 1.0) + + def test_empty_solution(self): + completions = ['42'] + solutions = [''] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_malformed_latex(self): + completions = ['\\boxed{42'] + solutions = ['\\boxed{42}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_answer_tag_with_extra_whitespace(self): + completions = [' 84 '] + solutions = ['\\boxed{84}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_multiple_answer_tags(self): + completions = ['42 Some text 100'] + solutions = ['\\boxed{42}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_real_world_example_from_user(self): + completions = [ + 'We are given a geometric sequence $\\{a_n\\}$ with:\n\n' + '- $a_3 = 2$\n- $a_5 = 6$\n\n' + 'We are to find $a_9$.\n\n---\n\n' + '### Step 1: Recall the formula\n\n' + '$$a_n = a_1 \\cdot r^{n-1}$$\n\n---\n\n' + '### Step 2: Use the given terms\n\n' + '$$a_3 = a_1 \\cdot r^2 = 2$$\n' + '$$a_5 = a_1 \\cdot r^4 = 6$$\n\n' + 'Divide equation (2) by equation (1):\n' + '$$r^2 = 3$$\n\n---\n\n' + '### Step 3: Find $a_9$\n\n' + '$$a_9 = a_1 \\cdot r^8 = \\frac{2}{3} \\cdot 81 = 54$$\n\n' + '### ✅ Final Answer:\n\n' + '54' + ] + solutions = ['\\boxed{54}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_equivalent_fractions(self): + completions = ['0.5'] + solutions = ['\\boxed{\\frac{1}{2}}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_different_forms_same_answer(self): + completions = ['2'] + solutions = ['\\boxed{\\sqrt{4}}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + +if __name__ == '__main__': + unittest.main() From 190c20194f7d22e3db8bb7adb06c68ec83c0f533 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 16 Oct 2025 14:36:12 +0800 Subject: [PATCH 34/42] add test cases --- swift/plugin/orm.py | 4 +--- tests/utils/test_rewards.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 75c8187fcc..2df0c753d8 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -266,13 +266,11 @@ def __call__(self, completions, solution, **kwargs) -> List[float]: nits=False, malformed_operators=False, basic_latex=True, - equations=True, boxed=True, units=True, ), - # Ensures that boxed is tried first boxed_match_priority=0, - try_extract_without_anchor=True, # Allow extraction of plain numbers without \boxed{} + try_extract_without_anchor=False, ) ], extraction_mode='first_match', diff --git a/tests/utils/test_rewards.py b/tests/utils/test_rewards.py index 858e658334..0f2e3da383 100644 --- a/tests/utils/test_rewards.py +++ b/tests/utils/test_rewards.py @@ -227,6 +227,34 @@ def test_different_forms_same_answer(self): self.assertEqual(len(rewards), 1) self.assertEqual(rewards[0], 1.0) + def test_latex_inline_math_delimiters(self): + completions = ['84', '3'] + solutions = ['\n\n\\[\n\\boxed{84}\n\\]', 'Therefore, the value of \\(a^2 - a + 2\\) is \\(\\boxed{3}\\).'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 2) + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 1.0) + + def test_latex_display_math_delimiters(self): + completions = ['100'] + solutions = ['\\[\\boxed{100}\\]'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_mixed_latex_delimiters(self): + completions = ['\\(x = 42\\)'] + solutions = ['\\[\\boxed{x = 42}\\]'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + if __name__ == '__main__': unittest.main() From c18650729980e38b23b8e44f46a0518ac2c98851 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 00:14:19 +0800 Subject: [PATCH 35/42] colocate moe script & fix moe colocate lora training --- examples/train/grpo/external/moe.sh | 6 ++-- examples/train/grpo/internal/moe_full.sh | 38 ++++++++++++++++++++ examples/train/grpo/internal/moe_lora.sh | 40 +++++++++++++++++++++ swift/trainers/rlhf_trainer/grpo_trainer.py | 38 ++++++++++---------- 4 files changed, 99 insertions(+), 23 deletions(-) create mode 100644 examples/train/grpo/internal/moe_full.sh create mode 100755 examples/train/grpo/internal/moe_lora.sh diff --git a/examples/train/grpo/external/moe.sh b/examples/train/grpo/external/moe.sh index 977cad76a8..00a160cea7 100644 --- a/examples/train/grpo/external/moe.sh +++ b/examples/train/grpo/external/moe.sh @@ -4,11 +4,11 @@ # CUDA_VISIBLE_DEVICES=0 \ # swift rollout \ # --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ -# --vllm_max_model_len 16384 +# --vllm_max_model_len 16384 \ +# --vllm_enable_prefix_caching true CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \ NPROC_PER_NODE=7 \ -MAX_PIXELS=262144 \ swift rlhf \ --rlhf_type grpo \ --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ @@ -37,12 +37,10 @@ swift rlhf \ --dataloader_num_workers 4 \ --num_generations 14 \ --temperature 1.0 \ - --system 'swift/examples/train/grpo/prompt.txt' \ --deepspeed zero3_offload \ --log_completions true \ --report_to tensorboard swanlab \ --num_iterations 1 \ --async_generate false \ - --move_model_batches 2 \ --beta 0.001 \ --move_model_batches 5 diff --git a/examples/train/grpo/internal/moe_full.sh b/examples/train/grpo/internal/moe_full.sh new file mode 100644 index 0000000000..cdc51a1418 --- /dev/null +++ b/examples/train/grpo/internal/moe_full.sh @@ -0,0 +1,38 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 2 \ + --vllm_max_model_len 16384 \ + --train_type full \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 1.0 \ + --deepspeed zero3_offload \ + --log_completions true \ + --sleep_level 1 \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 10 diff --git a/examples/train/grpo/internal/moe_lora.sh b/examples/train/grpo/internal/moe_lora.sh new file mode 100755 index 0000000000..611e13583a --- /dev/null +++ b/examples/train/grpo/internal/moe_lora.sh @@ -0,0 +1,40 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 2 \ + --vllm_max_model_len 16384 \ + --train_type lora \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 1.0 \ + --deepspeed zero2 \ + --log_completions true \ + --sleep_level 1 \ + --offload_model true \ + --offload_optimizer true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 5 diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c2e32e0e85..c3b6d1eb6b 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -545,7 +545,9 @@ def prepare_vllm(self, model): vllm_template = copy(self.template) vllm_template.padding_free = False lora_kwargs = {} - if self.args.train_type == 'lora': + is_moe = model.model_info.is_moe_model + if self.args.train_type == 'lora' and not is_moe: + # MoE LoRA is not supported now lora_kwargs = { 'enable_lora': True, 'max_loras': 1, @@ -691,22 +693,22 @@ def _move_full_model_to_vllm(self): del state_dict self.base_sync_done = True else: - if self.vllm_mode == 'server': - bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 - for i, parameter_group in enumerate(self.parameter_groups): - parameter_group_no_lora = self.parameter_groups_no_lora[i] - parameters = [ - parameter for name, parameter in self.model.named_parameters() - if not parameter_group or name in parameter_group - ] - with gather_if_zero3(parameters): + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + with gather_if_zero3(parameters): + state_dict = self.model.state_dict() + # Filter by parameter_group_no_lora if specified + if parameter_group_no_lora: + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + + if self.vllm_mode == 'server': + bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 if self.accelerator.is_main_process: # Get state_dict AFTER gather to get full parameters - state_dict = self.model.state_dict() - - # Filter by parameter_group_no_lora if specified - if parameter_group_no_lora: - state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} # Split gathered parameters into buckets current_bucket = [] @@ -729,12 +731,10 @@ def _move_full_model_to_vllm(self): _process_bucket_with_flattened_tensor(self, current_bucket) del state_dict - else: - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): + else: if self.vllm_mode == 'colocate': llm_model = self.engine.inner_model - llm_model.load_weights([(name, param.data)]) + llm_model.load_weights(state_dict.items()) if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() From b9a1827fa65d69dd8db02e5e8b0cbac80695167c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 10:42:00 +0800 Subject: [PATCH 36/42] add script --- .../grpo/external/{moe.sh => moe_full.sh} | 4 +- examples/train/grpo/external/moe_lora.sh | 44 +++++++++++++++++++ examples/train/grpo/internal/moe_full.sh | 2 + examples/train/grpo/internal/moe_lora.sh | 6 ++- 4 files changed, 51 insertions(+), 5 deletions(-) rename examples/train/grpo/external/{moe.sh => moe_full.sh} (95%) create mode 100755 examples/train/grpo/external/moe_lora.sh diff --git a/examples/train/grpo/external/moe.sh b/examples/train/grpo/external/moe_full.sh similarity index 95% rename from examples/train/grpo/external/moe.sh rename to examples/train/grpo/external/moe_full.sh index 00a160cea7..634f67024f 100644 --- a/examples/train/grpo/external/moe.sh +++ b/examples/train/grpo/external/moe_full.sh @@ -1,5 +1,4 @@ -# 8*80G A100 -# 200s/it +# 8*80G # CUDA_VISIBLE_DEVICES=0 \ # swift rollout \ @@ -41,6 +40,5 @@ swift rlhf \ --log_completions true \ --report_to tensorboard swanlab \ --num_iterations 1 \ - --async_generate false \ --beta 0.001 \ --move_model_batches 5 diff --git a/examples/train/grpo/external/moe_lora.sh b/examples/train/grpo/external/moe_lora.sh new file mode 100755 index 0000000000..3bd3ec5d9c --- /dev/null +++ b/examples/train/grpo/external/moe_lora.sh @@ -0,0 +1,44 @@ +# 8*80G + +# CUDA_VISIBLE_DEVICES=0 \ +# swift rollout \ +# --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ +# --vllm_max_model_len 16384 \ +# --vllm_enable_prefix_caching true + +CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \ +NPROC_PER_NODE=7 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --train_type lora \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 14 \ + --temperature 1.0 \ + --deepspeed zero3 \ + --log_completions true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 5 diff --git a/examples/train/grpo/internal/moe_full.sh b/examples/train/grpo/internal/moe_full.sh index cdc51a1418..7f74c8b736 100644 --- a/examples/train/grpo/internal/moe_full.sh +++ b/examples/train/grpo/internal/moe_full.sh @@ -1,3 +1,5 @@ +# 8*80G + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ NPROC_PER_NODE=8 \ swift rlhf \ diff --git a/examples/train/grpo/internal/moe_lora.sh b/examples/train/grpo/internal/moe_lora.sh index 611e13583a..3ebdec2a65 100755 --- a/examples/train/grpo/internal/moe_lora.sh +++ b/examples/train/grpo/internal/moe_lora.sh @@ -1,3 +1,5 @@ +# 8*80G + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ NPROC_PER_NODE=8 \ swift rlhf \ @@ -29,7 +31,7 @@ swift rlhf \ --dataloader_num_workers 4 \ --num_generations 16 \ --temperature 1.0 \ - --deepspeed zero2 \ + --deepspeed zero3 \ --log_completions true \ --sleep_level 1 \ --offload_model true \ @@ -37,4 +39,4 @@ swift rlhf \ --report_to tensorboard swanlab \ --num_iterations 1 \ --beta 0.001 \ - --move_model_batches 5 + --move_model_batches 10 From eb3e23003778a8f0c8ce6b84ce167dc22211d894 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 11:16:44 +0800 Subject: [PATCH 37/42] doc update --- docs/source/Instruction/GRPO/GetStarted/GRPO.md | 9 +++++---- ...44\273\244\350\241\214\345\217\202\346\225\260.md" | 3 ++- docs/source_en/Instruction/Command-line-parameters.md | 4 +++- docs/source_en/Instruction/GRPO/GetStarted/GRPO.md | 11 ++++++----- swift/llm/infer/infer_engine/grpo_vllm_engine.py | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index d2a8dd16ae..2fef0413b8 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -197,14 +197,15 @@ swift rollout \ --vllm_server_timeout <超时时间> \ ``` #### 权重同步加速 -swift 3.9对 LoRA 训练的权重同步进行了优化(相比swift3.8加速约10倍) - -为开启LoRA权重同步优化,请在rollout命令中设置以下参数 +swift 3.10 优化了权重同步,同时在 rollout 命令(server mode)中设置以下参数可以进一步优化 LoRA 训练的权重同步速度。 ```bash --vllm_enable_lora true --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 ``` -注意:对于多模态模型训练,vLLM 仅支持多模态模型的语言模型部分的adapter加载,如果需要训练多模态模型的ViT层(freeze_vit false),请设置`vllm_enable_lora false` +注意:以下情况无法使用该优化: + +- 训练多模态模型的ViT层(freeze_vit false) +- MoE 模型 优化实现细节请参考该[PR](https://github.com/modelscope/ms-swift/pull/5773) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 66afff6745..bb753e877c 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -531,6 +531,7 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 - vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 - vllm_max_model_len: vllm透传参数,默认为None。 @@ -549,7 +550,7 @@ reward模型参数将在PPO、GRPO中使用。 - sync_ref_model: 是否定期同步ref_model,默认为False。 - ref_model_mixup_alpha: 控制在更新过程中model和先前ref_model之间的混合。更新公式为 $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$。默认为0.6。 - ref_model_sync_steps:同步频率,默认为512。 -- move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。注意:该参数仅对LoRA(PEFT)训练有意义。 +- move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。 - multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 - dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index d609b21066..0b6d486669 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -540,11 +540,13 @@ The meanings of the following parameters can be referenced [here](https://huggin - vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds. - vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training. - async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: An environment variable that controls the bucket size (in MB) for weight synchronization during full-parameter training in Server Mode. Default is 512 MB. - vllm_mode colocate parameter (For more parameter support, refer to the [vLLM Arguments](#vLLM-Arguments).) - vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9. - vllm_max_model_len: vLLM passthrough parameter, the total length limit of model, default is None. - vllm_enforce_eager: vLLM passthrough parameter, default is False. - vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None. + - vllm_enable_prefix_caching: A pass-through parameter for vLLM, default is True. - vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1. - sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep - offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`. @@ -563,7 +565,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - sync_ref_model: Whether to synchronize the reference model. Default is False。 - ref_model_mixup_alpha: The Parameter controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation: $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$. Default is 0.6. - ref_model_sync_steps:The parameter determines how frequently the current policy is synchronized with the reference policy. Default is 512. -- move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches. This parameter is only meaningful for LoRA (PEFT). +- move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches. - multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py. - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. - dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index 80a251e870..aed3d7f41a 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -196,15 +196,16 @@ To configure the external vLLM server during training, use the following paramet ``` #### Weight-Sync Acceleration -Swift 3.9 optimizes weight synchronization for LoRA training, achieving ~10× speed-up over Swift 3.8. - -To enable the optimized LoRA weight sync, add the following arguments to your rollout command: +Swift 3.10 optimizes weight synchronization, and setting the following parameters in the rollout command (server mode) can further improve the weight synchronization speed for LoRA training: ```bash --vllm_enable_lora true - --vllm_max_lora_rank xxx # set to the same value as lora_rank in the training script + --vllm_max_lora_rank xxx # Should match the lora_rank in the training script ``` -Note: For multimodal model training, vLLM supports loading adapters only for the language-model part. If you need to train the ViT layers of a multimodal model (freeze_vit false), set `vllm_enable_lora false`. +Note: This optimization cannot be used in the following cases: + +- Training the ViT layers of multimodal models (freeze_vit set to false) +- MoE models For implementation details, please refer to the [PR](https://github.com/modelscope/ms-swift/pull/5773) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 6274f4d0ca..dcf185586b 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -104,7 +104,7 @@ def infer( if lora_int_ids: # since max_lora = 1, pick the first lora adapter_request = LoRARequest( - lora_name=f'lora_{lora_int_ids[0]}', + lora_name=f'{lora_int_ids[0]}', lora_int_id=lora_int_ids[0], lora_path='dummy_lora_path', ) From 9b503a723ccc50e3d3fb4b5e9c356dfec72c4d85 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 13:54:06 +0800 Subject: [PATCH 38/42] rm sricpt & update doc --- .../Instruction/GRPO/GetStarted/GRPO.md | 2 +- examples/train/grpo/external/mllm_lora.sh | 52 ------------------- 2 files changed, 1 insertion(+), 53 deletions(-) delete mode 100644 examples/train/grpo/external/mllm_lora.sh diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index aed3d7f41a..a667f0b053 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -183,7 +183,7 @@ swift rollout \ ``` For more rollout parameters, refer to the [vllm arguments](../../../Instruction/Command-line-parameters.md#vllm-arguments) and [rollout arguments](../../../Instruction/Command-line-parameters.md#rollout-arguments) -Note: When set `use_async_engine`, enabling only DP (Data Parallelism) may cause errors. [Related issue](https://github.com/vllm-project/vllm/issues/18567). If errors occur, try enabling both TP (Tensor Parallelism) and DP. +Note: When set `use_async_engine`, enabling only DP (Data Parallelism) may cause errors. [Related issue](https://github.com/vllm-project/vllm/issues/18567). If errors occur, try enabling both TP (Tensor Parallelism) and DP or upgrading vLLM. To configure the external vLLM server during training, use the following parameters: diff --git a/examples/train/grpo/external/mllm_lora.sh b/examples/train/grpo/external/mllm_lora.sh deleted file mode 100644 index e2b4e38bf3..0000000000 --- a/examples/train/grpo/external/mllm_lora.sh +++ /dev/null @@ -1,52 +0,0 @@ -# For LoRA Training, set following parameters to speed up weight update -# ```bash -# --vllm_enable_lora true -# --vllm_max_lora_rank xxx # same as lora_rank in training script -# ``` - -# CUDA_VISIBLE_DEVICES=4,5,6,7 \ -# swift rollout \ -# --model Qwen/Qwen2.5-VL-7B-Instruct \ -# --vllm_data_parallel_size 2 \ -# --vllm_tensor_parallel_size 2 \ -# --vllm_enable_lora true \ -# --vllm_max_lora_rank 16 - - -CUDA_VISIBLE_DEVICES=0,1,2,3 \ -NPROC_PER_NODE=4 \ -swift rlhf \ - --rlhf_type grpo \ - --model Qwen/Qwen2.5-VL-7B-Instruct \ - --external_plugins examples/train/grpo/plugin/plugin.py \ - --reward_funcs external_r1v_acc format \ - --use_vllm true \ - --vllm_mode server \ - --vllm_server_host 127.0.0.1 \ - --vllm_server_port 8000 \ - --train_type lora \ - --lora_rank 16 \ - --lora_alpha 32 \ - --torch_dtype bfloat16 \ - --dataset 'AI-ModelScope/clevr_cogen_a_train' \ - --max_completion_length 1024 \ - --num_train_epochs 1 \ - --per_device_train_batch_size 4 \ - --learning_rate 1e-6 \ - --gradient_accumulation_steps 2 \ - --save_strategy 'steps' \ - --eval_strategy 'steps' \ - --eval_steps 1000 \ - --save_steps 1000 \ - --save_total_limit 10 \ - --logging_steps 1 \ - --warmup_ratio 0.01 \ - --dataloader_num_workers 4 \ - --num_generations 16 \ - --temperature 1.0 \ - --system 'examples/train/grpo/prompt.txt' \ - --deepspeed zero3 \ - --log_completions true \ - --report_to tensorboard swanlab \ - --num_iterations 1 \ - --beta 0.001 From 67f36a6decc4d5ab0d03a44253056146f0f25278 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 14:48:40 +0800 Subject: [PATCH 39/42] streamline weight sync --- examples/train/grpo/external/moe_lora.sh | 0 examples/train/grpo/internal/moe_lora.sh | 0 swift/trainers/rlhf_trainer/grpo_trainer.py | 128 ++++++++------------ swift/trainers/rlhf_trainer/utils.py | 38 ++---- 4 files changed, 65 insertions(+), 101 deletions(-) mode change 100755 => 100644 examples/train/grpo/external/moe_lora.sh mode change 100755 => 100644 examples/train/grpo/internal/moe_lora.sh diff --git a/examples/train/grpo/external/moe_lora.sh b/examples/train/grpo/external/moe_lora.sh old mode 100755 new mode 100644 diff --git a/examples/train/grpo/internal/moe_lora.sh b/examples/train/grpo/internal/moe_lora.sh old mode 100755 new mode 100644 diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c3b6d1eb6b..ba47e2bd5a 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -643,103 +643,79 @@ def _move_adapter_to_vllm(self): self.engine.llm_engine.add_lora(lora_reqest) del lora_params + def _load_state_dict_to_vllm(self, state_dict): + """Load state_dict to vLLM engine (server or colocate mode)""" + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) + named_params = list(state_dict.items()) + parameter_buckets = _create_parameter_buckets(named_params, bucket_size_mb=bucket_size_mb) + + for bucket in parameter_buckets: + _process_bucket_with_flattened_tensor(self, bucket) + + del named_params, parameter_buckets + elif self.vllm_mode == 'colocate': + llm_model = self.engine.inner_model + llm_model.load_weights(state_dict.items()) + del state_dict + def _move_full_model_to_vllm(self): gather_if_zero3 = get_gather_if_zero3_context(self) - if is_peft_model(self.model): - for i, parameter_group in enumerate(self.parameter_groups): - parameter_group_no_lora = self.parameter_groups_no_lora[i] - parameters = [ - parameter for name, parameter in self.model.named_parameters() - if not parameter_group or name in parameter_group - ] - with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): - if self.should_merge_adapter: - # if rollout enable lora, we will only execute once before the first rollout - self.model.merge_adapter() - state_dict = self.model.state_dict() + is_peft = is_peft_model(self.model) + + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + + # Use patch_lora_merge for PEFT models, nullcontext otherwise + context_manager = patch_lora_merge(self.model, parameter_group) if is_peft else nullcontext() + + with gather_if_zero3(parameters), context_manager: + if is_peft and self.should_merge_adapter: + self.model.merge_adapter() + + state_dict = self.model.state_dict() + + # Process state_dict for PEFT models + if is_peft: prefix_removed = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} state_dict = prefix_removed if self.rollout_enable_lora else { k.replace('.base_layer', ''): v for k, v in prefix_removed.items() } state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} - # When module to save, remove its prefix and discard the original module state_dict = { k.replace('modules_to_save.default.', ''): v for k, v in state_dict.items() if 'original_module' not in k } - if parameter_group_no_lora: + + # Filter by parameter_group_no_lora + if parameter_group_no_lora: + if is_peft: parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] - state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + + if is_peft: assert len(state_dict) > 0 and all( [state.shape != torch.Size([0]) for state in state_dict.values()]) - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - # Create parameter buckets and process them efficiently - named_params = list(state_dict.items()) - parameter_buckets = _create_parameter_buckets(named_params) - - # Process each bucket using flattened tensor approach - for bucket in parameter_buckets: - _process_bucket_with_flattened_tensor(self, bucket) - - del named_params, parameter_buckets - elif self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights(state_dict.items()) - if self.should_merge_adapter: - with patch_lora_unmerge(self.model): - self.model.unmerge_adapter() - del state_dict + # Load to vLLM + self._load_state_dict_to_vllm(state_dict) + + if is_peft and self.should_merge_adapter: + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() + + if is_peft: self.base_sync_done = True - else: - for i, parameter_group in enumerate(self.parameter_groups): - parameter_group_no_lora = self.parameter_groups_no_lora[i] - parameters = [ - parameter for name, parameter in self.model.named_parameters() - if not parameter_group or name in parameter_group - ] - with gather_if_zero3(parameters): - state_dict = self.model.state_dict() - # Filter by parameter_group_no_lora if specified - if parameter_group_no_lora: - state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} - - if self.vllm_mode == 'server': - bucket_size_bytes = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) * 1024 * 1024 - if self.accelerator.is_main_process: - # Get state_dict AFTER gather to get full parameters - - # Split gathered parameters into buckets - current_bucket = [] - current_size = 0 - - for name, param in state_dict.items(): - param_size = param.numel() * param.element_size() - - # If adding this param would exceed bucket size, process current bucket first - if current_size + param_size > bucket_size_bytes and current_bucket: - _process_bucket_with_flattened_tensor(self, current_bucket) - current_bucket = [] - current_size = 0 - - current_bucket.append((name, param)) - current_size += param_size - - # Process remaining parameters in the last bucket - if current_bucket: - _process_bucket_with_flattened_tensor(self, current_bucket) - - del state_dict - else: - if self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights(state_dict.items()) + # Reset prefix cache if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() elif self.vllm_mode == 'colocate': - # since vLLM model weights has been updated, we should reset the prefix cache self.engine.engine.reset_prefix_cache() def _wait_queue(self): diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index a17cb3f935..d919db8446 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -821,41 +821,29 @@ def peft_config_to_dict(peft_config): return peft_config -def _create_parameter_buckets(named_params, bucket_size_mb=100): - """Create parameter buckets grouped by dtype for efficient processing""" +def _create_parameter_buckets(named_params, bucket_size_mb=512): + """Create parameter buckets for efficient processing""" buckets = [] current_bucket = [] current_size = 0 bucket_size_bytes = bucket_size_mb * 1024 * 1024 - # Group parameters by dtype first, then by size - dtype_groups = {} for name, param in named_params: - dtype = param.dtype - if dtype not in dtype_groups: - dtype_groups[dtype] = [] - dtype_groups[dtype].append((name, param)) - - # Create buckets within each dtype group - for dtype, params in dtype_groups.items(): - for name, param in params: - param_size = param.numel() * param.element_size() - - # If adding this param would exceed bucket size, start a new bucket - if current_size + param_size > bucket_size_bytes and current_bucket: - buckets.append(current_bucket) - current_bucket = [] - current_size = 0 - - current_bucket.append((name, param)) - current_size += param_size - - # Add remaining params in current bucket - if current_bucket: + param_size = param.numel() * param.element_size() + + # If adding this param would exceed bucket size, process current bucket first + if current_size + param_size > bucket_size_bytes and current_bucket: buckets.append(current_bucket) current_bucket = [] current_size = 0 + current_bucket.append((name, param)) + current_size += param_size + + # Process remaining parameters in the last bucket + if current_bucket: + buckets.append(current_bucket) + return buckets From a406f6d5e7cea6073f8b10871499737197e6769a Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 15:11:07 +0800 Subject: [PATCH 40/42] clean comments and import --- examples/train/grpo/internal/vllm_72b_4gpu.sh | 1 - examples/train/grpo/internal/vllm_lora_qwenvl72b.sh | 1 - examples/train/grpo/internal/vllm_multi_turn.sh | 1 - swift/llm/infer/rollout.py | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 2 +- swift/trainers/rlhf_trainer/utils.py | 5 +---- 6 files changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/train/grpo/internal/vllm_72b_4gpu.sh b/examples/train/grpo/internal/vllm_72b_4gpu.sh index d63af49e8f..c6871760c2 100644 --- a/examples/train/grpo/internal/vllm_72b_4gpu.sh +++ b/examples/train/grpo/internal/vllm_72b_4gpu.sh @@ -36,7 +36,6 @@ swift rlhf \ --top_p 1.0 \ --top_k 80 \ --log_completions true \ - --async_generate false \ --move_model_batches 16 \ --offload_optimizer true \ --offload_model true \ diff --git a/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh b/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh index bf054de523..f41273c39f 100755 --- a/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh +++ b/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh @@ -40,7 +40,6 @@ swift rlhf \ --top_p 1.0 \ --top_k 80 \ --log_completions true \ - --async_generate false \ --offload_optimizer true \ --offload_model true \ --move_model_batches 40 \ diff --git a/examples/train/grpo/internal/vllm_multi_turn.sh b/examples/train/grpo/internal/vllm_multi_turn.sh index 352e64b890..1cc8d5b500 100644 --- a/examples/train/grpo/internal/vllm_multi_turn.sh +++ b/examples/train/grpo/internal/vllm_multi_turn.sh @@ -36,7 +36,6 @@ swift rlhf \ --top_p 1.0 \ --top_k 80 \ --log_completions true \ - --async_generate false \ --offload_optimizer true \ --offload_model true \ --sleep_level 1 \ diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 205220b682..b0a8dc7071 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -88,7 +88,7 @@ def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> Non def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, metadatas: list[Dict]) -> None: """ - Receives updated weights from the client process and updates the named parameter in the model. + Receives and applies a flattened LoRA adapter to the model. """ metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] if self.pynccl_comm is None: diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index ba47e2bd5a..4ec153f664 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -495,7 +495,7 @@ def replace_lora(name): return '' else: if not self.rollout_enable_lora: - return re.sub(r'\.base_layer\.', '.', name) + return name.replace('.base_layer', '') else: return name diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index d919db8446..4bd92b2d6c 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -4,14 +4,13 @@ import os import time from contextlib import contextmanager, nullcontext -from dataclasses import asdict, dataclass +from dataclasses import asdict from functools import partial from io import BytesIO from types import MethodType from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import datasets -import json import torch import torch.nn.functional as F from msgspec import field @@ -21,7 +20,6 @@ from pydantic import BaseModel, field_validator from torch import nn from torch.utils.data import DataLoader, RandomSampler -from transformers import Trainer from swift.utils import is_swanlab_available, is_vllm_available, is_wandb_available @@ -401,7 +399,6 @@ def get_gather_if_zero3_context(trainer): def patch_vllm_load_adapter(): - # from vllm.lora.worker_manager import WorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.models import LoRAModel from vllm.lora.utils import get_adapter_absolute_path From 0380c08048e14bc6f7fc35469dfeecfeceb7b6fa Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 17:05:50 +0800 Subject: [PATCH 41/42] add vllm_enable_lora for colocate --- .../Instruction/GRPO/GetStarted/GRPO.md | 16 +++++++++++++-- ...44\350\241\214\345\217\202\346\225\260.md" | 8 +++++--- .../Instruction/Command-line-parameters.md | 3 ++- .../Instruction/GRPO/GetStarted/GRPO.md | 17 +++++++++++++--- swift/trainers/arguments.py | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 20 +++++++++++++++---- 6 files changed, 52 insertions(+), 14 deletions(-) diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index 2fef0413b8..4715ebe9b4 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -197,11 +197,23 @@ swift rollout \ --vllm_server_timeout <超时时间> \ ``` #### 权重同步加速 -swift 3.10 优化了权重同步,同时在 rollout 命令(server mode)中设置以下参数可以进一步优化 LoRA 训练的权重同步速度。 +swift 3.10 优化了权重同步,设置以下参数可以进一步优化 LoRA 训练的权重同步速度。 + ```bash - --vllm_enable_lora true +# rollout(server mode) +swift rollout \ + --vllm_enable_lora true \ --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 + ... + +# grpo(colocate mode) +swift rlhf \ + --rlhf_type grpo \ + --vllm_mode colocate \ + --vllm_enable_lora true \ + ... ``` + 注意:以下情况无法使用该优化: - 训练多模态模型的ViT层(freeze_vit false) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index bb753e877c..5c736f248c 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -526,7 +526,7 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, - vllm_mode server 参数 - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - - vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用。 + - vllm_server_host:vLLM server host地址,默认为None。 - vllm_server_port vLLM server 服务端口,默认为8000。 - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 @@ -538,6 +538,8 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_enforce_eager: vllm透传参数,默认为False。 - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 - vllm_enable_prefix_caching: vllm透传参数,默认为True。 + - vllm_tensor_parallel_size: tp并行数,默认为`1`。 + - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。 - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放 - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 @@ -605,9 +607,9 @@ soft overlong 奖励参数 ### Rollout参数 Rollout参数继承于[部署参数](#部署参数) -- multi_turn_scheduler: 多轮GRPO训练规划器,传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。默认为None,具体参考[文档](./GRPO/DeveloperGuide/多轮训练.md) +- multi_turn_scheduler: 多轮GRPO训练规划器,传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。默认为None,具体参考[文档](./GRPO/DeveloperGuide/多轮训练.md)。 - max_turns: 多轮GRPO训练下的最大轮数,默认为None,即不做约束。 -- vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速) +- vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。 - vllm_max_lora_rank: vLLM Engine LoRA参数,需大于等于训练的lora_rank,建议等于。默认为16。 ### Web-UI参数 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 0b6d486669..c4160ee525 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -535,7 +535,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate` - vllm_mode server parameter - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None. - - vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server. + - vllm_server_host: The host address of the vLLM server. Default is None. - vllm_server_port: The service port of the vLLM server. Default is 8000. - vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds. - vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training. @@ -548,6 +548,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None. - vllm_enable_prefix_caching: A pass-through parameter for vLLM, default is True. - vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1. + - vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. - sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep - offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`. - offload_model: Whether to offload the model during inference with vLLM. The default is `False`. diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index a667f0b053..f19faf3f9d 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -196,12 +196,23 @@ To configure the external vLLM server during training, use the following paramet ``` #### Weight-Sync Acceleration -Swift 3.10 optimizes weight synchronization, and setting the following parameters in the rollout command (server mode) can further improve the weight synchronization speed for LoRA training: +Swift 3.10 optimizes weight synchronization, and setting the following parameters can further improve the weight synchronization speed for LoRA training: ```bash - --vllm_enable_lora true - --vllm_max_lora_rank xxx # Should match the lora_rank in the training script +# rollout(server mode) +swift rollout \ + --vllm_enable_lora true \ + --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 + ... + +# grpo(colocate mode) +swift rlhf \ + --rlhf_type grpo \ + --vllm_mode colocate \ + --vllm_enable_lora true \ + ... ``` + Note: This optimization cannot be used in the following cases: - Training the ViT layers of multimodal models (freeze_vit set to false) diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index b2256da0b7..f5afe1de89 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -271,7 +271,7 @@ class GRPOArgumentsMixin(VllmArguments): vllm_mode: Literal['server', 'colocate'] = 'colocate' # internal vllm (colocate) vllm_enable_prefix_caching: bool = True # overwrite - + vllm_enable_lora: bool = False # external vllm (server) vllm_server_base_url: Optional[List[str]] = None vllm_server_host: Optional[List[str]] = None diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 4ec153f664..f3addf5d93 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -546,14 +546,26 @@ def prepare_vllm(self, model): vllm_template.padding_free = False lora_kwargs = {} is_moe = model.model_info.is_moe_model - if self.args.train_type == 'lora' and not is_moe: - # MoE LoRA is not supported now + vllm_enable_lora = self.args.vllm_enable_lora + if self.args.train_type == 'lora' and vllm_enable_lora: lora_kwargs = { - 'enable_lora': True, + 'enable_lora': self.args.vllm_enable_lora, 'max_loras': 1, 'max_lora_rank': self.args.lora_rank, } self.rollout_enable_lora = True + + if is_moe: + logger.warning( + 'vLLM LoRA is enabled for an MoE model. This may cause errors when applying LoRA to expert layers, ' + 'as vLLM currently does not support LoRA in MoE configurations. If you encounter errors, ' + 'please set vllm_enable_lora to False.') + + if self.is_multimodal: + logger.warning('vLLM LoRA is enabled for a multimodal model. This may lead to unexpected issues ' + 'when applying LoRA to the ViT component, as vLLM does not yet support this setup. ' + 'If errors occur, please disable LoRA by setting vllm_enable_lora to False.') + patch_vllm_load_adapter() with Swift.grpo_context(model, self.template.processor): set_expandable_segments(False) @@ -640,7 +652,7 @@ def _move_adapter_to_vllm(self): peft_config=asdict(peft_config), lora_tensors=lora_params, ) - self.engine.llm_engine.add_lora(lora_reqest) + self.engine.engine.add_lora(lora_reqest) del lora_params def _load_state_dict_to_vllm(self, state_dict): From 9f98473181b4e0d71d1cf2c5f252c9693d7f1721 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 17 Oct 2025 17:11:24 +0800 Subject: [PATCH 42/42] fix zh comment in en doc --- docs/source_en/Instruction/GRPO/GetStarted/GRPO.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index f19faf3f9d..7234621584 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -202,7 +202,7 @@ Swift 3.10 optimizes weight synchronization, and setting the following parameter # rollout(server mode) swift rollout \ --vllm_enable_lora true \ - --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 + --vllm_max_lora_rank xxx # match the lora_rank in the training script ... # grpo(colocate mode) @@ -212,7 +212,6 @@ swift rlhf \ --vllm_enable_lora true \ ... ``` - Note: This optimization cannot be used in the following cases: - Training the ViT layers of multimodal models (freeze_vit set to false)