-
Notifications
You must be signed in to change notification settings - Fork 563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add StatefulDataLoader
to all recipes except knowledge_single
#2441
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2441
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f9dabe7 with merge base 0afea1a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@joecummings Can you give a brief review for the ppo_single_device and qat_distributed recipes please. Maybe I'm missing something obvious? |
recipes/lora_dpo_single_device.py
Outdated
@@ -245,10 +245,17 @@ def setup(self, cfg: DictConfig) -> None: | |||
|
|||
# Dataloader depends on the tokenizer and loss_fn and should be | |||
# setup after all of these are setup | |||
self._sampler, self._dataloader = self._setup_data( | |||
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not being passed to the setup_data. The default should also be "padded_collate_dpo".
Do you mind checking the collate_name in all recipes you changed? I think that this is an error in others too
@@ -644,16 +647,16 @@ def _setup_data( | |||
raise RuntimeError("left_pad_sequence collator is only for inference.") | |||
collate_fn = _get_component_from_path(collate_fn) | |||
|
|||
sampler = DistributedSampler( | |||
ds, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, seed=0 | |||
sampler = StatefulDistributedSampler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also be adding "seed" to every sampler. I think that we also missed it in joe's original PR. Not a bug, since they hardcode seed=0, but it should probably use the same seed as the one given in the config.
== self.max_steps_per_epoch | ||
): | ||
break | ||
self._dataloader.sampler.set_epoch(curr_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some recipes are missing this
write_hf_ckpt_config(ckpt_dir) | ||
write_hf_ckpt_config(tmpdir) | ||
|
||
# Train for two epochs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont see neither epoch or max_steps_per_epoch in this config. Shouldnt we hardcode these and make max_steps_per_epoch something small?
@@ -123,18 +123,18 @@ def test_loss(self, tmpdir, monkeypatch): | |||
|
|||
loss_values = get_loss_values_from_metric_logger(log_file) | |||
expected_loss_values = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how were those generated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@krammnic how come these values were changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SalmanMohammadi Maybe I misunderstood the question, but it was just a run of integration test for ppo
"llama2": [ | ||
10.523505210876465, | ||
10.522541999816895, | ||
10.484564781188965, | ||
10.550897598266602, | ||
10.519064903259277, | ||
10.475532531738281, | ||
10.478732109069824, | ||
10.447160720825195, | ||
10.512746810913086, | ||
10.506056785583496, | ||
10.509842872619629, | ||
10.574836730957031, | ||
10.444534301757812, | ||
10.466689109802246, | ||
10.503318786621094, | ||
10.464300155639648, | ||
10.458215713500977, | ||
10.477818489074707, | ||
10.396238327026367, | ||
10.40851879119873, | ||
10.433064460754395, | ||
10.500737190246582, | ||
10.483240127563477, | ||
10.43812084197998, | ||
], | ||
"llama3": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did we generate these?
if dataloader_state_dict is not None: | ||
dataloader.load_state_dict(dataloader_state_dict) | ||
list(dataloader) # Hack to force dataloader to finish iteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if tests pass here, we have to delete this section in every recipe: #2490
Thanks for the review! Let me fix all points and then we will iterate on this PR |
1cb34f7
to
f9dabe7
Compare
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
#2431
#2439
Changelog
What are the changes made in this PR?
Added
StatefulDataLoader
to all recipes exceptknowledge_single
(this recipe will be updated by @jxtngx)Notice, that I haven't done verifying runs yet. Will add them gradually asap.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example