diff --git a/.claude/skills/verl-upgrade/SKILL.md b/.claude/skills/verl-upgrade/SKILL.md new file mode 100644 index 00000000000..8b688d6bbb2 --- /dev/null +++ b/.claude/skills/verl-upgrade/SKILL.md @@ -0,0 +1,29 @@ +--- +name: verl-upgrade +description: "Use when handling veRL version upgrades in Trinity, including three-way merge strategy, boundary checks, and retained customization decisions" +--- + +# veRL Upgrade Skill + +## Primary Sources + +1. `docs/agents/verl_upgrade/verl_upgrade_checklist.md` +2. A version-specific migration plan in `docs/agents/verl_upgrade/` matching `verl_*_migration_plan.md` + +## Workflow + +1. Read `docs/agents/verl_upgrade/verl_upgrade_checklist.md` first. +2. Confirm current version, target version, upgrade scope, and target files. +3. Generate or select the corresponding version-specific migration plan (`verl_*_migration_plan.md`). +4. During execution, review only detailed content for the target upgrade version. +5. Run three-way comparison against `trinity/trainer/verl/build//` snapshots. +6. Preserve Trinity responsibility boundaries. +7. Keep required Trinity customizations and remove redundant upstream copies. +8. Validate config-to-implementation wiring and output-field contracts. +9. Export remote GPU regression checklist after local static checks. + +## Hard Constraints + +1. Do not do whole-file overwrite from upstream. +2. Do not reintroduce reward/rollout/validation trainer logic unless responsibilities changed. +3. Keep checkpoint monitor/synchronizer collaboration where required. diff --git a/.codex/AGENTS.md b/.codex/AGENTS.md new file mode 100644 index 00000000000..758a7a74de5 --- /dev/null +++ b/.codex/AGENTS.md @@ -0,0 +1,17 @@ +# Codex Repository Guide + +## Canonical Documentation Roots + +- Agent documentation root: `docs/agents/` +- veRL upgrade knowledge root: `docs/agents/verl_upgrade/` + +## Agent Entry Files + +- Workspace-level guide: `AGENTS.md` +- Codex-specific guide: `.codex/AGENTS.md` +- Copilot instructions: `.github/instructions/` +- Claude skills: `.claude/skills/` + +## Repository Convention + +Treat `docs/agents/` as the single source of truth for agent-facing process and navigation documents. diff --git a/.github/instructions/verl-upgrade.instructions.md b/.github/instructions/verl-upgrade.instructions.md new file mode 100644 index 00000000000..027560d0808 --- /dev/null +++ b/.github/instructions/verl-upgrade.instructions.md @@ -0,0 +1,20 @@ +--- +applyTo: "trinity/trainer/verl/**/*.py,docs/agents/**/*.md" +description: "Use veRL migration guardrails and docs navigation when editing Trinity veRL upgrade related files" +--- + +# veRL Upgrade Instructions + +When the task is related to veRL upgrade/migration in Trinity: + +1. Read `docs/agents/verl_upgrade/verl_upgrade_checklist.md` first. +2. Use current version and target version to generate or select a version-specific plan in `docs/agents/verl_upgrade/` following `verl_*_migration_plan.md` naming. +3. During implementation/review execution, focus only on detailed content for the target upgrade version. +4. Preserve Trinity boundaries: + - Do not restore reward/rollout/validation main loops into Trinity trainer path by default. + - Avoid whole-file overwrite from upstream snapshots. +5. Prefer three-way merge reasoning: + - Trinity current vs old upstream baseline + - old upstream vs new upstream + - then current Trinity vs new upstream +6. If a subclass override is identical to upstream parent behavior, prefer removing the override. diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index 12649f11d5e..4440d5e8fa9 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20260310 + image: trinity-rft-unittest:20260407 cap_add: - SYS_PTRACE pull_policy: never @@ -15,8 +15,8 @@ services: - TRINITY_SFT_DATASET_PATH=/mnt/data - TRINITY_MODEL_PATH=/mnt/models/Qwen3-0.6B - TRINITY_API_MODEL_PATH=/mnt/models/Qwen3-1.7B - - TRINITY_VLM_MODEL_PATH=/mnt/models/Qwen2.5-VL-3B - - TRINITY_ALTERNATIVE_VLM_MODEL_PATH=/mnt/models/Qwen3-VL-2B-Instruct + - TRINITY_VLM_MODEL_PATH=/mnt/models/Qwen3.5-0.8B + - TRINITY_ALTERNATIVE_VLM_MODEL_PATH=/mnt/models/Qwen3.5-0.8B - VIRTUAL_ENV=/opt/venv working_dir: /workspace networks: @@ -34,7 +34,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20260310 + image: trinity-rft-unittest:20260407 cap_add: - SYS_PTRACE pull_policy: never @@ -44,7 +44,7 @@ services: - HF_HUB_DISABLE_PROGRESS_BARS=1 - TRINITY_CHECKPOINT_ROOT_DIR=/mnt/checkpoints - TRINITY_TASKSET_PATH=/mnt/data - - TRINITY_MODEL_PATH=/mnt/models/Qwen3-1.7B + - TRINITY_MODEL_PATH=/mnt/models/Qwen3-0.6B - VIRTUAL_ENV=/opt/venv working_dir: /workspace volumes: diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index f4111a1b59f..755627a9a54 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -113,15 +113,6 @@ jobs: fi fi - - name: Convert report.json time to ms - working-directory: trinity-${{ github.run_id }} - if: env.tests_run == 'true' || failure() - run: | - REPORT=report.json - if [ -f "$REPORT" ]; then - jq '(.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT" - fi - - name: Clean checkpoint dir working-directory: trinity-${{ github.run_id }}/.github/workflows/docker if: always() diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000..01a4dfe896b --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,18 @@ +# Multi-Agent Entry Guide + +This repository supports multiple coding agents. + +## Canonical Knowledge Location + +- Agent documentation root: `docs/agents/` +- veRL upgrade docs: `docs/agents/verl_upgrade/` + +## Agent-Specific Templates + +- Copilot instructions: `.github/instructions/verl-upgrade.instructions.md` +- Claude skill: `.claude/skills/verl-upgrade/SKILL.md` +- Codex template: `.codex/AGENTS.md` + +## Shared Rule + +All agents should follow this order for veRL upgrades: read checklist first, generate/select a version-specific migration plan from current->target version, then review only target-version detailed content during execution. diff --git a/benchmark/bench.py b/benchmark/bench.py index 0aede232659..07e64430739 100644 --- a/benchmark/bench.py +++ b/benchmark/bench.py @@ -210,6 +210,8 @@ def prepare_configs(args, rank, current_time): config["synchronizer"]["sync_offset"] = args.sync_offset if args.sync_style: config["synchronizer"]["sync_style"] = args.sync_style + if args.trainer_strategy: + config["trainer"]["trainer_strategy"] = args.trainer_strategy with open(config_path, "w") as f: yaml.dump(config, f, allow_unicode=True, sort_keys=False) @@ -320,5 +322,12 @@ def main(args): default=None, choices=[sync_style.value for sync_style in SyncStyle], ) + parser.add_argument( + "--trainer_strategy", + type=str, + default=None, + choices=["fsdp", "fsdp2", "megatron"], + help="Specify the trainer strategy.", + ) args = parser.parse_args() main(args) diff --git a/docs/README.md b/docs/README.md index dcda8248abf..4619657f839 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,5 +1,11 @@ # Trinity-RFT Documentation +## Documentation Layout + +- `docs/sphinx_doc/`: Sphinx source and build scripts for API/user docs. +- `docs/agents/`: Agent-oriented operational docs and migration knowledge. +- `docs/agents/verl_upgrade/`: Canonical veRL upgrade checklist and migration plans. + Please use the following commands to build sphinx doc of Trinity-RFT. ```shell @@ -20,3 +26,5 @@ cd docs/sphinx_doc ``` The sphinx doc is built in `docs/sphinx_doc/build/html`. + +For code-agent workflows (Copilot/Codex/Claude), start from `docs/agents/README.md`. diff --git a/docs/agents/README.md b/docs/agents/README.md new file mode 100644 index 00000000000..81a8e986f21 --- /dev/null +++ b/docs/agents/README.md @@ -0,0 +1,20 @@ +# Agent Knowledge Hub + +This directory stores agent-oriented documentation for upgrade workflows, runbooks, and operating constraints. + +## Structure + +- `verl_upgrade/`: veRL upgrade knowledge, including planning, checklist, and future postmortems. + +## How To Use + +1. Start from `verl_upgrade/verl_upgrade_checklist.md` before a version upgrade. +2. Use current version and target version to generate or select the corresponding plan in `verl_upgrade/` following `verl_*_migration_plan.md` naming. +3. During execution, review only the detailed content for the target upgrade version. +4. Add new version migration records in `verl_upgrade/` using versioned file names. + +## Naming Convention + +- Checklist: `verl_upgrade_checklist_.md` +- Plan: `verl__to__migration_plan.md` +- Postmortem: `verl_upgrade_postmortem_.md` diff --git a/docs/agents/verl_upgrade/verl_upgrade_checklist.md b/docs/agents/verl_upgrade/verl_upgrade_checklist.md new file mode 100644 index 00000000000..6ca1e85b089 --- /dev/null +++ b/docs/agents/verl_upgrade/verl_upgrade_checklist.md @@ -0,0 +1,128 @@ +# Pre-Upgrade Checklist for veRL + +This checklist is for quick verification before the next veRL upgrade in Trinity. + +## 1. Confirm Upgrade Scope + +1. Confirm the target veRL version. +2. Confirm the current Trinity baseline version. +3. Confirm the upstream snapshots for comparison have been generated under `trinity/trainer/verl/build//`. +4. Confirm this upgrade still focuses on the same 7 core migration files: + - `fsdp_workers.py` + - `dp_actor.py` + - `fsdp_checkpoint_manager.py` + - `megatron_workers.py` + - `megatron_actor.py` + - `megatron_checkpoint_manager.py` + - `verl_trainer.py` (corresponds to upstream `ray_trainer.py`) + +## 2. Prepare Three-Way Comparison + +1. For each file, compare all three sources together: + - Current Trinity file + - `build//...` + - `build//...` +2. Do not do whole-file overwrite. +3. Prioritize recording two categories of diffs: + - What Trinity added on top of the old-version baseline + - What upstream changed from old version to new version + +## 3. Verify Repository Responsibility Boundaries + +Before the next upgrade, verify these boundaries are still valid: + +1. Reward computation is not executed in Trinity `verl_trainer.py`. +2. Rollout is not executed in Trinity veRL trainer main loop. +3. Trainer-side validation is currently not implemented. +4. Trinity does not run upstream `RayPPOTrainer.fit()` directly. It follows the path defined in `trinity/trainer/trainer.py`: `prepare()`, `train_step()`, `save_checkpoint()`, `save_state_dict()`, and `upload_state_dict()`. + +If any boundary above changes, re-evaluate all following steps in this checklist. + +## 4. Upstream Logic That Should Not Be Accidentally Reintroduced + +Unless Trinity training responsibilities change, do not migrate these back by default: + +1. Full reward pipeline inside `fit()`. +2. Validation main flow. +3. Reward loop / async rollout manager. +4. `CheckpointEngineManager` orchestration logic that is only used by the upstream trainer main loop. + +## 5. Must-Check Configuration Wiring + +Before upgrade, verify whether these config items still need end-to-end wiring into implementation: + +1. `trust_remote_code` +2. `use_prefix_grouper` +3. `calculate_sum_pi_squared` +4. `sum_pi_squared_checkpointing` +5. Compatibility reads for `lora.rank` and `lora_rank` +6. `rollout_correction` +7. Compatibility structure for `reward.reward_model` and `reward_model` + +## 6. File-Level Priority Order + +Recommended processing order: + +1. `dp_actor.py` +2. `fsdp_workers.py` +3. `megatron_actor.py` +4. `megatron_workers.py` +5. `fsdp_checkpoint_manager.py` +6. `megatron_checkpoint_manager.py` +7. `verl_trainer.py` +8. `verl_config.py` + +Reason: the first four files define data fields and config wiring; the next three depend on these contracts being stable. + +## 7. Convergence Checks Required for Every File + +For every migration file, ask: + +1. Is this subclass implementation only a copy of parent-class code? +2. If it is fully identical to upstream parent implementation, can we delete the override directly? +3. If this is only a historical workaround, has upstream already absorbed it? +4. If this is a true Trinity-specific responsibility, has the reason to keep it been documented? + +## 8. Trinity Customizations Confirmed as Non-Removable + +1. Algorithm integration and loss composition logic in `dp_actor.py` and `megatron_actor.py`. +2. `CheckpointMonitor` / `Synchronizer` collaboration logic in `fsdp_checkpoint_manager.py` and `megatron_checkpoint_manager.py`. +3. `CheckpointMonitor`, Trinity custom `train_step()`, and state sync path in `verl_trainer.py`. +4. Trinity's independent experience pipeline and trainer scheduling relationship. + +## 9. Known Migration Sensitive Points + +1. `use_prefix_grouper` is an end-to-end chain from config to monkey patch to actor/ref worker. +2. `sum_pi_squared` must be passed from actor output all the way to the advantage consumer. +3. Megatron LoRA reference logprob follows the actor/no-adapter path, not the regular ref worker path. +4. To collect MFU in multimodal training, `images_seqlens` must be added to `batch.meta_info` in trainer. +5. Checkpoint manager cannot be replaced by whole-file upstream overwrite, otherwise Trinity async threads and monitoring logic are lost. + +## 10. Local Checks After Upgrade + +1. Run Problems check for all migrated files. +2. Run `python -m py_compile` uniformly for all migrated files. +3. Verify new config items are closed-loop across dataclass, defaults, and loading path. +4. Verify actor output fields match worker/trainer consumer fields. +5. Verify function signatures for checkpoint save and restore are consistent. + +## 11. Minimal Remote GPU Regression + +After local checks pass, run at least: + +1. FSDP single-step training. +2. Megatron single-step training. +3. Recompute path for old logprob / ref logprob. +4. Megatron reference logprob under LoRA. +5. Checkpoint save and restore. +6. Minimal regression for `use_prefix_grouper`. +7. Minimal regression for `calculate_sum_pi_squared`. + +## 12. Final Confirmation + +Before submitting the upgrade, reconfirm: + +1. Reward, rollout, and validation logic were not accidentally moved back into Trinity trainer. +2. Duplicate subclass implementations that are already identical to upstream were not kept. +3. Features previously trimmed by Trinity were not restored only for version alignment. +4. Documentation has been updated with newly added repository constraints and reasons for retained customizations. diff --git a/pyproject.toml b/pyproject.toml index 27e4d1b0ac4..92dc9d89d4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ ] requires-python = ">=3.10,<3.13" dependencies = [ - "verl==0.7.0", + "verl==0.7.1", "ray[default]>=2.50.0", "tensordict", "wandb", @@ -42,7 +42,7 @@ dependencies = [ "sortedcontainers", "word2number", "matplotlib", - "transformers>=4.51.0", + "transformers>=4.51.0,<=5.3.0", "datasets>=4.0.0", "typer>=0.20.1", ] @@ -52,18 +52,15 @@ trinity = "trinity.cli.launcher:main" [project.optional-dependencies] vllm = [ - "vllm>=0.10.2,<=0.18.0,!=0.11.0,!=0.12.0", - # v0.11 has bug when prefix-caching is enabled so we exclude it - # v0.12 has a huge performance regression so we exclude it - # v0.10.2 is the most stable version, but we allow up to 0.17.1 for new features - # For v0.16 to v0.18, the default dependencies require transformers < 5. + "vllm>=0.17.0,<=0.19.0", + # For v0.17 to v0.19, the default dependencies require transformers < 5. # We have patched vLLM to support transformers >= 5.0.0. ] data = [ "py-data-juicer>=1.4.3" ] agent = [ - "agentscope>=1.0.12" + "agentscope[tuner]>=1.0.18" ] rm_gallery = [ "rm-gallery>=0.1.5" @@ -81,13 +78,15 @@ dev = [ "viztracer", ] megatron = [ - "megatron-core[mlm]>=0.15.0", + "megatron-core[mlm]==0.16.1", # if you found "undefined symbol" error in transformer engine # reinstall it with --no-build-isolation and `--no-cache-dir` flag - "transformer_engine[pytorch]>=2.10.0", + "transformer_engine[pytorch]==2.13.0", # Install mbridge from main branch (unreleased version) - # "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612", + # "mbridge @ git+https://github.com/ISEEKYAN/mbridge.git@90c4633a6cdcfe5d29723d7b145d32f6f5e73303", + # or Megatron bridge, which is not tested + # "megatron-bridge==0.3.1", ] tinker = [ "tinker>=0.10.0; python_version >= '3.11'", @@ -113,6 +112,11 @@ flash_attn = [ "flash-attn>=2.8.1" ] +qwen3_5 = [ + "flash-linear-attention>=0.4.2", + "causal_conv1d>=1.6.0", +] + [tool.setuptools.packages.find] where = ["."] include = ["trinity*"] diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index a33026d98de..c76806ae98c 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -34,16 +34,14 @@ COPY . . # Install uv RUN pip install uv && uv venv /opt/venv --python=python3.12 -# Install minimal Trinity-RFT +# Install Trinity-RFT RUN . /opt/venv/bin/activate && \ - uv pip install -e.[vllm,mm,dev] - -# Install flash_attn and Megatron -RUN . /opt/venv/bin/activate && \ - uv pip install -e .[megatron] && \ - uv pip install flash_attn==2.8.1 --no-build-isolation && \ - uv pip install git+https://github.com/ISEEKYAN/mbridge.git@20e9ffbbe72ae7b1df83bfe1bc3c11f7382f2612 && \ - uv pip install transformer_engine[pytorch]==2.10.0 --no-build-isolation --no-cache-dir && \ + uv pip install -e.[mm,dev,tinker,data,agent] && \ + uv pip install vllm==0.19.0 && \ + uv pip install flash_attn==2.8.3 --no-build-isolation && \ + uv pip install -e .[megatron,qwen3_5] --no-build-isolation && \ + uv pip install git+https://github.com/ISEEKYAN/mbridge.git@90c4633a6cdcfe5d29723d7b145d32f6f5e73303 && \ + uv pip install transformers==5.3.0 && \ NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 \ uv pip install -v --no-build-isolation \ --config-settings="--build-option=--cpp_ext" \ diff --git a/scripts/migrate_from_verl/init_migration.py b/scripts/migrate_from_verl/init_migration.py new file mode 100644 index 00000000000..7fe789f5917 --- /dev/null +++ b/scripts/migrate_from_verl/init_migration.py @@ -0,0 +1,62 @@ +import argparse +import shutil +import subprocess +from pathlib import Path + + +def main(args): + # cd to verl repo dir and checkout the specified version + subprocess.run(["git", "fetch", "origin"], cwd=args.repo_dir, check=True) + subprocess.run(["git", "checkout", args.version], cwd=args.repo_dir, check=True) + + # copy files from verl repo to trinity repo with new names, and add them to git + trinity_path_prefix = Path(__file__).parent.parent.parent / "trinity" / "trainer" / "verl" + verl_path_prefix = Path(args.repo_dir) + file_maps = { + "fsdp_workers": ("verl", "workers", "fsdp_workers.py"), + "dp_actor": ("verl", "workers", "actor", "dp_actor.py"), + "fsdp_checkpoint_manager": ("verl", "utils", "checkpoint", "fsdp_checkpoint_manager.py"), + "megatron_workers": ("verl", "workers", "megatron_workers.py"), + "megatron_actor": ("verl", "workers", "actor", "megatron_actor.py"), + "megatron_checkpoint_manager": ( + "verl", + "utils", + "checkpoint", + "megatron_checkpoint_manager.py", + ), + "ray_trainer": ("verl", "trainer", "ppo", "ray_trainer.py"), + } + + for filename, path_parts in file_maps.items(): + src_path = verl_path_prefix / Path(*path_parts) + dst_path = trinity_path_prefix / f"{filename}-{args.version}.py" + dst_path.parent.mkdir(parents=True, exist_ok=True) + dst_path.write_bytes(src_path.read_bytes()) + subprocess.run(["git", "add", str(dst_path)], cwd=trinity_path_prefix, check=True) + print(f"Copied {src_path} to {dst_path}") + + print("Running pre-commit on the migrated files...") + subprocess.run( + ["pre-commit", "run", "--all-files"], cwd=Path(__file__).parent.parent.parent, check=True + ) + + # move the files to the build directory and reset git status to keep history clean + target_dir = trinity_path_prefix / "build" / f"{args.version}" + target_dir.mkdir(parents=True, exist_ok=True) + for filename, path_parts in file_maps.items(): + src_path = trinity_path_prefix / f"{filename}-{args.version}.py" + dst_path = target_dir / f"{filename}.py" + dst_path.parent.mkdir(parents=True, exist_ok=True) + subprocess.run(["git", "add", str(src_path)], cwd=trinity_path_prefix, check=True) + subprocess.run(["git", "reset", "HEAD", str(src_path)], cwd=trinity_path_prefix, check=True) + shutil.move(src_path, dst_path) + print(f"Moved {src_path} to {dst_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--repo-dir", required=True, help="Path to the verl repository") + parser.add_argument("--version", required=True, help="Version of the verl repository") + + args = parser.parse_args() + main(args) diff --git a/tests/buffer/formatter_test.py b/tests/buffer/formatter_test.py index e5381a7249f..d6cf5405255 100644 --- a/tests/buffer/formatter_test.py +++ b/tests/buffer/formatter_test.py @@ -313,12 +313,13 @@ def test_task_formatter(self): self.assertEqual(task.raw_task, sample) def test_multi_modal_sft_formatter(self): - IMAGE_TOKEN_ID = 151655 # only for Qwen2.5 VL, if changed, please update this test storage_config = get_unittest_dataset_config("geometry") formatter = FORMATTER.get("sft")( tokenizer_path=get_vision_language_model_path(), format_config=storage_config.format ) + self.assertIsNotNone(formatter.processor) + IMAGE_TOKEN_ID = formatter.processor.image_token_id ds = load_dataset(storage_config.path, split=storage_config.split) count = 0 for sample in ds: diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 7b475c8b01f..2e45804de4c 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -812,8 +812,11 @@ async def test_metric_calculation_with_non_repeatable_workflow( statuses, exps = await scheduler.get_results(batch_id=0) self.assertEqual(len(statuses), 2) self.assertEqual(len(exps), 1 * 4 * 3 + 1 * 5 * 8) - self.assertAlmostEqual(statuses[0].metrics[0]["run_metrics"], 2.0) # (1+2+3)/3 - self.assertAlmostEqual(statuses[1].metrics[0]["run_metrics"], 7.0) # (0+2+4+6+8+10+12+14)/8 + # (1+2+3)/3 = 2.0 + # (0+2+4+6+8+10+12+14)/8 = 7.0 + self.assertSetEqual( + set(status.metrics[0]["run_metrics"] for status in statuses), {2.0, 7.0} + ) async def test_over_rollout_min_wait(self): self.config.explorer.over_rollout.ratio = 0.5 diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 4524a9d138f..70d55c36874 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -43,6 +43,7 @@ explorer: seed: 42 trainer: trainer_type: verl + trainer_strategy: fsdp save_interval: 100 save_hf_checkpoint: never grad_clip: 1.0 @@ -55,6 +56,11 @@ trainer: checkpoint: load_contents: ['model', 'optimizer', 'extra'] save_contents: ['model', 'optimizer', 'extra'] + megatron: + tensor_model_parallel_size: 2 + ref: + megatron: + tensor_model_parallel_size: 2 critic: optim: lr: 1e-5 @@ -64,6 +70,8 @@ trainer: checkpoint: load_contents: ['model', 'optimizer', 'extra'] save_contents: ['model', 'optimizer', 'extra'] + megatron: + tensor_model_parallel_size: 2 trainer: max_actor_ckpt_to_keep: 1 max_critic_ckpt_to_keep: 1 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index d6e77a363ba..17c333f8695 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -107,10 +107,7 @@ def test_trainer(self): self.config.check_and_update() _trainer_config = self.config.trainer.trainer_config if self.strategy == "megatron": - _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 - _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 _trainer_config.critic.strategy = "megatron" - _trainer_config.critic.megatron.tensor_model_parallel_size = 2 _trainer_config.trainer.max_actor_ckpt_to_keep = 2 _trainer_config.trainer.max_critic_ckpt_to_keep = 2 both(self.config) @@ -151,8 +148,6 @@ def test_trainer(self): hf_dir_step_8 = os.listdir(os.path.join(checkpoint_step_8, "actor", "huggingface")) self.assertGreater(len(hf_dir_step_4), 0) self.assertGreater(len(hf_dir_step_8), 0) - self.assertNotIn("model.safetensors", hf_dir_step_4) - self.assertNotIn("model.safetensors", hf_dir_step_8) # test checkpoint convert convert(self.config.checkpoint_job_dir) hf_dir_step_4 = os.listdir(os.path.join(checkpoint_step_4, "actor", "huggingface")) @@ -616,10 +611,7 @@ def test_fully_async_mode(self): trainer_config.check_and_update() if self.strategy == "megatron": _trainer_config = trainer_config.trainer.trainer_config - _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 - _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 _trainer_config.critic.strategy = "megatron" - _trainer_config.critic.megatron.tensor_model_parallel_size = 2 explorer1_config = deepcopy(config) explorer1_config.trainer = deepcopy(trainer_config.trainer) @@ -802,10 +794,6 @@ def setUp(self): def test_trainer(self): # noqa: C901 """Test the checkpoint saving.""" _trainer_config = self.config.trainer.trainer_config - if self.strategy == "megatron": - _trainer_config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2 - _trainer_config.actor_rollout_ref.ref.megatron.tensor_model_parallel_size = 2 - _trainer_config.critic.megatron.tensor_model_parallel_size = 2 stop_event = multiprocessing.Event() trainer_process = multiprocessing.Process(target=run_both, args=(self.config, stop_event)) diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index e6276d5f7fa..6ed1120ee9a 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -281,11 +281,6 @@ def get_verl_checkpoint_info( # modified from verl/model_merger/fsdp_model_merger.py def load_fsdp_state_dict_from_verl_checkpoint(checkpoint_path: str) -> dict: # noqa: C901 """Load state dict from a Verl checkpoint.""" - # start of patch for verl to support transformers v5 - from trinity.trainer.verl import patch_for_transformers_v5 - - patch_for_transformers_v5() - # end of patch for verl to support transformers v5 from verl.model_merger.base_model_merger import ModelMergerConfig from verl.model_merger.fsdp_model_merger import FSDPModelMerger @@ -324,18 +319,14 @@ def load_huggingface_state_dict(checkpoint_path: str): def get_megatron_converter(checkpoint_path: str): - # start of patch for verl to support transformers v5 - from trinity.trainer.verl import patch_for_transformers_v5 - - patch_for_transformers_v5() - # end of patch for verl to support transformers v5 - import builtins from contextlib import contextmanager from verl.model_merger.base_model_merger import ModelMergerConfig from verl.model_merger.megatron_model_merger import MegatronModelMerger + from trinity.trainer.verl.utils import patch_rope_theta_in_hf_config + # modified from verl/model_merger/megatron_model_merger.py class MegatronStateDictConverter(MegatronModelMerger): def __init__(self, config: ModelMergerConfig): @@ -353,10 +344,7 @@ def __init__(self, config: ModelMergerConfig): torch.distributed.get_world_size = original_get_world_size # start of patch for verl to support transformers v5 - if not hasattr(self.hf_config, "rope_theta"): - rope_theta = self.hf_config.rope_parameters.get("rope_theta", None) - if rope_theta is not None: - setattr(self.hf_config, "rope_theta", rope_theta) + patch_rope_theta_in_hf_config(self.hf_config) # end of patch for verl to support transformers v5 @contextmanager diff --git a/trinity/common/models/vllm_patch/worker_patch.py b/trinity/common/models/vllm_patch/worker_patch.py index a1bde2b5235..9a84c16fa80 100644 --- a/trinity/common/models/vllm_patch/worker_patch.py +++ b/trinity/common/models/vllm_patch/worker_patch.py @@ -13,10 +13,10 @@ def patch_vllm_prompt_logprobs(model_runner: GPUModelRunner): # noqa: C901 """Patch vLLM model runner to support prompt logprobs extraction.""" version = get_vllm_version() - if version < parse_version("0.10.2") or version > parse_version("0.18.0"): + if version < parse_version("0.10.2") or version > parse_version("0.19.0"): raise ValueError( f"Unsupported vllm version: {vllm.__version__}. " - "This patch requires vllm version >= 0.10.2, <= 0.18.0." + "This patch requires vllm version >= 0.10.2, <= 0.19.0." ) is_v0102 = version == parse_version("0.10.2") diff --git a/trinity/common/patch/qwen3_5.py b/trinity/common/patch/qwen3_5.py index 13a308dc34b..4a7fef18368 100644 --- a/trinity/common/patch/qwen3_5.py +++ b/trinity/common/patch/qwen3_5.py @@ -9,7 +9,6 @@ BaseModelOutputWithPast, Cache, Qwen3_5CausalLMOutputWithPast, - Qwen3_5DynamicCache, Qwen3_5ForConditionalGeneration, Qwen3_5ModelOutputWithPast, TransformersKwargs, @@ -76,9 +75,7 @@ def ulysses_gated_delta_net_forward_decorator(func): @wraps(func) def wrapper( hidden_states: torch.Tensor, - cache_params: Qwen3_5DynamicCache | None = None, - cache_position: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, + **kwargs, ): from verl.utils.ulysses import ( gather_outputs_and_unpad, @@ -90,7 +87,7 @@ def wrapper( if ulysses_sp_size > 1: hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1) - output = func(hidden_states, cache_params, cache_position, attention_mask) + output = func(hidden_states, **kwargs) if ulysses_sp_size > 1: group = get_ulysses_sequence_parallel_group() @@ -120,6 +117,8 @@ def qwen35_text_forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache + past_key_values = Qwen3_5DynamicCache(config=self.config) if cache_position is None: diff --git a/trinity/service/data_juicer/server/utils.py b/trinity/service/data_juicer/server/utils.py index 0a8d6cc7844..a6fd5b79b5c 100644 --- a/trinity/service/data_juicer/server/utils.py +++ b/trinity/service/data_juicer/server/utils.py @@ -116,11 +116,11 @@ def _parse_task_pipeline_config(config: DJConfig) -> Namespace: def group_scores(dataset: Dataset) -> Dataset: - if Fields.stats not in dataset.features: + if Fields.stats not in dataset.features or len(dataset) == 0: return dataset # for perplexity, normalize them with the max value. stats_min_max = {} - for stats in dataset.features[Fields.stats]: + for stats in dataset[Fields.stats][0]: all_stats = [ sample[Fields.stats][stats] for sample in dataset.data if Fields.stats in sample ] diff --git a/trinity/trainer/verl/__init__.py b/trinity/trainer/verl/__init__.py index 547e8ca61eb..e69de29bb2d 100644 --- a/trinity/trainer/verl/__init__.py +++ b/trinity/trainer/verl/__init__.py @@ -1,18 +0,0 @@ -import sys - -import transformers - - -# start of patch for verl to support transformers v5 -def patch_for_transformers_v5(): - if not hasattr(sys.modules["transformers"], "AutoModelForVision2Seq"): - setattr( - sys.modules["transformers"], - "AutoModelForVision2Seq", - transformers.AutoModelForImageTextToText, - ) - sys.modules["transformers"].__all__.append("AutoModelForVision2Seq") - - -patch_for_transformers_v5() -# end of patch for verl to support transformers v5 diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 419001e2265..c13d25a0bea 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -15,32 +15,19 @@ # limitations under the License. """ Single Process Actor. -Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/dp_actor.py +Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/workers/actor/dp_actor.py """ import logging import os import torch -import verl.utils.torch_functional as verl_F from torch import nn from verl import DataProto -from verl.utils.attention_utils import ( - index_first_axis, - pad_input, - rearrange, - unpad_input, -) from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch -from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import ( - gather_outputs_and_unpad, - ulysses_pad, - ulysses_pad_and_slice_inputs, -) +from verl.utils.seqlen_balancing import prepare_dynamic_batch from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN @@ -82,6 +69,7 @@ def update_policy(self, data: DataProto): # noqa: C901 # temperature must be in the data.meta_info to avoid silent error temperature = data.meta_info["temperature"] + pad_token_id = data.meta_info.get("pad_token_id", 0) select_keys = [ "input_ids", "position_ids", @@ -89,6 +77,8 @@ def update_policy(self, data: DataProto): # noqa: C901 "responses", "response_mask", ] + if self.use_prefix_grouper and "prompts" in data.batch.keys(): + select_keys.append("prompts") select_keys.extend(self.policy_loss_fn.select_keys) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") @@ -98,6 +88,8 @@ def update_policy(self, data: DataProto): # noqa: C901 has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + if self.use_prefix_grouper and "uid" in data.non_tensor_batch.keys(): + non_tensor_select_keys.append("uid") data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) @@ -137,7 +129,11 @@ def update_policy(self, data: DataProto): # noqa: C901 for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) micro_batch_metrics = {} - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + model_inputs = { + **micro_batch.batch, + **micro_batch.non_tensor_batch, + "pad_token_id": pad_token_id, + } response_mask = model_inputs["response_mask"] loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") @@ -236,315 +232,3 @@ def update_policy(self, data: DataProto): # noqa: C901 append_to_dict(metrics, mini_batch_metrics) self.actor_optimizer.zero_grad() return metrics - - # TODO: remove this method after upgrading verl - def _forward_micro_batch( # type: ignore # noqa: C901 - self, micro_batch, temperature, calculate_entropy=False - ) -> dict[str, torch.Tensor]: - """ - Returns: - dict[str, torch.Tensor]: a dict containing keys - - ``entropy``: tensor of shape [batch_size, response_length]. torch.float32. - - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. - """ - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch.keys(): - from verl.utils.model import extract_multi_modal_inputs - - multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) - - with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - entropy = None - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) - - if self.use_remove_padding: - input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis( - rearrange(position_ids, "c b s ... -> (b s) c ..."), indices - ) - .transpose(0, 1) - .unsqueeze(1) - ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - is_mask_all_zero = attention_mask.sum() == 0 - if is_mask_all_zero: - input_ids_rmpad = torch.zeros( - (1, self.ulysses_sequence_parallel_size), - device=input_ids.device, - dtype=input_ids.dtype, - ) - if position_ids.dim() == 3: - position_ids_rmpad = torch.zeros( - (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), - device=position_ids.device, - dtype=position_ids.dtype, - ) - else: - position_ids_rmpad = torch.zeros( - (1, self.ulysses_sequence_parallel_size), - device=position_ids.device, - dtype=position_ids.dtype, - ) - - if "image_bound" in multi_modal_inputs: - from verl.utils.dataset.vision_utils import ( - process_multi_modal_inputs_for_minicpmo, - ) - - multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( - input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs - ) - - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll( - input_ids_rmpad, shifts=-1, dims=1 - ) # (1, total_nnz) - - # pad and slice the inputs if sp > 1 - if self.use_ulysses_sp: - is_vlm_model = hasattr( - getattr(self.actor_module, "module", self.actor_module).config, - "vision_config", - ) - if is_vlm_model: - # vlm model's inputs will be sliced after embedding - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - else: - ( - input_ids_rmpad, - position_ids_rmpad, - pad_size, - ) = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, - position_ids_rmpad=None, - sp_size=self.ulysses_sequence_parallel_size, - ) - - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( - 0 - ) # ((total_nnz / sp) + pad) - - # only pass input_ids and position_ids to enable flash_attn_varlen - extra_args = {} - if self.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = self.actor_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if self.use_fused_kernels: - log_probs = output.log_probs.squeeze(0) # (total_nnz,) - entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) - - else: - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - logits_rmpad.div_(temperature) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - inplace_backward = True - if calculate_entropy: - inplace_backward = False - log_probs = logprobs_from_logits( - logits=logits_rmpad, - labels=input_ids_rmpad_rolled, - inplace_backward=inplace_backward, - ) - - # compute entropy - if calculate_entropy: - if not self.config.entropy_checkpointing: - entropy_rmpad = self.compute_entropy_from_logits( - logits_rmpad - ) # ((total_nnz / sp) + pad) - else: - entropy_rmpad = torch.utils.checkpoint.checkpoint( - self.compute_entropy_from_logits, logits_rmpad - ) - - # gather log_prob if sp > 1 - if self.use_ulysses_sp: - # gather and unpad for the ulysses sp - log_probs = gather_outputs_and_unpad( - log_probs, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - if calculate_entropy: - entropy_rmpad = gather_outputs_and_unpad( - entropy_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - - if is_mask_all_zero: - log_probs = log_probs[:0] - if calculate_entropy: - entropy_rmpad = entropy_rmpad[:0] - - # pad back to (bsz, seqlen) - if calculate_entropy: - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - - # only return response part: - if calculate_entropy: - entropy = full_entropy.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[ - :, -response_length - 1 : -1 - ] # (bsz, response_length) - - else: # not using rmpad and no ulysses sp - extra_args = {} - if self.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = self.actor_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if self.use_fused_kernels: - log_probs = output.log_probs[:, -response_length - 1 : -1] - entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) - - else: - logits = output.logits - - logits.div_(temperature) - logits = logits[ - :, -response_length - 1 : -1, : - ] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch["responses"]) - if calculate_entropy: - if not self.config.entropy_checkpointing: - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - else: - entropy = torch.utils.checkpoint.checkpoint( - verl_F.entropy_from_logits, logits - ) - - outputs = {"log_probs": log_probs} - if calculate_entropy: - outputs["entropys"] = entropy - return outputs - - # TODO: remove this method after upgrading verl - @GPUMemoryLogger(role="dp actor", logger=logger) - def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> dict[str, torch.Tensor]: - """Compute the log probability of the responses given input_ids, attention_mask and position_ids - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the - concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. - - Returns: - dict[str, torch.Tensor]: a dict containing keys - - ``log_probs``: tensor of shape [batch_size, response_length]. torch.float32. - - ``entropys``: tensor of shape [batch_size, response_length]. torch.float32. - """ - # set to eval - self.actor_module.eval() - - micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info[ - "temperature" - ] # temperature must be in the data.meta_info to avoid silent error - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - if use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) - else: - micro_batches = data.split(micro_batch_size) - - log_probs_lst = [] - entropy_lst = [] - for micro_batch in micro_batches: - micro_batch = micro_batch.to(get_device_id()) - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - outputs = self._forward_micro_batch( - model_inputs, temperature=temperature, calculate_entropy=calculate_entropy - ) - log_probs_lst.append(outputs["log_probs"]) - if calculate_entropy: - entropy_lst.append(outputs["entropys"]) - - log_probs = torch.concat(log_probs_lst, dim=0) - if calculate_entropy: - entropys = torch.concat(entropy_lst, dim=0) - - if use_dynamic_bsz: - log_probs = restore_dynamic_batch(log_probs, batch_idx_list) - if calculate_entropy: - entropys = restore_dynamic_batch(entropys, batch_idx_list) - - outputs = {"log_probs": log_probs} - if calculate_entropy: - outputs["entropys"] = entropys - return outputs diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 483a20606e7..1051a4a169c 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -13,7 +13,7 @@ # limitations under the License. """ FSDP Checkpoint Manager. -Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/utils/checkpoint/fsdp_checkpoint_manager.py +Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/utils/checkpoint/fsdp_checkpoint_manager.py """ import json @@ -83,6 +83,16 @@ def __init__(self, *args, ray_namespace: str = "", trust_remote_code: bool = Fal self.latest_hf_model_save_step = None self.latest_tokenizer_save_step = None + def _is_latest_registered_checkpoint(self, path: str) -> bool: + if not self.previous_saved_paths: + return False + return os.path.abspath(path) == os.path.abspath(self.previous_saved_paths[-1]) + + def register_checkpoint(self, new_path: str, max_ckpt_to_keep: Optional[int] = None): + if self._is_latest_registered_checkpoint(new_path): + return + super().register_checkpoint(new_path, max_ckpt_to_keep) + def _upload_state_dict(self, state_dict: Union[dict, None], global_step: int): """ Internal method to upload a state dict to the Synchronizer actor. @@ -418,16 +428,16 @@ def save_checkpoint( """ Modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint - Saves the model checkpoint to disk, optionally uploads it to a remote Synchronizer, - and uses background threads to prevent blocking the main training loop. + Saves the model checkpoint to disk and uses background threads to prevent + blocking the main training loop. Main improvements over the base class: - Uses separate threads for saving model/optimizer/extras. - - Implements synchronization with a remote actor. If the model is not trained (`global_step == 0`) or continues from a breakpoint, `Synchonizer` will be notified and the model will not be saved. + - Registers background work with CheckpointMonitor so trainer-side coordination + can wait on state-dict and checkpoint completion. Args: local_path (str): Local directory path to save the checkpoint. - hdfs_path (str, optional): HDFS path for saving the checkpoint (not implemented here). global_step (int): Current training step. max_ckpt_to_keep (int, optional): Maximum number of checkpoints to keep locally. save_as_hf (bool): Whether to force save the model in Hugging Face format. @@ -437,24 +447,15 @@ def save_checkpoint( # record the previous global step self.previous_global_step = global_step - local_path = local_mkdir_safe(local_path) - # remove previous local_path, only rank 0 should do this - if ( - self.rank == 0 - and max_ckpt_to_keep - and isinstance(max_ckpt_to_keep, int) - and max_ckpt_to_keep > 0 - and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore - and local_path != self.previous_saved_paths[-1] # type: ignore - ): # last step may save twice - keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore - self.logger.info( - "Checkpoint manager is removing previous checkpoints at " - + str(self.previous_saved_paths[:keep_start]) # type: ignore - ) - self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore - self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore + skip_retention_rotation = self.rank == 0 and self._is_latest_registered_checkpoint( + local_path + ) + + if self.rank == 0 and not skip_retention_rotation: + self.ensure_checkpoint_capacity(max_ckpt_to_keep) + + local_path = local_mkdir_safe(local_path) torch.distributed.barrier() @@ -504,10 +505,8 @@ def save_checkpoint( checkpoint_thread_count=checkpoint_thread_count, ) ) - if ( - len(self.previous_saved_paths) == 0 or local_path != self.previous_saved_paths[-1] - ): # last step may save twice - self.previous_saved_paths.append(local_path) + if self.rank == 0: + self.register_checkpoint(local_path, max_ckpt_to_keep) def wait_on_save_thread(self) -> None: """ diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 29c087b72f0..6121080d4cb 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -13,13 +13,12 @@ # limitations under the License. """ The main entry point to run the PPO algorithm. -Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/fsdp_workers.py +Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/workers/fsdp_workers.py """ import datetime import json import os -import sys import warnings from contextlib import contextmanager from dataclasses import asdict @@ -37,18 +36,6 @@ from torch.distributed.fsdp import FlatParameter from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import FSDP_PREFIX - -# start of patch for verl to support transformers v5 -if not hasattr(sys.modules["transformers"], "AutoModelForVision2Seq"): - setattr( - sys.modules["transformers"], - "AutoModelForVision2Seq", - sys.modules["transformers"].AutoModelForImageTextToText, - ) - sys.modules["transformers"].__all__.append("AutoModelForVision2Seq") -# end of patch for verl to support transformers v5 - - from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import ( @@ -310,6 +297,7 @@ def _build_model_optimizer( # noqa: C901 use_liger=False, role="actor", enable_activation_offload=False, + use_prefix_grouper=False, use_tiled_mlp=False, tiled_mlp_shards=4, ): @@ -433,6 +421,7 @@ def _build_model_optimizer( # noqa: C901 ulysses_sp_size=self.ulysses_sequence_parallel_size, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend, + use_prefix_grouper=use_prefix_grouper, use_tiled_mlp=use_tiled_mlp, tiled_mlp_shards=tiled_mlp_shards, ) @@ -644,6 +633,7 @@ def init_model(self): use_shm = self.config.model.get("use_shm", False) use_fused_kernels = self.config.model.get("use_fused_kernels", False) trust_remote_code = self.config.model.get("trust_remote_code", False) + use_prefix_grouper = self.config.actor.get("use_prefix_grouper", False) if self._is_actor: # we need the model for actor @@ -675,6 +665,7 @@ def init_model(self): use_liger=self.config.model.get("use_liger", False), role="actor", enable_activation_offload=self.config.model.get("enable_activation_offload", False), + use_prefix_grouper=use_prefix_grouper, use_tiled_mlp=use_tiled_mlp, tiled_mlp_shards=tiled_mlp_shards, ) @@ -698,6 +689,7 @@ def init_model(self): with open_dict(self.config.actor): self.config.actor.use_remove_padding = use_remove_padding self.config.actor.use_fused_kernels = use_fused_kernels + self.config.actor.use_prefix_grouper = use_prefix_grouper self.actor = DataParallelPPOActor( config=self.config.actor, actor_module=self.actor_module_fsdp, @@ -731,6 +723,7 @@ def init_model(self): trust_remote_code=trust_remote_code, use_liger=self.config.model.get("use_liger", False), role="ref", + use_prefix_grouper=use_prefix_grouper, use_tiled_mlp=ref_use_tiled_mlp, tiled_mlp_shards=ref_tiled_mlp_shards, )[0] @@ -738,6 +731,8 @@ def init_model(self): with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding self.config.ref.use_fused_kernels = use_fused_kernels + if use_prefix_grouper: + self.config.ref.use_prefix_grouper = use_prefix_grouper self.ref_policy = DataParallelPPOActor( config=self.config.ref, actor_module=self.ref_module_fsdp ) @@ -868,6 +863,7 @@ def update_actor(self, data: DataProto): data = data.to( "cpu" ) # data will to device with each micro batch on actor.update_policy + data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) # perform training with Timer(name="update_policy", logger=None) as timer: @@ -930,6 +926,7 @@ def compute_log_prob(self, data: DataProto): data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) # perform recompute log_prob calculate_entropy = not is_lora with self.ulysses_sharding_manager: @@ -941,6 +938,8 @@ def compute_log_prob(self, data: DataProto): tensors = {"ref_log_prob": outputs["log_probs"]} if calculate_entropy: tensors["entropys"] = outputs["entropys"] + if "sum_pi_squared" in outputs: + tensors["sum_pi_squared"] = outputs["sum_pi_squared"] output = DataProto.from_dict( tensors=tensors, meta_info={"temperature": self.config.rollout.temperature}, @@ -977,6 +976,7 @@ def compute_ref_log_prob(self, data: DataProto): data.meta_info["temperature"] = self.config.rollout.temperature data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) with self.ulysses_sharding_manager: data = data.to( "cpu" diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index f146a972102..ed053ecb1ca 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -18,7 +18,7 @@ Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer -Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/actor/megatron_actor.py +Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/workers/actor/megatron_actor.py """ from functools import partial @@ -42,7 +42,7 @@ vocab_parallel_entropy, vocab_parallel_log_probs_from_logits, ) -from verl.utils.megatron_utils import unwrap_model +from verl.utils.megatron_utils import get_megatron_mtp_loss, unwrap_model from verl.utils.profiler import GPUMemoryLogger from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import rearrange_micro_batches @@ -126,6 +126,7 @@ def forward_backward_batch( # noqa: C901 assert ( max_token_len is not None ), "max_token_len must be set when use_dynamic_bsz is True" + dp_group = mpu.get_data_parallel_group() vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = ( @@ -135,6 +136,7 @@ def forward_backward_batch( # noqa: C901 batch=mini_batch.batch, num_batches_divided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len, + dp_group=dp_group, ) assert ( len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0 @@ -144,7 +146,9 @@ def forward_backward_batch( # noqa: C901 ) else: micro_batches, indices = rearrange_micro_batches( - batch=mini_batch.batch, max_token_len=max_token_len + batch=mini_batch.batch, + max_token_len=max_token_len, + dp_group=dp_group, ) total_seqlen = max_token_len else: @@ -393,6 +397,7 @@ def logits_processor(logits, label, label_mask): logits_processor=logits_processor, logits_processor_args=logits_processor_args, data_format="thd" if self.config.megatron.use_remove_padding else "bshd", + mtp_config=None if forward_only else getattr(self, "mtp_config", None), ) if forward_only: @@ -475,6 +480,13 @@ def logits_processor(logits, label, label_mask): ) self.mini_layer_topk_idx_list = [] + if ( + not forward_only + and getattr(self, "mtp_config", None) is not None + and self.mtp_config.enable_train + ): + losses_reduced["mtp_losses"] = [get_megatron_mtp_loss(n_micro_batch)] + return losses_reduced @GPUMemoryLogger(role="megatron actor", logger=logger) @@ -491,8 +503,6 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: """ metrics = {} - if self.use_torch_profiler and self.prof and self.prof.enable: - self.prof.start() for data in dataloader: if self.config.router_replay.mode in ["R2", "R3"]: RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) @@ -521,6 +531,10 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: max_token_len=max_token_len, mini_batch_size=self.config.ppo_mini_batch_size, ) + mtp_losses = metric_micro_batch.get("mtp_losses", None) + if mtp_losses is not None: + for mtp_metrics_dict in mtp_losses: + append_to_dict(metrics, mtp_metrics_dict) metric_micro_batch = metric_micro_batch["output"] for metric in metric_micro_batch: # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask @@ -537,17 +551,13 @@ def update_policy(self, dataloader: Iterable[DataProto]) -> dict: pass else: raise NotImplementedError - if self.use_torch_profiler and self.prof and self.prof.enable: - self.prof.step() if self.config.router_replay.mode in ["R2", "R3"]: RouterReplay.clear_global_router_replay_action() RouterReplay.clear_global_indices() # add empty cache after each compute - if self.use_torch_profiler and self.prof and self.prof.enable: - self.prof.stop_and_save() - self.prof.stop_trace() + self.actor_optimizer.zero_grad() get_torch_device().empty_cache() return metrics diff --git a/trinity/trainer/verl/megatron_checkpoint_manager.py b/trinity/trainer/verl/megatron_checkpoint_manager.py index c95a1c2ac70..159aabc7cde 100644 --- a/trinity/trainer/verl/megatron_checkpoint_manager.py +++ b/trinity/trainer/verl/megatron_checkpoint_manager.py @@ -13,36 +13,28 @@ # limitations under the License. """ Megatron Checkpoint Manager. -Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/utils/checkpoint/megatron_checkpoint_manager.py +Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/utils/checkpoint/megatron_checkpoint_manager.py """ +import inspect import json import os from collections.abc import Callable from dataclasses import asdict +from typing import Optional -import megatron import ray import torch import torch.distributed -from megatron.core import dist_checkpointing, mpu from megatron.core.transformer.enums import AttnBackend -from packaging import version from transformers import GenerationConfig from verl.utils.checkpoint.megatron_checkpoint_manager import ( MegatronCheckpointManager as OldMegatronCheckpointManager, ) -from verl.utils.checkpoint.megatron_checkpoint_manager import ( - is_non_local, - load_dist_checkpointing, - logger, -) +from verl.utils.checkpoint.megatron_checkpoint_manager import logger from verl.utils.fs import local_mkdir_safe from verl.utils.logger import log_with_rank -from verl.utils.megatron.dist_checkpointing import ( - FullyParallelSaveStrategyWrapper, - get_default_save_sharded_strategy, -) +from verl.utils.megatron.dist_checkpointing import save_dist_checkpointing from verl.utils.megatron_utils import ( get_dist_checkpoint_path, get_hf_model_checkpoint_path, @@ -53,42 +45,6 @@ from trinity.trainer.verl.verl_trainer import CheckpointMonitor from trinity.utils.log import get_logger -mcore_ge_014 = version.parse(megatron.core.__version__) >= version.parse("0.14.0") -if not mcore_ge_014: - logger.warning( - "Detected megatron.core %s, recommend upgrading to >= 0.14.0 for better checkpoint compatibility", - megatron.core.__version__, - ) - - -# TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 -def save_dist_checkpointing( - sharded_state_dict, - ckpt_path, - async_save=False, - content_metadata=None, -): - validate_sharding_integrity = True - # Get checkpointing strategies - save_strategy = get_default_save_sharded_strategy("torch_dist") - save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) - - # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.14.0/megatron/core/optimizer/distrib_optimizer.py#L1109-L1123 - mcore_ge_014 = version.parse(megatron.core.__version__) >= version.parse("0.14.0") - # Save model sharded state dicts - save_kwargs = dict( - sharded_strategy=save_strategy, - async_sharded_save=async_save, - validate_access_integrity=validate_sharding_integrity, - ) - if content_metadata is not None: - if mcore_ge_014: - save_kwargs["content_metadata"] = content_metadata - - return dist_checkpointing.save(sharded_state_dict, ckpt_path, **save_kwargs) - class MegatronCheckpointManager(OldMegatronCheckpointManager): """ @@ -117,257 +73,17 @@ def __init__( self.latest_extra_state_save_step = None self.latest_hf_model_save_step = None - # TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 - def generate_state_dict( - self, - generate_model: bool = True, - generate_optimizer: bool = True, - generate_extra: bool = True, - is_loading: bool = False, - metadata: dict | None = None, - ): - # For save dist checkpointing - state_dict = {} - base_metadata = metadata or self._build_sharded_state_dict_metadata() - - # Should always generate model state dict - # All ranks Save Model to reduce memory pressure - # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure - for vpp_rank, model in enumerate(self.model): - if len(self.model) > 1: - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - key = f"model{vpp_rank}" if len(self.model) > 1 else "model" - else: - key = "model" - if hasattr(model, "module"): - model = model.module - - # GPTModel's sharded_state_dict function when having mtp requires metadata['dp_cp_group'] - model_metadata = dict(base_metadata) - model_metadata["dp_cp_group"] = mpu.get_data_parallel_group(with_context_parallel=True) - kwargs = {"metadata": model_metadata} - state_dict[key] = model.sharded_state_dict(**kwargs) - - # Optimizer State Dict - if generate_optimizer: - torch.distributed.barrier() - sharded_state_dict_kwargs = {"is_loading": is_loading} - if base_metadata is not None: - # https://github.com/NVIDIA/Megatron-LM/blob/core_v0.14.0/megatron/core/optimizer/distrib_optimizer.py#L1109-L1123 - if mcore_ge_014: - sharded_state_dict_kwargs["metadata"] = base_metadata - optimizer_sharded_states = self.optimizer.sharded_state_dict( - state_dict, **sharded_state_dict_kwargs - ) - state_dict["optimizer"] = optimizer_sharded_states - - if self.lr_scheduler is not None: - lr_state_dict = self.lr_scheduler.state_dict() - state_dict["lr_scheduler"] = lr_state_dict - - if not generate_model: - state_dict.pop("model", None) - - # RNG States State Dict - if generate_extra: - torch.distributed.barrier() - rng_state = self.get_rng_state() - state_dict["rng_state"] = rng_state - - return state_dict - - # TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 - def _build_sharded_state_dict_metadata(self) -> dict: - """Builds metadata used for sharded_state_dict versioning. - - - The whole content metadata is passed to ``sharded_state_dict`` model and optimizer methods - and therefore affects only the logic behind sharded_state_dict creation. - The content metadata should be minimalistic, ideally flat (or with a single nesting level) - and with semantically meaningful flag names (e.g. `distrib_optim_sharding_type`). - In particular, a simple integer (or SemVer) versioning flag (e.g. `metadata['version'] = 3.4`) - is discouraged, because the metadata serves for all models and optimizers and it's practically - impossible to enforce a linearly increasing versioning for this whole space. - """ - metadata: dict = {} - - if not mcore_ge_014: - # For backward compatibility with Megatron core < v0.14.0 - if self.use_distributed_optimizer: - metadata["distrib_optim_sharding_type"] = "fully_sharded_model_space" - return metadata - - if self.use_distributed_optimizer: - megatron_config = getattr(self.config, self.role, self.config).megatron - dist_ckpt_optim_fully_reshardable = megatron_config.dist_ckpt_optim_fully_reshardable - distrib_optim_fully_reshardable_mem_efficient = ( - megatron_config.distrib_optim_fully_reshardable_mem_efficient - ) - if dist_ckpt_optim_fully_reshardable: - metadata["distrib_optim_sharding_type"] = "fully_reshardable" - metadata[ - "distrib_optim_fully_reshardable_mem_efficient" - ] = distrib_optim_fully_reshardable_mem_efficient - else: - metadata["distrib_optim_sharding_type"] = "dp_reshardable" - - metadata["singleton_local_shards"] = False - metadata["chained_optim_avoid_prefix"] = True - return metadata - - # TODO: removed after upgrading verl > 0.7.0; https://github.com/verl-project/verl/pull/5154 - def load_checkpoint( # noqa: C901 - self, local_path: str, hdfs_path: str = None, del_local_after_load=False - ): - if local_path is not None: - assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist." - - # For load optimizer dist_ckpt - try: - import transformer_engine - - torch.serialization.add_safe_globals([torch.optim.AdamW]) - torch.serialization.add_safe_globals( - [transformer_engine.pytorch.optimizers.fused_adam.FusedAdam] - ) - except Exception: - pass - - dist_checkpoint_path = get_dist_checkpoint_path(local_path) - - load_content_metadata = getattr(dist_checkpointing, "load_content_metadata", None) - if load_content_metadata is None: - # For backward compatibility - sharded_sd_metadata = None - else: - sharded_sd_metadata = load_content_metadata(checkpoint_dir=dist_checkpoint_path) - if sharded_sd_metadata is None: - if self.use_distributed_optimizer: - # Backward-compatibility with old checkpoints which don't have content versioning - # Can be removed after ending support for MLM optimizer checkpoints with MCore < v0.13 - # (for MCore v0.13+ checkpoints `sharded_sd_metadata is not None`) - sharded_sd_metadata = { - "distrib_optim_sharding_type": "fully_sharded_model_space", - } - else: - sharded_sd_metadata = self._build_sharded_state_dict_metadata() - - # Get State Dict for loading - sharded_state_dict = self.generate_state_dict( - self.should_load_model and self.use_dist_checkpointing, - self.should_load_optimizer, - self.should_load_extra, - is_loading=True, - metadata=sharded_sd_metadata, - ) - log_with_rank( - f"Generated state dict for loading: {sharded_state_dict.keys()}", - rank=self.rank, - logger=logger, - ) - - # Load Dist Checkpointing - state_dict = load_dist_checkpointing( - sharded_state_dict=sharded_state_dict, - ckpt_dir=dist_checkpoint_path, - ) - - if self.should_load_model and self.use_dist_checkpointing: - assert "model" in state_dict or any( - f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model)) - ), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - for vpp_rank, model in enumerate(self.model): - if len(self.model) == 1: - model_state_dict = state_dict["model"] - else: - assert ( - f"model{vpp_rank}" in state_dict - ), f"model{vpp_rank} not found in state_dict" - model_state_dict = state_dict[f"model{vpp_rank}"] - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - self.model[vpp_rank].load_state_dict(model_state_dict) - log_with_rank( - f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger - ) - - # Skip HF checkpoint loading if PEFT is used - elif self.should_load_model and self.use_hf_checkpoint and self.peft_cls is None: - hf_model_path = get_hf_model_checkpoint_path(local_path) - if self.vanilla_bridge: - self.bridge.load_weights(self.model, hf_model_path) - else: - self.bridge.load_hf_weights(self.model, hf_model_path) - log_with_rank( - f"Loaded HF model checkpoint from {hf_model_path} with bridge", - rank=self.rank, - logger=logger, - ) - # Load PEFT adapter checkpoint if available - if self.should_load_model and self.peft_cls is not None: - adapter_ckpt_path = os.path.join(local_path, "adapter_checkpoint") - if os.path.exists(adapter_ckpt_path): - from verl.utils.megatron_peft_utils import load_adapter_checkpoint - - # TODO: a better format for adapter checkpoint, waiting megatron-bridge support - - load_adapter_checkpoint( - self.model, - adapter_ckpt_path, - ) - log_with_rank( - f"Loaded adapter checkpoint from {adapter_ckpt_path}", - rank=self.rank, - logger=logger, - ) - else: - log_with_rank( - f"PEFT config is set but no adapter checkpoint found at {adapter_ckpt_path}", - rank=self.rank, - logger=logger, - ) - - if self.should_load_optimizer: - assert ( - "optimizer" in state_dict - ), f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - optimizer_state_dict = state_dict["optimizer"] - self.optimizer.load_state_dict(optimizer_state_dict) - log_with_rank( - f"Loaded optimizer checkpoint from {local_path}", rank=self.rank, logger=logger - ) - if self.use_checkpoint_opt_param_scheduler: - assert "lr_scheduler" in state_dict, ( - f"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file " - f"{local_path}." - ) - lr_scheduler_state_dict = state_dict["lr_scheduler"] - if self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - log_with_rank( - f"Loaded LR scheduler checkpoint from {local_path}", - rank=self.rank, - logger=logger, - ) + def _is_latest_registered_checkpoint(self, path: str) -> bool: + if not self.previous_saved_paths: + return False + return os.path.abspath(path) == os.path.abspath(self.previous_saved_paths[-1]) - if self.should_load_extra: - assert ( - "rng_state" in state_dict - ), f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - rng_state = state_dict["rng_state"] - self.load_rng_states(rng_state) - log_with_rank(f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger) - - if del_local_after_load: - try: - os.remove(local_path) if is_non_local(local_path) else None - except Exception as e: - log_with_rank( - f"remove local resume ckpt file after loading failed, exception {e} will be ignored", - rank=self.rank, - logger=logger, - ) + def register_checkpoint(self, new_path: str, max_ckpt_to_keep: Optional[int] = None): + if self._is_latest_registered_checkpoint(new_path): + return + super().register_checkpoint(new_path, max_ckpt_to_keep) - def _save_state_dict(self, local_path, global_step) -> bool: + def _save_state_dict(self, local_path, global_step, max_ckpt_to_keep=None) -> bool: """ Save the model state dict to the specified local path. @@ -467,7 +183,7 @@ def _save_state_dict(self, local_path, global_step) -> bool: logger=logger, log_only_rank_0=True, ) - if self.use_hf_checkpoint: + elif self.use_hf_checkpoint: # Use mbridge to save HF model checkpoint log_with_rank( f"Saving HF model checkpoint to {local_path} with bridge", @@ -476,9 +192,14 @@ def _save_state_dict(self, local_path, global_step) -> bool: ) hf_ckpt_path = get_hf_model_checkpoint_path(local_path) if self.vanilla_bridge: - self.bridge.save_weights( - self.model, hf_ckpt_path, distributed_filesystem=True, memory_efficient=True - ) + extended_args = {} + mbridge_config = getattr(self.checkpoint_config, "mbridge_config", None) or {} + for sig in inspect.signature(self.bridge.save_weights).parameters: + if sig == "weights_path" or sig == "models": + continue + if sig in mbridge_config: + extended_args[sig] = mbridge_config[sig] + self.bridge.save_weights(self.model, hf_ckpt_path, **extended_args) else: self.bridge.save_hf_weights(self.model, hf_ckpt_path) @@ -582,6 +303,7 @@ def _save_extra_state(self, local_path, global_step) -> bool: "grad_sync_func", "param_sync_func", "generation_config", + "_pg_collection", ] backup = {} for k in bypass_keys: @@ -629,12 +351,14 @@ def _save_hf_model(self, local_path, global_step) -> bool: if self.bridge is not None: hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) if self.vanilla_bridge: - self.bridge.save_weights( - self.model, - hf_model_ckpt_path, - distributed_filesystem=True, - memory_efficient=True, - ) + extended_args = {} + mbridge_config = getattr(self.checkpoint_config, "mbridge_config", None) or {} + for sig in inspect.signature(self.bridge.save_weights).parameters: + if sig == "weights_path" or sig == "models": + continue + if sig in mbridge_config: + extended_args[sig] = mbridge_config[sig] + self.bridge.save_weights(self.model, hf_model_ckpt_path, **extended_args) else: self.bridge.save_hf_weights(self.model, hf_model_ckpt_path) else: @@ -717,29 +441,18 @@ def save_checkpoint( ): # record the previous global step self.previous_global_step = global_step - local_path = local_mkdir_safe(local_path) - # remove previous local_path - if ( - not self.checkpoint_config.async_save - and max_ckpt_to_keep - and isinstance(max_ckpt_to_keep, int) - and max_ckpt_to_keep > 0 - and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore - and local_path != self.previous_saved_paths[-1] # type: ignore - ): # last step may save twice - keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore - self.logger.info( - "Checkpoint manager is removing previous checkpoints at " - + str(self.previous_saved_paths[:keep_start]) # type: ignore - ) - self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore - self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore + skip_retention_rotation = self._is_latest_registered_checkpoint(local_path) + + if not self.checkpoint_config.async_save and not skip_retention_rotation: + self.ensure_checkpoint_capacity(max_ckpt_to_keep) + + local_path = local_mkdir_safe(local_path) torch.distributed.barrier() state_dict_thread_count = 0 - if self._save_state_dict(local_path, global_step): + if self._save_state_dict(local_path, global_step, max_ckpt_to_keep): state_dict_thread_count += 1 self._save_tokenizer(local_path, global_step) @@ -755,7 +468,6 @@ def save_checkpoint( global_step, state_dict_thread_count=state_dict_thread_count ) ) - if ( - len(self.previous_saved_paths) == 0 or local_path != self.previous_saved_paths[-1] - ): # last step may save twice - self.previous_saved_paths.append(local_path) + + if self.rank == 0: + self.register_checkpoint(local_path, max_ckpt_to_keep) diff --git a/trinity/trainer/verl/megatron_workers.py b/trinity/trainer/verl/megatron_workers.py index eda9cd6a4c0..36576ad4550 100644 --- a/trinity/trainer/verl/megatron_workers.py +++ b/trinity/trainer/verl/megatron_workers.py @@ -13,13 +13,13 @@ # limitations under the License. """ The main entry point to run the PPO algorithm. -Modified from https://github.com/volcengine/verl/blob/v0.7.0/verl/workers/megatron_workers.py +Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/workers/megatron_workers.py """ import datetime import os -import sys import time +from contextlib import nullcontext import psutil import ray @@ -30,27 +30,11 @@ from megatron.core import parallel_state as mpu from omegaconf import DictConfig, OmegaConf, open_dict -from trinity.utils.log import get_logger - try: from verl.workers.engine.mindspeed.transformer_impl import repatch except ImportError: repatch = None -# start of patch for verl to support transformers v5 -if not hasattr(sys.modules["transformers"], "AutoModelForVision2Seq"): - setattr( - sys.modules["transformers"], - "AutoModelForVision2Seq", - sys.modules["transformers"].AutoModelForImageTextToText, - ) - sys.modules["transformers"].__all__.append("AutoModelForVision2Seq") - - import accelerate - - setattr(accelerate, "init_empty_weights", lambda: torch.device("cpu")) -# end of patch for verl to support transformers v5 - from verl import DataProto from verl.models.mcore import get_mcore_weight_converter from verl.single_controller.base import Worker @@ -62,6 +46,7 @@ from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import ( get_device_id, + get_device_name, get_nccl_backend, get_torch_device, set_expandable_segments, @@ -104,7 +89,9 @@ from trinity.manager.synchronizer import Synchronizer from trinity.trainer.verl.megatron_actor import MegatronPPOActor from trinity.trainer.verl.megatron_checkpoint_manager import MegatronCheckpointManager +from trinity.trainer.verl.utils import patch_rope_theta_in_hf_config from trinity.utils.distributed import init_process_group +from trinity.utils.log import get_logger class MegatronWorker(Worker): @@ -117,6 +104,7 @@ def _init_hf_config_and_tf_config( # noqa: C901 override_transformer_config, trust_remote_code=False, megatron_config=None, + enable_mtp=False, ): from transformers import AutoConfig from verl.models.mcore import hf_to_mcore_config @@ -163,13 +151,23 @@ def _init_hf_config_and_tf_config( # noqa: C901 hf_config.rope_theta = self.config.model.rope_theta # start of patch for verl to support transformers v5 - if not hasattr(hf_config, "rope_theta"): - rope_theta = hf_config.rope_parameters.get("rope_theta", None) - if rope_theta is not None: - setattr(hf_config, "rope_theta", rope_theta) + patch_rope_theta_in_hf_config(hf_config) # end of patch for verl to support transformers v5 self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + + if enable_mtp: + assert ( + getattr(hf_config, "num_nextn_predict_layers", 0) > 0 + ), "MTP requires at least one nextn_predict_layer" + assert megatron_config.use_mbridge, "MTP requires use_mbridge to be True" + override_transformer_config[ + "mtp_loss_scaling_factor" + ] = self.config.model.mtp.mtp_loss_scaling_factor + elif hasattr(hf_config, "num_nextn_predict_layers"): + hf_config.num_nextn_predict_layers = 0 + + self.enable_mtp = enable_mtp update_model_config(hf_config, override_config_kwargs=override_config_kwargs) self.architectures = getattr(hf_config, "architectures", None) if self.rank == 0: @@ -188,6 +186,44 @@ def _init_hf_config_and_tf_config( # noqa: C901 self.vanilla_bridge = megatron_config.get("vanilla_mbridge", True) if megatron_config.use_mbridge: if self.vanilla_bridge: + # start of patch for mbridge + import json + from glob import glob + + from mbridge.core.safetensor_io import SafeTensorIO + from safetensors import safe_open + + if not getattr(SafeTensorIO, "_is_patched", False): + + def new_init(self, hf_dir: str): + index_file = os.path.join(hf_dir, "model.safetensors.index.json") + config = AutoConfig.from_pretrained(hf_dir, trust_remote_code=True) + + self.index = {} + self.origin_index = {} + if os.path.exists(index_file): + with open(index_file, "r") as f: + origin_index = json.load(f) + self.index = origin_index["weight_map"] + self.origin_index = origin_index + else: + src_files = glob(os.path.join(hf_dir, "*.safetensors")) + if len(src_files) == 1: + for file in src_files: + with safe_open(file, framework="pt", device="cpu") as f: + filename = os.path.basename(file) + for key in f.keys(): + self.index[key] = filename + if getattr(config, "tie_word_embeddings", False): + if "lm_head.weight" in self.index.keys(): + self.index.pop("lm_head.weight") + + self.hf_dir = hf_dir + + SafeTensorIO.__init__ = new_init + SafeTensorIO._is_patched = True + # end of patch for mbridge + from verl.models.mcore.mbridge import AutoBridge bridge = AutoBridge.from_config(hf_config, dtype=dtype) @@ -208,6 +244,10 @@ def _init_hf_config_and_tf_config( # noqa: C901 # In case of invalid overrides, we need to make sure some critical params are set correctly provider.params_dtype = dtype + # Ensure dtype settings propagate to Megatron-Bridge/TE + provider.fp16 = fp16 + provider.bf16 = bf16 + # Pass distributed info provider.tensor_model_parallel_size = megatron_config.tensor_model_parallel_size provider.pipeline_model_parallel_size = megatron_config.pipeline_model_parallel_size @@ -282,7 +322,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): set_numa_affinity() rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( - backend=get_nccl_backend(), + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) @@ -351,6 +391,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): self._is_offload_param = False self._is_offload_grad = False self._is_offload_optimizer = False + self._hf_export_conversion_tasks = None # normalize config if self._is_actor: @@ -409,6 +450,7 @@ def _build_model_optimizer( override_transformer_config, self.config.model.get("trust_remote_code", False), self.config.actor.megatron if not self._is_ref else self.config.ref.megatron, + self.config.model.get("mtp", {}).get("enable", False), ) self.generation_config = get_generation_config( self.local_path, @@ -606,10 +648,16 @@ def init_model(self): tf_config=self.tf_config, actor_module=self.actor_module, actor_optimizer=self.actor_optimizer, + mtp_config=self.config.model.mtp if self.config.model.mtp.enable else None, ) self.logger.info(f"routing replay layers: {len(RouterReplay.router_instances)}") log_gpu_memory_usage("After MegatronPPOActor init", logger=self.logger) + if self.bridge is not None and not self.vanilla_bridge: + self._hf_export_conversion_tasks = self.bridge.get_conversion_tasks( + self.actor.actor_module + ) + if self._is_ref: self.ref_module, self.ref_model_config = self._build_model_optimizer( model_path=self.config.model.path, @@ -679,7 +727,11 @@ def _get_tensor_generator(self): if self.vanilla_bridge: per_tensor_param = self.bridge.export_weights(self.actor.actor_module) else: - per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module) + per_tensor_param = self.bridge.export_hf_weights( + self.actor.actor_module, + show_progress=False, + conversion_tasks=self._hf_export_conversion_tasks, + ) else: per_tensor_param = per_tensor_generator( self.actor.actor_module, @@ -798,8 +850,9 @@ def update_actor(self, data: DataProto): metrics = self.actor.update_policy(dataloader=dataloader) delta_time = timer.last global_num_tokens = data.meta_info["global_token_num"] + images_seqlens = data.meta_info.get("images_seqlens", None) estimated_flops, promised_flops = self.flops_counter.estimate_flops( - global_num_tokens, delta_time + global_num_tokens, delta_time, images_seqlens=images_seqlens ) metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size @@ -838,6 +891,9 @@ def update_actor(self, data: DataProto): @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") def compute_ref_log_prob(self, data: DataProto): + if self.peft_cls is not None: + data.meta_info["is_lora"] = True + return self.compute_log_prob(data) assert self._is_ref if self._ref_is_offload_param: load_megatron_model_to_gpu(self.ref_module, load_grad=False) @@ -870,10 +926,12 @@ def compute_log_prob(self, data: DataProto): log_gpu_memory_usage( "After load actor params and grad during compute_log_prob", logger=self.logger ) - # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + is_lora = data.meta_info.pop("is_lora", False) + adapter_ctx = self.peft_cls.disable_adapter(self.actor_module) if is_lora else nullcontext() + config_source = self.config.ref if is_lora else self.config.rollout + data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz data.meta_info["temperature"] = self.config.rollout.temperature if self.enable_routing_replay and self.config.actor.router_replay.mode == "R2": @@ -882,12 +940,15 @@ def compute_log_prob(self, data: DataProto): if self.enable_routing_replay and self.config.actor.router_replay.mode == "R3": RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) - output, entropys, layers_topk_idx = self.actor.compute_log_prob( - data=data, calculate_entropy=True - ) + with adapter_ctx: + output, entropys, layers_topk_idx = self.actor.compute_log_prob( + data=data, calculate_entropy=not is_lora + ) + tensors = {"ref_log_prob": output} if is_lora else {"old_log_probs": output} + if not is_lora: + tensors["entropys"] = entropys output = DataProto.from_dict( - tensors={"old_log_probs": output, "entropys": entropys}, - meta_info={"temperature": self.config.rollout.temperature}, + tensors=tensors, meta_info={"temperature": self.config.rollout.temperature} ) if self.config.actor.router_replay.mode == "R2": output.batch["routed_experts"] = layers_topk_idx diff --git a/trinity/trainer/verl/monkey_patch.py b/trinity/trainer/verl/monkey_patch.py index e5b10e4b056..bfc842aab23 100644 --- a/trinity/trainer/verl/monkey_patch.py +++ b/trinity/trainer/verl/monkey_patch.py @@ -187,11 +187,13 @@ def apply_monkey_patch( # noqa: C901 use_remove_padding: bool = True, use_fused_kernels: bool = False, fused_kernels_backend: str = None, + use_prefix_grouper: bool = False, use_tiled_mlp: bool = False, tiled_mlp_shards: int = 4, ): """ - Apply monkey patch to the models for ulysses sequence parallel, fused kernel, and tiled MLP. + Apply monkey patch to the models for ulysses sequence parallel, fused kernel, prefix grouper, + and tiled MLP. In the end of this function forward function of the model is patched for fused kernel. If the model is not supported with fused kernel, please return after patch. @@ -207,6 +209,7 @@ def apply_monkey_patch( # noqa: C901 """ from verl.models.transformers.monkey_patch import ( _ulysses_flash_attention_forward, + apply_prefix_grouper_patch, patch_vlm_for_ulysses_input_slicing, ) from verl.utils.import_utils import is_trl_available @@ -221,6 +224,9 @@ def apply_monkey_patch( # noqa: C901 model_type = getattr(model.config, "model_type", None) apply_tiled_mlp_monkey_patch(num_shards=tiled_mlp_shards, model_type=model_type) + if use_prefix_grouper: + apply_prefix_grouper_patch() + """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" module = sys.modules[model.__module__] @@ -338,12 +344,12 @@ def state_dict(self, *args, **kwargs): ) # Step 3: patch verl.utils.flops_counter - from verl.utils.flops_counter import ESTIMATE_FUNC, _estimate_qwen2_flops + from verl.utils.flops_counter import ESTIMATE_FUNC, _estimate_qwen3_vl_flops ESTIMATE_FUNC.update( { - "qwen3_5": _estimate_qwen2_flops, - "qwen3_5_moe": _estimate_qwen2_flops, + "qwen3_5": _estimate_qwen3_vl_flops, + "qwen3_5_moe": _estimate_qwen3_vl_flops, } ) diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 3b8b03d8321..ee0b67314c6 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -389,3 +389,20 @@ def rearrange_micro_batches( micro_batches.append(curr_micro_batch) return micro_batches, micro_bsz_idx + + +# add rope_theta to hf config for backward compatibility, can be removed after verl is updated +def patch_rope_theta_in_hf_config(hf_config): + if not hasattr(hf_config, "rope_theta"): + if hasattr(hf_config, "rope_parameters"): + rope_parameters = hf_config.rope_parameters + elif hasattr(hf_config, "text_config") and hasattr( + hf_config.text_config, "rope_parameters" + ): + rope_parameters = hf_config.text_config.rope_parameters + else: + rope_parameters = {} + + rope_theta = rope_parameters.get("rope_theta", None) + if rope_theta is not None: + setattr(hf_config, "rope_theta", rope_theta) diff --git a/trinity/trainer/verl/verl_config.py b/trinity/trainer/verl/verl_config.py index da5c12c8253..9bdcaa8f831 100644 --- a/trinity/trainer/verl/verl_config.py +++ b/trinity/trainer/verl/verl_config.py @@ -4,7 +4,13 @@ from typing import Any, Dict, List, Optional, Union from omegaconf import OmegaConf -from verl.workers.config import PolicyLossConfig, RouterReplayConfig +from verl.trainer.config import CheckpointConfig, RolloutCorrectionConfig +from verl.workers.config import ( + McoreEngineConfig, + MtpConfig, + PolicyLossConfig, + RouterReplayConfig, +) from trinity.algorithm import ALGORITHM_TYPE from trinity.common.config import Config, SynchronizerConfig, set_if_none @@ -47,6 +53,9 @@ class ActorModel: exclude_modules: Optional[str] = None lora_adapter_path: Optional[str] = None + # mtp configs + mtp: MtpConfig = field(default_factory=MtpConfig) + # rope configs rope_scaling: Optional[dict] = None rope_theta: Optional[float] = None @@ -100,10 +109,10 @@ class FSDPConfig: @dataclass -class Checkpoint: - load_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) - save_contents: List[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) - async_save: bool = False # TODO: testing async save +class _CheckpointConfig(CheckpointConfig): + mbridge_config: dict[str, Any] = field( + default_factory=lambda: dict(distributed_filesystem=True, memory_efficient=True) + ) @dataclass @@ -115,30 +124,16 @@ class OverrideTransformerConfig: @dataclass -class MegatronConfig: - param_offload: bool = False - grad_offload: bool = False - optimizer_offload: bool = False - tensor_model_parallel_size: int = 1 - expert_model_parallel_size: int = 1 - expert_tensor_parallel_size: Optional[int] = None - pipeline_model_parallel_size: int = 1 - virtual_pipeline_model_parallel_size: Optional[int] = None - context_parallel_size: int = 1 - sequence_parallel: bool = True - use_distributed_optimizer: bool = True - use_dist_checkpointing: bool = False - dist_checkpointing_path: Optional[str] = None - dist_ckpt_optim_fully_reshardable: bool = False - distrib_optim_fully_reshardable_mem_efficient: bool = False - seed: int = 42 - override_ddp_config: dict = field(default_factory=dict) - override_transformer_config: OverrideTransformerConfig = field( - default_factory=OverrideTransformerConfig - ) - use_mbridge: bool = False - dtype: str = "bfloat16" - use_remove_padding: bool = True +class _McoreEngineConfig(McoreEngineConfig): + # use_dist_checkpointing: bool = True + # whether to use the vanilla mbridge without verl-specific optimizations + # TODO: failed to run with vanilla_mbridge = False, need to investigate further + vanilla_mbridge: bool = True + + max_token_len_per_gpu: Optional[int] = None + micro_batch_size_per_gpu: Optional[int] = None + infer_max_token_len_per_gpu: Optional[int] = None + infer_micro_batch_size_per_gpu: Optional[int] = None @dataclass @@ -157,6 +152,9 @@ class Actor: ppo_micro_batch_size: Optional[int] = None ppo_micro_batch_size_per_gpu: int = 1 use_dynamic_bsz: Optional[bool] = None + use_prefix_grouper: bool = False + calculate_sum_pi_squared: bool = False + sum_pi_squared_checkpointing: bool = False ppo_max_token_len_per_gpu: Optional[int] = None fix_actor_microbatch_loss_scale: Optional[bool] = None # EXPERIMENTAL grad_clip: Optional[float] = None @@ -165,10 +163,10 @@ class Actor: ulysses_sequence_parallel_size: Optional[int] = None entropy_from_logits_with_chunking: bool = False entropy_checkpointing: bool = False - checkpoint: Checkpoint = field(default_factory=Checkpoint) + checkpoint: _CheckpointConfig = field(default_factory=_CheckpointConfig) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) - megatron: MegatronConfig = field(default_factory=MegatronConfig) + megatron: _McoreEngineConfig = field(default_factory=_McoreEngineConfig) profile: ProfileConfig = field(default_factory=ProfileConfig) data_loader_seed: Optional[int] = None load_weight: bool = True @@ -192,14 +190,15 @@ class Ref: log_prob_micro_batch_size: Optional[int] = None log_prob_micro_batch_size_per_gpu: int = 1 log_prob_use_dynamic_bsz: Optional[bool] = None + use_prefix_grouper: bool = False log_prob_max_token_len_per_gpu: Optional[int] = None ulysses_sequence_parallel_size: Optional[int] = None entropy_from_logits_with_chunking: bool = False entropy_checkpointing: bool = False - checkpoint: Checkpoint = field( - default_factory=lambda: Checkpoint(load_contents=["model"], save_contents=["model"]) + checkpoint: _CheckpointConfig = field( + default_factory=lambda: _CheckpointConfig(load_contents=["model"], save_contents=["model"]) ) - megatron: MegatronConfig = field(default_factory=MegatronConfig) + megatron: _McoreEngineConfig = field(default_factory=_McoreEngineConfig) profile: ProfileConfig = field(default_factory=ProfileConfig) load_weight: bool = True profiler: dict = field(default_factory=dict) @@ -276,10 +275,10 @@ class Critic: shuffle: bool = False grad_clip: Optional[float] = None cliprange_value: float = 0.0 - checkpoint: Checkpoint = field(default_factory=Checkpoint) + checkpoint: _CheckpointConfig = field(default_factory=_CheckpointConfig) rollout_n: int = 1 loss_agg_mode: str = "token-mean" - megatron: MegatronConfig = field(default_factory=MegatronConfig) + megatron: _McoreEngineConfig = field(default_factory=_McoreEngineConfig) profile: ProfileConfig = field(default_factory=ProfileConfig) data_loader_seed: Optional[int] = None load_weight: bool = True @@ -311,6 +310,11 @@ class RewardModel: use_reward_loop: bool = True +@dataclass +class Reward: + reward_model: RewardModel = field(default_factory=RewardModel) + + @dataclass class CustomRewardFunction: path: Optional[str] = None @@ -326,23 +330,13 @@ class KL_Ctrl: @dataclass -class RolloutCorrection: - rollout_is: Optional[str] = None - rollout_is_threshold: float = 2.0 - rollout_rs: Optional[str] = None - rollout_rs_threshold: Optional[float] = None - rollout_rs_threshold_lower: Optional[float] = None - rollout_token_veto_threshold: Optional[float] = None - # Because rollout and training in Trinity runs separately, - # rollout_is_batch_normalize is default to True +class _RolloutCorrectionConfig(RolloutCorrectionConfig): bypass_mode: bool = True - loss_type: str = "ppo_clip" - rollout_is_batch_normalize: bool = False @dataclass class Algorithm: - rollout_correction: RolloutCorrection = field(default_factory=RolloutCorrection) + rollout_correction: _RolloutCorrectionConfig = field(default_factory=_RolloutCorrectionConfig) # ! DO NOT SET gamma or lam below; they are kept here merely for compatibility with verl, # and their values will be overwritten by those in AlgorithmConfig.advantage_fn_args # if they are really needed (e.g., for GAE advantage/returns computation) @@ -393,6 +387,7 @@ class veRLConfig: actor_rollout_ref: ActorRolloutRef = field(default_factory=ActorRolloutRef) critic: Critic = field(default_factory=Critic) reward_model: RewardModel = field(default_factory=RewardModel) + reward: Reward = field(default_factory=Reward) custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction) algorithm: Algorithm = field(default_factory=Algorithm) trainer: Trainer = field(default_factory=Trainer) diff --git a/trinity/trainer/verl/verl_trainer.py b/trinity/trainer/verl/verl_trainer.py index b22af2715f4..1796c88e041 100644 --- a/trinity/trainer/verl/verl_trainer.py +++ b/trinity/trainer/verl/verl_trainer.py @@ -18,6 +18,7 @@ from verl.trainer.ppo.metric_utils import ( compute_throughout_metrics, compute_timing_metrics, + compute_variance_proxy_metrics, ) from verl.trainer.ppo.ray_trainer import ( RayClassWithInitArgs, @@ -440,9 +441,14 @@ async def prepare(self): def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): # Do not use verl's dataloader self.train_dataloader = None + self.val_dataloader = None self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - self.config.actor_rollout_ref.actor.optim.total_training_steps = self.total_training_steps - self.config.critic.optim.total_training_steps = self.total_training_steps + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim") is not None: + self.config.actor_rollout_ref.actor.optim.total_training_steps = ( + self.total_training_steps + ) + if OmegaConf.select(self.config, "critic.optim") is not None: + self.config.critic.optim.total_training_steps = self.total_training_steps async def save_state_dict(self): # checkpoint sync actor_local_path = os.path.join( @@ -475,6 +481,16 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 batch.meta_info["global_token_num"] = torch.sum( batch.batch["attention_mask"], dim=-1 ).tolist() + images_seqlens_all = [] + for multi_modal_input in batch.non_tensor_batch.get("multi_modal_inputs", []): + if "image_grid_thw" not in multi_modal_input: + continue + images_seqlens = multi_modal_input.get("images_seqlens", None) + if images_seqlens is None: + continue + images_seqlens_all.extend(images_seqlens.tolist()) + if images_seqlens_all: + batch.meta_info["images_seqlens"] = images_seqlens_all # Operating Mode Selection: # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) @@ -592,6 +608,8 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 metrics.update( compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) ) + gradient_norm = metrics.get("actor/grad_norm", None) + metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) return metrics