Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add StatefulDataLoader to all recipes except knowledge_single #2441

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

krammnic
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.
#2431
#2439

Changelog

What are the changes made in this PR?
Added StatefulDataLoader to all recipes except knowledge_single (this recipe will be updated by @jxtngx)

Notice, that I haven't done verifying runs yet. Will add them gradually asap.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Feb 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2441

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f9dabe7 with merge base 0afea1a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 26, 2025
@krammnic
Copy link
Contributor Author

krammnic commented Feb 28, 2025

@joecummings Can you give a brief review for the ppo_single_device and qat_distributed recipes please. Maybe I'm missing something obvious?

@@ -245,10 +245,17 @@ def setup(self, cfg: DictConfig) -> None:

# Dataloader depends on the tokenizer and loss_fn and should be
# setup after all of these are setup
self._sampler, self._dataloader = self._setup_data(
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not being passed to the setup_data. The default should also be "padded_collate_dpo".

Do you mind checking the collate_name in all recipes you changed? I think that this is an error in others too

@@ -644,16 +647,16 @@ def _setup_data(
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, seed=0
sampler = StatefulDistributedSampler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also be adding "seed" to every sampler. I think that we also missed it in joe's original PR. Not a bug, since they hardcode seed=0, but it should probably use the same seed as the one given in the config.

== self.max_steps_per_epoch
):
break
self._dataloader.sampler.set_epoch(curr_epoch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some recipes are missing this

write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for two epochs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see neither epoch or max_steps_per_epoch in this config. Shouldnt we hardcode these and make max_steps_per_epoch something small?

@@ -123,18 +123,18 @@ def test_loss(self, tmpdir, monkeypatch):

loss_values = get_loss_values_from_metric_logger(log_file)
expected_loss_values = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how were those generated?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krammnic how come these values were changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi Maybe I misunderstood the question, but it was just a run of integration test for ppo

Comment on lines 48 to 74
"llama2": [
10.523505210876465,
10.522541999816895,
10.484564781188965,
10.550897598266602,
10.519064903259277,
10.475532531738281,
10.478732109069824,
10.447160720825195,
10.512746810913086,
10.506056785583496,
10.509842872619629,
10.574836730957031,
10.444534301757812,
10.466689109802246,
10.503318786621094,
10.464300155639648,
10.458215713500977,
10.477818489074707,
10.396238327026367,
10.40851879119873,
10.433064460754395,
10.500737190246582,
10.483240127563477,
10.43812084197998,
],
"llama3": [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did we generate these?

Comment on lines 694 to 696
if dataloader_state_dict is not None:
dataloader.load_state_dict(dataloader_state_dict)
list(dataloader) # Hack to force dataloader to finish iteration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if tests pass here, we have to delete this section in every recipe: #2490

@krammnic
Copy link
Contributor Author

Thanks for the review! Let me fix all points and then we will iterate on this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants