Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRPO LoRA Single Device #2467

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

GRPO LoRA Single Device #2467

wants to merge 7 commits into from

Conversation

ianbarber
Copy link

@ianbarber ianbarber commented Mar 9, 2025

Context

What is the purpose of this PR? Is it to

  • [x ] add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

#2421 - exploring a LoRA recipe.

Changelog

What are the changes made in this PR?

  • Add in new configs and recipe for lora single device
  • SFT config for lora

Following the pattern in the original GRPO PR I tried SFT then GRPO. The model reaches around 45% on the train set after 2 epochs.

By reducing the number of generations I got the recipe running in a single 24GB card, which is a good baseline for ease of access!

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • [x ] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • [x ] manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • [x ] I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Mar 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2467

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ecec2aa with merge base 1241231 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 9, 2025
@ianbarber ianbarber changed the title (draft/discussion) GRPO LoRA GRPO LoRA Single Device Mar 19, 2025
@ianbarber
Copy link
Author

ianbarber commented Mar 22, 2025

Updated this - I validated the recipe for single device, so figured it was just cleaner to do a single device diff. I reverted the checkpointing changes as the LoRA from SFT gets merged in and its cleaner just start it again. The recipe gets to a decent success rate during training - testing on 4 epochs running the gms8k task from lm evals gives:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5436|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5421|±  |0.0137|

This is (clearly) eval-on-the-training data, but for reference the base llama 3B gets 0.2646 on flexible (and for reference after SFT 0.29, one epoch of GRPO gets to .48). I was a little dubious whether LoRA would work here and it appears it does!

@codecov-commenter
Copy link

codecov-commenter commented Mar 22, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 66.42%. Comparing base (1241231) to head (ecec2aa).
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2467      +/-   ##
==========================================
- Coverage   66.98%   66.42%   -0.57%     
==========================================
  Files         378      373       -5     
  Lines       22527    21986     -541     
==========================================
- Hits        15090    14604     -486     
+ Misses       7437     7382      -55     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants