diff --git "a/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" "b/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" index 67f12e0165..ac00f9d180 100644 --- "a/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/GRPO\344\273\243\347\240\201\350\256\255\347\273\203.md" @@ -42,7 +42,9 @@ ```bash CUDA_VISIBLE_DEVICES=7 \ swift rollout \ - --model Qwen/Qwen2.5-7B-Instruct + --model Qwen/Qwen2.5-7B-Instruct \ + --vllm_enable_lora true \ + --vllm_max_lora_rank 16 ``` ```bash @@ -61,6 +63,8 @@ swift rlhf \ --vllm_server_host 127.0.0.1 \ --vllm_server_port 8000 \ --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ --torch_dtype bfloat16 \ --dataset 'open-r1/verifiable-coding-problems-python-10k' \ --load_from_cache_file true \ diff --git a/docs/source/Instruction/GRPO/GetStarted/GRPO.md b/docs/source/Instruction/GRPO/GetStarted/GRPO.md index f0bb51bdf4..4715ebe9b4 100644 --- a/docs/source/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source/Instruction/GRPO/GetStarted/GRPO.md @@ -185,7 +185,7 @@ swift rollout \ 更多 rollout 参数参考[vLLM参数](../../../Instruction/命令行参数.md#vllm参数)和[rollout 参数](../../../Instruction/命令行参数.md#rollout参数) -注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP。 +注意:在使用 use_async_engine 时,仅开启 DP 可能会导致错误,相关问题参考: [vllm issue](https://github.com/vllm-project/vllm/issues/18567)。如果出现错误,请尝试同时启用 TP 和 DP,或升级vLLM 训练使用以下参数配置外部 vLLM 服务器 @@ -196,6 +196,30 @@ swift rollout \ --vllm_server_port <服务端口> \ --vllm_server_timeout <超时时间> \ ``` +#### 权重同步加速 +swift 3.10 优化了权重同步,设置以下参数可以进一步优化 LoRA 训练的权重同步速度。 + +```bash +# rollout(server mode) +swift rollout \ + --vllm_enable_lora true \ + --vllm_max_lora_rank xxx # 与训练脚本lora_rank一致 + ... + +# grpo(colocate mode) +swift rlhf \ + --rlhf_type grpo \ + --vllm_mode colocate \ + --vllm_enable_lora true \ + ... +``` + +注意:以下情况无法使用该优化: + +- 训练多模态模型的ViT层(freeze_vit false) +- MoE 模型 + +优化实现细节请参考该[PR](https://github.com/modelscope/ms-swift/pull/5773) ## logged metrics - completions/mean_length:生成的 completion 的平均长度。 diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 7917eddce2..5c736f248c 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -526,17 +526,20 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, - vllm_mode server 参数 - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - - vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用。 + - vllm_server_host:vLLM server host地址,默认为None。 - vllm_server_port vLLM server 服务端口,默认为8000。 - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 - vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 - vllm_max_model_len: vllm透传参数,默认为None。 - vllm_enforce_eager: vllm透传参数,默认为False。 - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 - vllm_enable_prefix_caching: vllm透传参数,默认为True。 + - vllm_tensor_parallel_size: tp并行数,默认为`1`。 + - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。 - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放 - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 @@ -549,7 +552,7 @@ reward模型参数将在PPO、GRPO中使用。 - sync_ref_model: 是否定期同步ref_model,默认为False。 - ref_model_mixup_alpha: 控制在更新过程中model和先前ref_model之间的混合。更新公式为 $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$。默认为0.6。 - ref_model_sync_steps:同步频率,默认为512。 -- move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。注意:该参数仅对LoRA(PEFT)训练有意义。 +- move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。 - multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 - dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 @@ -604,8 +607,10 @@ soft overlong 奖励参数 ### Rollout参数 Rollout参数继承于[部署参数](#部署参数) -- multi_turn_scheduler: 多轮GRPO训练规划器,传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。默认为None,具体参考[文档](./GRPO/DeveloperGuide/多轮训练.md) +- multi_turn_scheduler: 多轮GRPO训练规划器,传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。默认为None,具体参考[文档](./GRPO/DeveloperGuide/多轮训练.md)。 - max_turns: 多轮GRPO训练下的最大轮数,默认为None,即不做约束。 +- vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。 +- vllm_max_lora_rank: vLLM Engine LoRA参数,需大于等于训练的lora_rank,建议等于。默认为16。 ### Web-UI参数 - server_name: web-ui的host,默认为'0.0.0.0'。 diff --git a/docs/source_en/BestPractices/GRPO-Code-Training.md b/docs/source_en/BestPractices/GRPO-Code-Training.md index aa24f56985..3b822ec97c 100644 --- a/docs/source_en/BestPractices/GRPO-Code-Training.md +++ b/docs/source_en/BestPractices/GRPO-Code-Training.md @@ -46,7 +46,9 @@ launch external vLLM server using following script ```bash CUDA_VISIBLE_DEVICES=7 \ swift rollout \ - --model Qwen/Qwen2.5-7B-Instruct + --model Qwen/Qwen2.5-7B-Instruct \ + --vllm_enable_lora true \ + --vllm_max_lora_rank 16 ``` ```bash @@ -65,6 +67,8 @@ swift rlhf \ --vllm_server_host 127.0.0.1 \ --vllm_server_port 8000 \ --train_type lora \ + --lora_rank 16 \ + --lora_alpha 32 \ --torch_dtype bfloat16 \ --dataset 'open-r1/verifiable-coding-problems-python-10k' \ --load_from_cache_file true \ diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index ead203be2d..c4160ee525 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -535,17 +535,20 @@ The meanings of the following parameters can be referenced [here](https://huggin - vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate` - vllm_mode server parameter - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None. - - vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server. + - vllm_server_host: The host address of the vLLM server. Default is None. - vllm_server_port: The service port of the vLLM server. Default is 8000. - vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds. - vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training. - async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: An environment variable that controls the bucket size (in MB) for weight synchronization during full-parameter training in Server Mode. Default is 512 MB. - vllm_mode colocate parameter (For more parameter support, refer to the [vLLM Arguments](#vLLM-Arguments).) - vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9. - vllm_max_model_len: vLLM passthrough parameter, the total length limit of model, default is None. - vllm_enforce_eager: vLLM passthrough parameter, default is False. - vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None. + - vllm_enable_prefix_caching: A pass-through parameter for vLLM, default is True. - vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1. + - vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. - sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep - offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`. - offload_model: Whether to offload the model during inference with vLLM. The default is `False`. @@ -563,7 +566,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - sync_ref_model: Whether to synchronize the reference model. Default is False。 - ref_model_mixup_alpha: The Parameter controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation: $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$. Default is 0.6. - ref_model_sync_steps:The parameter determines how frequently the current policy is synchronized with the reference policy. Default is 512. -- move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches. This parameter is only meaningful for LoRA (PEFT). +- move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches. - multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py. - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. - dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. @@ -623,6 +626,8 @@ Deployment Arguments inherit from the [inference arguments](#inference-arguments The rollout parameters inherit from the [deployment parameters](#deployment-arguments). - multi_turn_scheduler: The scheduler for multi-turn GRPO training. Pass the corresponding plugin name, and ensure the implementation is added in `plugin/multi_turn.py`. Default is `None`. See [documentation](./GRPO/DeveloperGuide/multi_turn.md) for details. - max_turns: Maximum number of turns in multi-turn GRPO training. Default is `None`, meaning no limit. +- vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. +- vllm_max_lora_rank: LoRA parameter for the vLLM engine. Must be greater than or equal to the training lora_rank; it is recommended to set them equal. Defaults to 16. ### Web-UI Arguments - server_name: Host for the web UI, default is '0.0.0.0'. diff --git a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md index 73b4772005..7234621584 100644 --- a/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md +++ b/docs/source_en/Instruction/GRPO/GetStarted/GRPO.md @@ -183,7 +183,7 @@ swift rollout \ ``` For more rollout parameters, refer to the [vllm arguments](../../../Instruction/Command-line-parameters.md#vllm-arguments) and [rollout arguments](../../../Instruction/Command-line-parameters.md#rollout-arguments) -Note: When set `use_async_engine`, enabling only DP (Data Parallelism) may cause errors. [Related issue](https://github.com/vllm-project/vllm/issues/18567). If errors occur, try enabling both TP (Tensor Parallelism) and DP. +Note: When set `use_async_engine`, enabling only DP (Data Parallelism) may cause errors. [Related issue](https://github.com/vllm-project/vllm/issues/18567). If errors occur, try enabling both TP (Tensor Parallelism) and DP or upgrading vLLM. To configure the external vLLM server during training, use the following parameters: @@ -194,6 +194,31 @@ To configure the external vLLM server during training, use the following paramet --vllm_server_port \ --vllm_server_timeout \ ``` + +#### Weight-Sync Acceleration +Swift 3.10 optimizes weight synchronization, and setting the following parameters can further improve the weight synchronization speed for LoRA training: + +```bash +# rollout(server mode) +swift rollout \ + --vllm_enable_lora true \ + --vllm_max_lora_rank xxx # match the lora_rank in the training script + ... + +# grpo(colocate mode) +swift rlhf \ + --rlhf_type grpo \ + --vllm_mode colocate \ + --vllm_enable_lora true \ + ... +``` +Note: This optimization cannot be used in the following cases: + +- Training the ViT layers of multimodal models (freeze_vit set to false) +- MoE models + +For implementation details, please refer to the [PR](https://github.com/modelscope/ms-swift/pull/5773) + ## logged metrics - completions/mean_length: The average length of generated completions. - completions/min_length: The minimum length among generated completions. diff --git a/examples/train/grpo/external/README.md b/examples/train/grpo/external/README.md index 733199dd4c..a4808a5c9d 100644 --- a/examples/train/grpo/external/README.md +++ b/examples/train/grpo/external/README.md @@ -7,6 +7,12 @@ 1. vLLM version 0.8.3 or higher. 2. trl version 0.17.0 or higher +For LoRA Training, set following parameters to speed up weight update +```bash + --vllm_enable_lora true + --vllm_max_lora_rank xxx # same as lora_rank in training script +``` + ## **Introduction** The GRPO (Group Relative Policy Optimization) training framework supports high-performance inference engines like vLLM to accelerate the sampling process. The **External Mode** allows you to connect to an external vLLM inference server, separating the inference service from the training process. This mode is ideal for scenarios where you want to offload inference to dedicated hardware or servers, improving resource utilization and scalability. diff --git a/examples/train/grpo/external/moe_full.sh b/examples/train/grpo/external/moe_full.sh new file mode 100644 index 0000000000..634f67024f --- /dev/null +++ b/examples/train/grpo/external/moe_full.sh @@ -0,0 +1,44 @@ +# 8*80G + +# CUDA_VISIBLE_DEVICES=0 \ +# swift rollout \ +# --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ +# --vllm_max_model_len 16384 \ +# --vllm_enable_prefix_caching true + +CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \ +NPROC_PER_NODE=7 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --train_type full \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 14 \ + --temperature 1.0 \ + --deepspeed zero3_offload \ + --log_completions true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 5 diff --git a/examples/train/grpo/external/moe_lora.sh b/examples/train/grpo/external/moe_lora.sh new file mode 100644 index 0000000000..3bd3ec5d9c --- /dev/null +++ b/examples/train/grpo/external/moe_lora.sh @@ -0,0 +1,44 @@ +# 8*80G + +# CUDA_VISIBLE_DEVICES=0 \ +# swift rollout \ +# --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ +# --vllm_max_model_len 16384 \ +# --vllm_enable_prefix_caching true + +CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 \ +NPROC_PER_NODE=7 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --train_type lora \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 14 \ + --temperature 1.0 \ + --deepspeed zero3 \ + --log_completions true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 5 diff --git a/examples/train/grpo/internal/moe_full.sh b/examples/train/grpo/internal/moe_full.sh new file mode 100644 index 0000000000..7f74c8b736 --- /dev/null +++ b/examples/train/grpo/internal/moe_full.sh @@ -0,0 +1,40 @@ +# 8*80G + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 2 \ + --vllm_max_model_len 16384 \ + --train_type full \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 1.0 \ + --deepspeed zero3_offload \ + --log_completions true \ + --sleep_level 1 \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 10 diff --git a/examples/train/grpo/internal/moe_lora.sh b/examples/train/grpo/internal/moe_lora.sh new file mode 100644 index 0000000000..3ebdec2a65 --- /dev/null +++ b/examples/train/grpo/internal/moe_lora.sh @@ -0,0 +1,42 @@ +# 8*80G + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +swift rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --reward_funcs accuracy \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 2 \ + --vllm_max_model_len 16384 \ + --train_type lora \ + --torch_dtype bfloat16 \ + --dataset AI-MO/NuminaMath-TIR#1000 \ + --max_length 12000 \ + --max_completion_length 8192 \ + --overlong_filter true \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-6 \ + --gradient_accumulation_steps 4 \ + --save_strategy 'steps' \ + --eval_strategy 'steps' \ + --eval_steps 1000 \ + --save_steps 1000 \ + --save_total_limit 10 \ + --logging_steps 1 \ + --warmup_ratio 0.01 \ + --dataloader_num_workers 4 \ + --num_generations 16 \ + --temperature 1.0 \ + --deepspeed zero3 \ + --log_completions true \ + --sleep_level 1 \ + --offload_model true \ + --offload_optimizer true \ + --report_to tensorboard swanlab \ + --num_iterations 1 \ + --beta 0.001 \ + --move_model_batches 10 diff --git a/examples/train/grpo/internal/vllm_72b_4gpu.sh b/examples/train/grpo/internal/vllm_72b_4gpu.sh index d63af49e8f..c6871760c2 100644 --- a/examples/train/grpo/internal/vllm_72b_4gpu.sh +++ b/examples/train/grpo/internal/vllm_72b_4gpu.sh @@ -36,7 +36,6 @@ swift rlhf \ --top_p 1.0 \ --top_k 80 \ --log_completions true \ - --async_generate false \ --move_model_batches 16 \ --offload_optimizer true \ --offload_model true \ diff --git a/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh b/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh index bf054de523..f41273c39f 100755 --- a/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh +++ b/examples/train/grpo/internal/vllm_lora_qwenvl72b.sh @@ -40,7 +40,6 @@ swift rlhf \ --top_p 1.0 \ --top_k 80 \ --log_completions true \ - --async_generate false \ --offload_optimizer true \ --offload_model true \ --move_model_batches 40 \ diff --git a/examples/train/grpo/internal/vllm_multi_turn.sh b/examples/train/grpo/internal/vllm_multi_turn.sh index 352e64b890..1cc8d5b500 100644 --- a/examples/train/grpo/internal/vllm_multi_turn.sh +++ b/examples/train/grpo/internal/vllm_multi_turn.sh @@ -36,7 +36,6 @@ swift rlhf \ --top_p 1.0 \ --top_k 80 \ --log_completions true \ - --async_generate false \ --offload_optimizer true \ --offload_model true \ --sleep_level 1 \ diff --git a/swift/llm/argument/deploy_args.py b/swift/llm/argument/deploy_args.py index c6a762bec0..71ae22b2bd 100644 --- a/swift/llm/argument/deploy_args.py +++ b/swift/llm/argument/deploy_args.py @@ -86,7 +86,8 @@ class RolloutArguments(DeployArguments): # only for GRPO rollout with AsyncEngine, see details in swift/plugin/multi_turn multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None - + vllm_enable_lora: bool = False + vllm_max_lora_rank: int = 16 # GYM env gym_env: Optional[str] = None context_manager: Optional[str] = None diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index ac546f4b0d..dcf185586b 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -14,6 +14,7 @@ try: os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400' + from vllm.lora.request import LoRARequest except Exception: raise @@ -98,6 +99,16 @@ def infer( use_tqdm: Optional[bool] = None, adapter_request: Optional[AdapterRequest] = None, ) -> List[RolloutOutput]: + if not adapter_request and self.enable_lora: + lora_int_ids = list(self.engine.list_loras()) + if lora_int_ids: + # since max_lora = 1, pick the first lora + adapter_request = LoRARequest( + lora_name=f'{lora_int_ids[0]}', + lora_int_id=lora_int_ids[0], + lora_path='dummy_lora_path', + ) + res = super().infer( infer_requests, request_config, @@ -189,3 +200,13 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r id=request_id, prompt_token_ids=prompt_token_ids, images_size=images_size) + + def _add_adapter(self, adapter_request: Optional[Union[AdapterRequest, LoRARequest]] = None): + assert self.enable_lora, f'adapter_request: {adapter_request}, self.enable_lora: {self.enable_lora}' + from vllm.lora.request import LoRARequest + if isinstance(adapter_request, AdapterRequest): + return super()._add_adapter(adapter_request) + elif isinstance(adapter_request, LoRARequest): + return adapter_request + else: + raise ValueError(f'Invalid adapter request: {adapter_request}') diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index e2e0d02783..a3926df386 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -12,6 +12,8 @@ from PIL import Image from pydantic import BaseModel, Field, field_validator +from swift.trainers.rlhf_trainer.utils import FlattenedTensorMetadata +from swift.tuners.lora import LoraConfig from ..template import InferRequest from ..utils import Messages, Tool @@ -459,3 +461,13 @@ class UpdateWeightsRequest(BaseModel): name: str dtype: str shape: list[int] + + +class UpdateFlattenedAdapterRequest(BaseModel): + lora_int_id: int + peft_config: LoraConfig + metadatas: List[FlattenedTensorMetadata] + + +class UpdateFlattenedParamsRequest(BaseModel): + metadatas: List[FlattenedTensorMetadata] diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index d56c18e301..b0a8dc7071 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -6,26 +6,31 @@ import multiprocessing import os import time +import traceback +from collections.abc import Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import asdict from functools import wraps from itertools import chain from multiprocessing import Pipe, Process from multiprocessing.connection import Connection -from typing import Dict, List, Optional, Union, get_type_hints +from typing import Dict, List, Optional, Union import torch import uvicorn from aiohttp import ClientConnectorError from fastapi import FastAPI -from trl.scripts.vllm_serve import WeightSyncWorkerExtension +from trl.scripts.vllm_serve import WeightSyncWorkerExtension as HFWeightSyncWorkerExtension from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest from swift.plugin.multi_turn import RolloutScheduler, multi_turns +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest, + patch_vllm_load_adapter) from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient -from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest +from .protocol import (InitCommunicatorRequest, RequestConfig, UpdateFlattenedAdapterRequest, + UpdateFlattenedParamsRequest, UpdateWeightsRequest) try: from vllm.utils import get_open_port @@ -50,6 +55,84 @@ - For inference or deployment, please use the `swift infer` or `swift deploy` commands. """ +patch_vllm_load_adapter() + + +class WeightSyncWorkerExtension(HFWeightSyncWorkerExtension): + + def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + + Args: + name (`str`): + Name of the weight tensor being updated. + dtype (`str`): + Data type of the weight tensor as a string (e.g., `"torch.float32"`). + shape (`Sequence[int]`): + Shape of the weight tensor. + """ + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + + dtype = getattr(torch, dtype.split('.')[-1]) + # Allocate memory for the incoming weight tensor on the correct device. + weight = torch.empty(shape, dtype=dtype, device=self.device) + + # Use NCCL to broadcast the updated weights from the client (src) to all workers. + self.pynccl_comm.broadcast(weight, src=self.client_rank) + self.pynccl_comm.group.barrier() + + # Load the received weights into the model. + self.model_runner.model.load_weights(weights=[(name, weight)]) + + def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, metadatas: list[Dict]) -> None: + """ + Receives and applies a flattened LoRA adapter to the model. + """ + metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + flatten_tensor_length = metadatas[-1].end_idx + dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1]) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device) + self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) + self.pynccl_comm.group.barrier() + flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) + named_params = flattened_tensor_bucket.reconstruct_tensors() + lora_request = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=peft_config, + lora_tensors=named_params) + self.add_lora(lora_request) + + def update_flattened_params(self, metadatas: list[Dict]) -> None: + """ + Receives updated flattened weights from the client process and updates the model parameters. + + Args: + metadatas (list[Dict]): List of metadata dictionaries for the flattened tensors. + """ + metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas] + if self.pynccl_comm is None: + raise RuntimeError('Communicator not initialized. Call `init_communicator` first.') + + flatten_tensor_length = metadatas[-1].end_idx + dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1]) + flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device) + + self.pynccl_comm.broadcast(flatten_tensor, src=self.client_rank) + self.pynccl_comm.group.barrier() + + flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor) + named_params = flattened_tensor_bucket.reconstruct_tensors() + + # Load the reconstructed parameters into the model + self.model_runner.model.load_weights(weights=list(named_params.items())) + + logger = get_logger() @@ -109,7 +192,11 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) - result = method(*args, **kwargs) + try: + result = method(*args, **kwargs) + except Exception: + logger.error(f'Method execution failed: {method_name}\n{traceback.format_exc()}') + result = None if command['type'] == 'call': connection.send(result) elif command['type'] == 'shutdown': @@ -138,7 +225,6 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast # Handle commands if command['type'] in ['call', 'fire_and_forget']: - import traceback method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) @@ -167,6 +253,8 @@ def _register_rl_rollout_app(self): self.app.get('/get_world_size/')(self.get_world_size) self.app.post('/init_communicator/')(self.init_communicator) self.app.post('/update_named_param/')(self.update_named_param) + self.app.post('/update_adapter_flattened_param/')(self.update_adapter_flattened_param) + self.app.post('/update_flattened_params/')(self.update_flattened_params) self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache) self.app.post('/close_communicator/')(self.close_communicator) self.app.post('/infer/', response_model=None)(self.infer) @@ -224,16 +312,18 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, + 'max_lora_rank': args.vllm_max_lora_rank, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend if infer_backend != 'vllm': infer_backend = 'vllm' logger.info('Currently, rollout only supports the vLLM backend. Set vLLM backend') kwargs.update(args.get_vllm_engine_kwargs()) + kwargs.update({'enable_lora': args.vllm_enable_lora}) # override # used for RL external rollout backend engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync - engine_kwargs.update({'worker_extension_cls': 'trl.scripts.vllm_serve.WeightSyncWorkerExtension'}) + engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'}) engine_kwargs['load_format'] = 'dummy' if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1: engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size @@ -311,6 +401,37 @@ async def update_named_param(self, request: UpdateWeightsRequest): return {'message': 'Request received, updating named parameter'} + async def update_adapter_flattened_param(self, request: UpdateFlattenedAdapterRequest): + peft_config = asdict(request.peft_config) + metadatas = [ + metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() + for metadata in request.metadatas + ] + kwargs = {'method': 'update_adapter_flattened_param', 'args': (request.lora_int_id, peft_config, metadatas)} + for connection in self.connections: + connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) + + return {'message': 'Request received, updating adapter parameter'} + + async def update_flattened_params(self, request: UpdateFlattenedParamsRequest): + """ + Updates the model weights with flattened tensor data. + + Args: + request (UpdateFlattenedParamsRequest): + - metadatas (List[FlattenedTensorMetadata]): Metadata for the flattened tensors. + + """ + metadatas = [ + metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict() + for metadata in request.metadatas + ] + kwargs = {'method': 'update_flattened_params', 'args': (metadatas, )} + for connection in self.connections: + connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs}) + + return {'message': 'Request received, updating flattened parameters'} + async def reset_prefix_cache(self): """ Resets the prefix cache for the model. @@ -342,13 +463,17 @@ async def get_engine_type(self): enable_multi_turn = False if self.args.multi_turn_scheduler: enable_multi_turn = True - - if self.use_async_engine: - if self.use_gym_env: - return {'engine_type': 'AsyncLLMEngine', 'gym_env': True, 'enable_multi_turn': True} - return {'engine_type': 'AsyncLLMEngine', 'enable_multi_turn': enable_multi_turn} - else: - return {'engine_type': 'LLMEngine', 'enable_multi_turn': enable_multi_turn} + use_gym_env = False + if self.use_async_engine and self.use_gym_env: + use_gym_env = True + engine_type = 'AsyncLLMEngine' if self.use_async_engine else 'LLMEngine' + enable_lora = self.args.vllm_enable_lora + return { + 'engine_type': engine_type, + 'enable_multi_turn': enable_multi_turn, + 'use_gym_env': use_gym_env, + 'enable_lora': enable_lora, + } async def close_communicator(self): """ @@ -429,22 +554,3 @@ def run_rollout(args: RolloutArguments, return_url: bool = False): finally: process.terminate() logger.info('The deployment process has been terminated.') - - -# https://github.com/huggingface/trl/pull/3690 -# This patch handles backward compatibility for dtype parameter type changes in TRL: -# - For TRL <= 0.19: dtype_annotation is torch.dtype (needs patching) -# - For TRL > 0.19: dtype_annotation is str (no patching needed) -old_update_named_param = WeightSyncWorkerExtension.update_named_param -dtype_annotation = get_type_hints(old_update_named_param).get('dtype') - -if not hasattr(WeightSyncWorkerExtension, 'old_update_named_param') and dtype_annotation == torch.dtype: - - @wraps(old_update_named_param) - def patched_update_named_param(self, name, dtype, shape) -> None: - if isinstance(dtype, str): - dtype = getattr(torch, dtype.split('.')[-1]) - return old_update_named_param(self, name, dtype, shape) - - WeightSyncWorkerExtension.update_named_param = patched_update_named_param - WeightSyncWorkerExtension.old_update_named_param = old_update_named_param diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index d8f2b30042..2df0c753d8 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -246,29 +246,35 @@ def __call__(self, completions, solution, **kwargs) -> List[float]: from math_verify import LatexExtractionConfig, parse, verify rewards = [] for content, sol in zip(completions, solution): - gold_parsed = parse(sol, extraction_mode='first_match') + content_match = re.search(r'(.*?)', content, re.DOTALL) + content_to_parse = content_match.group(1).strip() if content_match else content + has_answer_tag = content_match is not None + + sol_match = re.search(r'(.*?)', sol, re.DOTALL) + sol_to_parse = sol_match.group(1).strip() if sol_match else sol + + gold_parsed = parse(sol_to_parse, extraction_mode='first_match') if len(gold_parsed) != 0: - # We require the answer to be provided in correct latex (no malformed operators) - answer_parsed = parse( - content, - extraction_config=[ - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - equations=True, - boxed=True, - units=True, - ), - # Ensures that boxed is tried first - boxed_match_priority=0, - try_extract_without_anchor=False, - ) - ], - extraction_mode='first_match', - ) - # edge case + if has_answer_tag: + answer_parsed = parse(content_to_parse, extraction_mode='first_match') + else: + answer_parsed = parse( + content_to_parse, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed=True, + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode='first_match', + ) try: reward = float(verify(gold_parsed, answer_parsed)) except Exception: diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index b2256da0b7..f5afe1de89 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -271,7 +271,7 @@ class GRPOArgumentsMixin(VllmArguments): vllm_mode: Literal['server', 'colocate'] = 'colocate' # internal vllm (colocate) vllm_enable_prefix_caching: bool = True # overwrite - + vllm_enable_lora: bool = False # external vllm (server) vllm_server_base_url: Optional[List[str]] = None vllm_server_host: Optional[List[str]] = None diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py index b8423dc340..89fb020e08 100644 --- a/swift/trainers/rlhf_arguments.py +++ b/swift/trainers/rlhf_arguments.py @@ -51,6 +51,7 @@ class GKDConfig(SwiftArgumentsMixin, HfGKDConfig): @dataclass class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig): stop_words: List[str] = field(default_factory=list) + lora_rank: int = 8 # for vllm lora adapter def __post_init__(self): GRPOArgumentsMixin.__post_init__(self) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index f53004e536..f3addf5d93 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -7,7 +7,7 @@ import re import time import uuid -from collections import defaultdict, deque +from collections import OrderedDict, defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext from copy import copy, deepcopy @@ -25,6 +25,7 @@ from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from dacite import from_dict from packaging import version +from peft.utils.save_and_load import get_peft_model_state_dict from torch.nn import ModuleList from torch.utils.data import DataLoader from transformers import PreTrainedModel, TrainerCallback @@ -48,10 +49,11 @@ unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin -from .utils import (_ForwardRedirection, compute_chord_loss, identity_data_collator, load_pil_img, - make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, - patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids, - set_expandable_segments) +from .utils import (FlattenedTensorBucket, TensorLoRARequest, _create_parameter_buckets, _ForwardRedirection, + _process_bucket_with_flattened_tensor, compute_chord_loss, get_gather_if_zero3_context, + identity_data_collator, load_pil_img, make_chord_sft_dataset, patch_lora_merge, patch_lora_unmerge, + patch_profiling_context, patch_profiling_decorator, patch_save_last_checkpoint, + patch_vllm_load_adapter, replace_assistant_response_with_ids, set_expandable_segments) from .vllm_client import VLLMClient try: @@ -245,19 +247,20 @@ def __init__(self, # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but # it's safer to set it in all cases. set_seed(args.seed, device_specific=True) - if is_peft_model(self.model): - self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() self.use_fast_infer = self.use_vllm # whether to use the PT backend self.vllm_use_async_engine = False self.enable_offload = False self.use_gym_env = False self.enable_server_multi_turn = False + self.rollout_enable_lora = False # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs self.dynamic_num_samples = False if self.use_vllm: if not is_vllm_available(): raise ImportError('vLLM is not available and `use_vllm` is set to True. ' 'Please install vLLM with `pip install vllm -U` to use it.') + self.base_sync_done = False # tag for lora weights sync + if self.vllm_mode == 'server': self.vllm_client: VLLMClient = vllm_client if self.accelerator.is_main_process: @@ -265,13 +268,16 @@ def __init__(self, vllm_use_async_engine = [self.vllm_client.use_async_engine] use_gym_env = [self.vllm_client.use_gym_env] enable_multi_turn = [self.vllm_client.enable_multi_turn] + enable_lora = [self.vllm_client.enable_lora] else: vllm_use_async_engine = [False] use_gym_env = [False] enable_multi_turn = [self.enable_server_multi_turn] + enable_lora = [False] self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + self.rollout_enable_lora = broadcast_object_list(enable_lora, from_process=0)[0] if self.use_gym_env: self.reward_func_names = ['gym_reward'] @@ -302,6 +308,8 @@ def __init__(self, infer_template.padding_free = False self.engine = PtEngine.from_model_template(self.model, infer_template, max_batch_size=0) # 0: no limit + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + if not self.reward_funcs and not self.use_gym_env: raise ValueError('You must specify reward_funcs or reward_model') @@ -486,7 +494,10 @@ def replace_lora(name): if 'lora_' in name: return '' else: - return name.replace('base_layer.', '') + if not self.rollout_enable_lora: + return name.replace('.base_layer', '') + else: + return name def remove_lora_and_prefix(names): names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names]) @@ -533,6 +544,29 @@ def prepare_vllm(self, model): self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation) vllm_template = copy(self.template) vllm_template.padding_free = False + lora_kwargs = {} + is_moe = model.model_info.is_moe_model + vllm_enable_lora = self.args.vllm_enable_lora + if self.args.train_type == 'lora' and vllm_enable_lora: + lora_kwargs = { + 'enable_lora': self.args.vllm_enable_lora, + 'max_loras': 1, + 'max_lora_rank': self.args.lora_rank, + } + self.rollout_enable_lora = True + + if is_moe: + logger.warning( + 'vLLM LoRA is enabled for an MoE model. This may cause errors when applying LoRA to expert layers, ' + 'as vLLM currently does not support LoRA in MoE configurations. If you encounter errors, ' + 'please set vllm_enable_lora to False.') + + if self.is_multimodal: + logger.warning('vLLM LoRA is enabled for a multimodal model. This may lead to unexpected issues ' + 'when applying LoRA to the ViT component, as vLLM does not yet support this setup. ' + 'If errors occur, please disable LoRA by setting vllm_enable_lora to False.') + + patch_vllm_load_adapter() with Swift.grpo_context(model, self.template.processor): set_expandable_segments(False) engine = GRPOVllmEngine( @@ -553,6 +587,7 @@ def prepare_vllm(self, model): load_format='dummy', template=vllm_template, distributed_executor_backend='external_launcher', + **lora_kwargs, ) set_expandable_segments(True) return engine @@ -569,66 +604,130 @@ def _template_context(self, template: Template): @patch_profiling_decorator def _move_model_to_vllm(self, skip_async_check=False): - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - if zero_stage_3: - import deepspeed - gather_if_zero3 = deepspeed.zero.GatheredParameters - else: - gather_if_zero3 = nullcontext - if self.args.async_generate and not skip_async_check: # before sync weight, we should wait async generate finish self._wait_queue() - if is_peft_model(self.model): - for i, parameter_group in enumerate(self.parameter_groups): # < this is the change - parameter_group_no_lora = self.parameter_groups_no_lora[i] - parameters = [ - parameter for name, parameter in self.model.named_parameters() - if not parameter_group or name in parameter_group - ] - with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): + train_type = self.args.train_type + + if train_type == 'full' or (train_type == 'lora' and not self.base_sync_done) or not self.rollout_enable_lora: + self._move_full_model_to_vllm() + else: + self._move_adapter_to_vllm() + + def _move_adapter_to_vllm(self): + lora_params = OrderedDict() + for i, parameter_group in enumerate(self.parameter_groups): # < this is the change + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + gather_if_zero3 = get_gather_if_zero3_context(self) + with gather_if_zero3(parameters), patch_lora_merge(self.model, parameter_group): + assert len(parameters) == len(parameter_group) + state_dict = {name: p for p, name in zip(parameters, parameter_group)} + peft_config = self.model.peft_config.get('default', None) + self.model.merge_adapter() + cur_lora_params = get_peft_model_state_dict(self.model, state_dict) + cur_lora_params = { + name: param.full_tensor().detach() if hasattr(param, 'full_tensor') else param.detach() + for name, param in cur_lora_params.items() + } + lora_params.update(cur_lora_params) + with patch_lora_unmerge(self.model): + self.model.unmerge_adapter() + del cur_lora_params + + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + bucked = FlattenedTensorBucket(named_tensors=list(lora_params.items())) + metadatas = bucked.get_metadata() + flattened_tensor = bucked.get_flattened_tensor() + self.vllm_client.update_adapter_flattened_param(peft_config, metadatas, flattened_tensor) + elif self.vllm_mode == 'colocate': + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) + lora_reqest = TensorLoRARequest( + lora_name=f'{lora_int_id}', + lora_int_id=lora_int_id, + lora_path='dummy_lora_path', + peft_config=asdict(peft_config), + lora_tensors=lora_params, + ) + self.engine.engine.add_lora(lora_reqest) + del lora_params + + def _load_state_dict_to_vllm(self, state_dict): + """Load state_dict to vLLM engine (server or colocate mode)""" + if self.vllm_mode == 'server' and self.accelerator.is_main_process: + bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) + named_params = list(state_dict.items()) + parameter_buckets = _create_parameter_buckets(named_params, bucket_size_mb=bucket_size_mb) + + for bucket in parameter_buckets: + _process_bucket_with_flattened_tensor(self, bucket) + + del named_params, parameter_buckets + elif self.vllm_mode == 'colocate': + llm_model = self.engine.inner_model + llm_model.load_weights(state_dict.items()) + del state_dict + + def _move_full_model_to_vllm(self): + gather_if_zero3 = get_gather_if_zero3_context(self) + is_peft = is_peft_model(self.model) + + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + parameters = [ + parameter for name, parameter in self.model.named_parameters() + if not parameter_group or name in parameter_group + ] + + # Use patch_lora_merge for PEFT models, nullcontext otherwise + context_manager = patch_lora_merge(self.model, parameter_group) if is_peft else nullcontext() + + with gather_if_zero3(parameters), context_manager: + if is_peft and self.should_merge_adapter: self.model.merge_adapter() - state_dict = self.model.state_dict() - state_dict = { - k.removeprefix('base_model.model.').replace('.base_layer', ''): v - for k, v in state_dict.items() + + state_dict = self.model.state_dict() + + # Process state_dict for PEFT models + if is_peft: + prefix_removed = {k.removeprefix('base_model.model.'): v for k, v in state_dict.items()} + state_dict = prefix_removed if self.rollout_enable_lora else { + k.replace('.base_layer', ''): v + for k, v in prefix_removed.items() } state_dict = {k: v for k, v in state_dict.items() if self.model.prefix not in k} - # When module to save, remove its prefix and discard the original module state_dict = { k.replace('modules_to_save.default.', ''): v for k, v in state_dict.items() if 'original_module' not in k } - if parameter_group_no_lora: + + # Filter by parameter_group_no_lora + if parameter_group_no_lora: + if is_peft: parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] - state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + + if is_peft: assert len(state_dict) > 0 and all( [state.shape != torch.Size([0]) for state in state_dict.values()]) - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - for name, param in state_dict.items(): - self.vllm_client.update_named_param(name, param) - elif self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights(state_dict.items()) + # Load to vLLM + self._load_state_dict_to_vllm(state_dict) + + if is_peft and self.should_merge_adapter: with patch_lora_unmerge(self.model): self.model.unmerge_adapter() - del state_dict - else: - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.vllm_mode == 'server' and self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - elif self.vllm_mode == 'colocate': - llm_model = self.engine.inner_model - llm_model.load_weights([(name, param.data)]) + if is_peft: + self.base_sync_done = True + + # Reset prefix cache if self.vllm_mode == 'server' and self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() elif self.vllm_mode == 'colocate': - # since vLLM model weights has been updated, we should reset the prefix cache self.engine.engine.reset_prefix_cache() def _wait_queue(self): @@ -2938,3 +3037,33 @@ def get_chunked_inputs(self, inputs, start_idx, end_idx): chunk_inputs.update(to_device(template.data_collator(encoded_data), self.model.device)) chunk_inputs.pop('labels', None) return chunk_inputs + + @property + def should_merge_adapter(self): + """ + Determine whether the LoRA adapter should be merged into the base model during weight synchronization. + + Note: + Merging or unmerging adapters in MoE models is computationally expensive and should be minimized. + + Raises: + AssertionError: If full-parameter training is used, as adapter merging is not supported. + + Returns: + bool: True if the adapter should be merged; False otherwise. + - Returns True when LoRA is not enabled for rollout. + - Returns True when loading from a checkpoint or using pre-trained adapters. + - Returns False during normal LoRA training (weights are already synchronized). + """ + assert self.args.train_type != 'full', 'Full-parameter training should not merge adapter' + + # Rollout does not support LoRA + if not self.rollout_enable_lora: + return True + + if self.args.resume_from_checkpoint: + # Resuming training: merge into base model + return True + + # base model weights are synced before training; no need to merge + return False diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index c168d69c3a..4bd92b2d6c 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -3,23 +3,25 @@ import math import os import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext +from dataclasses import asdict from functools import partial from io import BytesIO from types import MethodType -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import datasets import torch import torch.nn.functional as F +from msgspec import field from peft.tuners import lora from peft.tuners.lora import LoraLayer from PIL import Image +from pydantic import BaseModel, field_validator from torch import nn from torch.utils.data import DataLoader, RandomSampler -from transformers import Trainer -from swift.utils import is_swanlab_available, is_wandb_available +from swift.utils import is_swanlab_available, is_vllm_available, is_wandb_available if is_wandb_available(): import wandb @@ -29,6 +31,23 @@ if TYPE_CHECKING: from swift.llm.utils import Messages +TensorLoRARequest = None +if is_vllm_available(): + from vllm.lora.request import LoRARequest + + class TensorLoRARequest(LoRARequest): + peft_config: dict = field(default=None) + lora_tensors: dict = field(default=None) + lora_embeddings: Optional[Dict[str, torch.Tensor]] = None + + @property + def config(self): + return self.peft_config + + @property + def embeddings(self): + return self.lora_embeddings + def round_robin(num_reqs, num_workers): """Distribute requests evenly across workers using round-robin algorithm. @@ -368,6 +387,230 @@ def patched_len(self) -> int: RepeatSampler.old_len_func = origin_len_func +def get_gather_if_zero3_context(trainer): + deepspeed_plugin = trainer.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + return gather_if_zero3 + + +def patch_vllm_load_adapter(): + from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager + from vllm.lora.models import LoRAModel + from vllm.lora.utils import get_adapter_absolute_path + + try: + from vllm.transformers_utils.tokenizer_group import TokenizerGroup + except ImportError: + # removed in https://github.com/vllm-project/vllm/pull/24078 + TokenizerGroup = None + + def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel: + """ + code borrowed from verl.utils.vllm.utils.py + based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors + Reason: + VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. + To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to + load memory-based LoRA tensors. + """ + try: + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + expected_lora_modules: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend(packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + expected_lora_modules = list(set(expected_lora_modules)) + # this is the patch + lora_tensors = None + from vllm.lora.peft_helper import PEFTHelper + if isinstance(lora_request, TensorLoRARequest): + peft_config = lora_request.peft_config + lora_tensors = lora_request.lora_tensors + peft_helper = PEFTHelper.from_dict(peft_config) + else: + lora_path = get_adapter_absolute_path(lora_request.lora_path) + peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None) + if isinstance(lora_request, TensorLoRARequest): # this is the patch + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_tensors, + peft_helper=peft_helper, + device='cpu', + dtype=self.lora_config.lora_dtype, + embeddings=None, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + else: + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device='cpu', + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper, + ) + except Exception as e: + raise e + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than ' + f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.') + return lora + + def patched_get_lora_tokenizer(self: TokenizerGroup, lora_request: LoRARequest): + # since we pass dummy path, skip get tokenizer from path + return self.tokenizer + + if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'): + _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter + LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter + LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter + if TokenizerGroup is not None: + TokenizerGroup._old_get_lora_tokenizer = TokenizerGroup.get_lora_tokenizer + TokenizerGroup.get_lora_tokenizer = patched_get_lora_tokenizer + + +# FlattenedTensor, code borrowed from sglang/srt/weight_sync/tensor_bucket.py +class FlattenedTensorMetadata(BaseModel): + """Metadata for a tensor in a flattened bucket""" + name: str + shape: Tuple[int, ...] + dtype: str + start_idx: int + end_idx: int + numel: int + + @field_validator('shape', mode='before') + @classmethod + def ensure_shape_tuple(cls, v: Any) -> Tuple[int, ...]: + # accept tuple/list, torch.Size, or other iterable of ints + if torch is not None and isinstance(v, torch.Size): + return tuple(int(x) for x in v) + if isinstance(v, (list, tuple)): + return tuple(int(x) for x in v) + if isinstance(v, Iterable): + return tuple(int(x) for x in v) + raise ValueError('shape must be an iterable of ints (e.g. tuple/list/torch.Size)') + + @field_validator('dtype', mode='before') + @classmethod + def ensure_dtype_str(cls, v: Any) -> str: + # accept torch.dtype or str + if torch is not None and isinstance(v, torch.dtype): + return str(v) + if isinstance(v, str): + return v + raise ValueError('dtype must be a torch.dtype or str') + + +class FlattenedTensorBucket: + """ + A bucket that flattens multiple tensors into a single tensor for efficient processing + while preserving all metadata needed for reconstruction. + """ + + def __init__( + self, + named_tensors: List[Tuple[str, torch.Tensor]] = None, + flattened_tensor: torch.Tensor = None, + metadata: List[FlattenedTensorMetadata] = None, + ): + """ + Initialize a tensor bucket from a list of named tensors OR from pre-flattened data. + Args: + named_tensors: List of (name, tensor) tuples (for creating new bucket) + flattened_tensor: Pre-flattened tensor (for reconstruction) + metadata: Pre-computed metadata (for reconstruction) + """ + if named_tensors is not None: + # Create bucket from named tensors + self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors) + self.flattened_tensor: torch.Tensor = None + + if not named_tensors: + raise ValueError('Cannot create empty tensor bucket') + + # First pass: compute total size and metadata + current_idx = 0 + total_numel = 0 + for i, (name, tensor) in enumerate(named_tensors): + numel = tensor.numel() + metadata_obj = FlattenedTensorMetadata( + name=name, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype), + start_idx=current_idx, + end_idx=current_idx + numel, + numel=numel, + ) + self.metadata[i] = metadata_obj + current_idx += numel + total_numel += numel + + # Pre-allocate the final flattened tensor to avoid intermediate copies + # Use the dtype and device of the first tensor + first_tensor = named_tensors[0][1] + self.flattened_tensor = torch.empty(total_numel, dtype=first_tensor.dtype, device=first_tensor.device) + + # Second pass: copy data directly into pre-allocated tensor + for meta, (name, tensor) in zip(self.metadata, named_tensors): + self.flattened_tensor[meta.start_idx:meta.end_idx].copy_(tensor.flatten()) + else: + # Initialize from pre-flattened data + if flattened_tensor is None or metadata is None: + raise ValueError('Must provide either named_tensors or both flattened_tensor and metadata') + self.flattened_tensor = flattened_tensor + self.metadata = metadata + + def get_flattened_tensor(self) -> torch.Tensor: + """Get the flattened tensor containing all bucket tensors""" + return self.flattened_tensor + + def get_metadata(self) -> List[FlattenedTensorMetadata]: + """Get metadata for all tensors in the bucket""" + return self.metadata + + def reconstruct_tensors(self) -> Dict[str, torch.Tensor]: + """ + Reconstruct original tensors from flattened tensor with optimized performance. + Uses memory-efficient operations to minimize allocations and copies. + """ + # preallocate the result list + reconstructed = {} + + for meta in self.metadata: + tensor = self.flattened_tensor[meta.start_idx:meta.end_idx].reshape(meta.shape) + dtype = getattr(torch, meta.dtype.split('.')[-1]) + # batch dtype conversion (if needed) + if tensor.dtype != dtype: + tensor = tensor.to(dtype) + + reconstructed[meta.name] = tensor + + return reconstructed + + def identity_data_collator(features): """Identity data collator that returns features as-is without any processing.""" return features @@ -563,3 +806,63 @@ def set_expandable_segments(enable: bool) -> None: if torch.cuda.is_available(): torch.cuda.memory._set_allocator_settings(f'expandable_segments:{enable}') os.environ['PYTORCH_CUDA_ALLOC_CONF'] = f'expandable_segments:{enable}' + + +def peft_config_to_dict(peft_config): + if not isinstance(peft_config, dict): + peft_config = asdict(peft_config) + # turn set to list to serializable + if 'target_modules' in peft_config and isinstance(peft_config['target_modules'], set): + peft_config['target_modules'] = list(peft_config['target_modules']) + + return peft_config + + +def _create_parameter_buckets(named_params, bucket_size_mb=512): + """Create parameter buckets for efficient processing""" + buckets = [] + current_bucket = [] + current_size = 0 + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + + for name, param in named_params: + param_size = param.numel() * param.element_size() + + # If adding this param would exceed bucket size, process current bucket first + if current_size + param_size > bucket_size_bytes and current_bucket: + buckets.append(current_bucket) + current_bucket = [] + current_size = 0 + + current_bucket.append((name, param)) + current_size += param_size + + # Process remaining parameters in the last bucket + if current_bucket: + buckets.append(current_bucket) + + return buckets + + +def _process_bucket_with_flattened_tensor(trainer, bucket_params): + """Process a bucket of parameters using FlattenedTensorBucket for efficiency""" + if not bucket_params: + return + + # Create FlattenedTensorBucket for efficient processing + bucket = FlattenedTensorBucket(named_tensors=bucket_params) + metadatas = bucket.get_metadata() + flattened_tensor = bucket.get_flattened_tensor() + + # Use the new flattened parameter update method + # If not available, fall back to individual parameter updates + try: + trainer.vllm_client.update_flattened_params(metadatas, flattened_tensor) + except AttributeError: + # Fallback to individual parameter updates + reconstructed = bucket.reconstruct_tensors() + for name, param in reconstructed.items(): + trainer.vllm_client.update_named_param(name, param) + + # Clean up + del bucket, metadatas, flattened_tensor diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index fc3dc2e614..8440a220bb 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -4,13 +4,15 @@ import threading import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict from typing import List, Optional, Union from urllib.parse import urlparse +import json import requests import torch from packaging import version -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from requests import ConnectionError from torch import nn from transformers.utils import is_torch_cuda_available @@ -19,6 +21,7 @@ from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig, RolloutOutput from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available +from .utils import peft_config_to_dict if is_vllm_available(): from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -242,6 +245,87 @@ def _update_single_server(i): if all_errors: raise RuntimeError(f'Multiple errors: {all_errors}') + def update_adapter_flattened_param(self, peft_config, metadatas, flattened_tensor): + """ + Adds a LoRA adapter to the model on all servers. + + Args: + lora_request: TensorLoRARequest object containing LoRA adapter information. + """ + errors = [None] * self.num_servers + peft_config = peft_config_to_dict(peft_config) + metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] + lora_int_id = int(time.time_ns() % 0x7FFFFFFF) + + def _update_single_server(i): + try: + data = { + 'lora_int_id': lora_int_id, + 'peft_config': { + **peft_config + }, + 'metadatas': metadatas, + } + + response = self.sessions[i].post( + f'{self.base_urls[i]}/update_adapter_flattened_param/', + json=data, + ) + if response.status_code != 200: + raise Exception(f'Server {i} update adapter failed: {response.text}') + + self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank) + self.pynccl_comms[i].group.barrier() + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors: {all_errors}') + + def update_flattened_params(self, metadatas, flattened_tensor): + """ + Updates model parameters using flattened tensor data. + + Args: + metadatas: List of FlattenedTensorMetadata objects + flattened_tensor: The flattened tensor containing all parameters + """ + errors = [None] * self.num_servers + metadatas = [m.model_dump() if hasattr(m, 'model_dump') else m.dict() for m in metadatas] + + def _update_single_server(i): + try: + data = { + 'metadatas': metadatas, + } + + response = self.sessions[i].post( + f'{self.base_urls[i]}/update_flattened_params/', + json=data, + ) + if response.status_code != 200: + raise Exception(f'Server {i} update flattened params failed: {response.text}') + + self.pynccl_comms[i].broadcast(flattened_tensor, src=self.pynccl_comms[i].rank) + self.pynccl_comms[i].group.barrier() + except Exception as e: + errors[i] = e + + with ThreadPoolExecutor(max_workers=self.num_servers) as executor: + futures = [executor.submit(_update_single_server, i) for i in range(self.num_servers)] + for future in futures: + future.result() + + all_errors = [e for e in errors if e is not None] + if all_errors: + raise RuntimeError(f'Multiple errors: {all_errors}') + def update_model_params(self, model: nn.Module): for name, param in model.named_parameters(): self.update_named_param(name, param.data) @@ -275,6 +359,7 @@ def get_engine_type(self): self.use_async_engine = result['engine_type'] == 'AsyncLLMEngine' self.enable_multi_turn = result.get('enable_multi_turn', False) self.use_gym_env = result.get('gym_env', False) + self.enable_lora = result.get('enable_lora', False) return result def close_communicator(self): diff --git a/tests/utils/test_rewards.py b/tests/utils/test_rewards.py new file mode 100644 index 0000000000..0f2e3da383 --- /dev/null +++ b/tests/utils/test_rewards.py @@ -0,0 +1,260 @@ +import unittest + + +class TestMathAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + try: + from swift.plugin.orm import MathAccuracy + cls.math_accuracy = MathAccuracy() + cls.available = True + except (ImportError, AssertionError) as e: + print(f'Warning: MathAccuracy not available: {e}') + cls.available = False + + def setUp(self): + if not self.available: + self.skipTest('MathAccuracy not available (math_verify not installed)') + + def test_pure_latex_format(self): + completions = ['The answer is \\boxed{42}'] + solutions = ['\\boxed{42}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_latex_in_long_text(self): + completions = ['After careful calculation, the final answer is \\boxed{100}'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_multiple_steps_with_boxed(self): + completions = [ + 'Let me solve step by step:\n' + '1. First we have x = 2\n' + '2. Then y = 3x = 6\n' + '3. Finally z = x + y = 8\n' + '\nFinal answer: \\boxed{8}' + ] + solutions = ['\\boxed{8}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_wrong_answer_no_tag(self): + completions = ['The answer is \\boxed{42}'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_batch_processing_no_tag(self): + completions = ['\\boxed{42}', '\\boxed{100}', '\\boxed{8}'] + solutions = ['\\boxed{42}', '\\boxed{100}', '\\boxed{8}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 3) + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 1.0) + self.assertEqual(rewards[2], 1.0) + + def test_answer_tag_with_plain_number(self): + completions = ['84'] + solutions = ['\\boxed{84}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_answer_tag_with_latex(self): + completions = ['\\boxed{100}'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_long_text_with_answer_tag(self): + completions = [ + 'Let me solve:\n' + 'Step 1: Calculate x = 10\n' + 'Step 2: Calculate y = 20\n' + 'Step 3: Sum = 30\n' + '\n54' + ] + solutions = ['\\boxed{54}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_answer_tag_with_complex_expression(self): + completions = ['\\frac{1}{2}'] + solutions = ['\\boxed{\\frac{1}{2}}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_solution_with_answer_tag(self): + completions = ['84'] + solutions = ['\\boxed{84}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_answer_tag_wrong_answer(self): + completions = ['42'] + solutions = ['\\boxed{100}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_mixed_batch_with_and_without_tags(self): + completions = [ + '\\boxed{42}', + '100', + 'The answer is \\boxed{8}', + ] + solutions = [ + '\\boxed{42}', + '\\boxed{100}', + '\\boxed{8}', + ] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 3) + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 1.0) + self.assertEqual(rewards[2], 1.0) + + def test_empty_solution(self): + completions = ['42'] + solutions = [''] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_malformed_latex(self): + completions = ['\\boxed{42'] + solutions = ['\\boxed{42}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 0.0) + + def test_answer_tag_with_extra_whitespace(self): + completions = [' 84 '] + solutions = ['\\boxed{84}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_multiple_answer_tags(self): + completions = ['42 Some text 100'] + solutions = ['\\boxed{42}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_real_world_example_from_user(self): + completions = [ + 'We are given a geometric sequence $\\{a_n\\}$ with:\n\n' + '- $a_3 = 2$\n- $a_5 = 6$\n\n' + 'We are to find $a_9$.\n\n---\n\n' + '### Step 1: Recall the formula\n\n' + '$$a_n = a_1 \\cdot r^{n-1}$$\n\n---\n\n' + '### Step 2: Use the given terms\n\n' + '$$a_3 = a_1 \\cdot r^2 = 2$$\n' + '$$a_5 = a_1 \\cdot r^4 = 6$$\n\n' + 'Divide equation (2) by equation (1):\n' + '$$r^2 = 3$$\n\n---\n\n' + '### Step 3: Find $a_9$\n\n' + '$$a_9 = a_1 \\cdot r^8 = \\frac{2}{3} \\cdot 81 = 54$$\n\n' + '### ✅ Final Answer:\n\n' + '54' + ] + solutions = ['\\boxed{54}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_equivalent_fractions(self): + completions = ['0.5'] + solutions = ['\\boxed{\\frac{1}{2}}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_different_forms_same_answer(self): + completions = ['2'] + solutions = ['\\boxed{\\sqrt{4}}'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_latex_inline_math_delimiters(self): + completions = ['84', '3'] + solutions = ['\n\n\\[\n\\boxed{84}\n\\]', 'Therefore, the value of \\(a^2 - a + 2\\) is \\(\\boxed{3}\\).'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 2) + self.assertEqual(rewards[0], 1.0) + self.assertEqual(rewards[1], 1.0) + + def test_latex_display_math_delimiters(self): + completions = ['100'] + solutions = ['\\[\\boxed{100}\\]'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + def test_mixed_latex_delimiters(self): + completions = ['\\(x = 42\\)'] + solutions = ['\\[\\boxed{x = 42}\\]'] + + rewards = self.math_accuracy(completions, solutions) + + self.assertEqual(len(rewards), 1) + self.assertEqual(rewards[0], 1.0) + + +if __name__ == '__main__': + unittest.main()