-
Notifications
You must be signed in to change notification settings - Fork 563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable SequenceParallel in 2D training #2503
Conversation
This reverts commit 1450d61.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2503
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit e552802 with merge base ab8c23e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2503 +/- ##
==========================================
+ Coverage 8.60% 66.41% +57.80%
==========================================
Files 323 373 +50
Lines 19293 22005 +2712
==========================================
+ Hits 1661 14615 +12954
+ Misses 17632 7390 -10242 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
), | ||
"norm": SequenceParallel(), | ||
"output": ColwiseParallel( | ||
input_layouts=Shard(1), output_layouts=Replicate(), use_local_output=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe use_local_output
is by default True?
@@ -169,6 +170,8 @@ def padded_collate_sft( | |||
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. | |||
padding_idx (int): Padding index for input ids. Defaults to 0. | |||
ignore_idx (int): Padding index for labels. Defaults to -100. | |||
pad_to_multiple_of (Optional[int]): If not None, pad the sequence to a multiple of this number. | |||
This is useful for proper sharding with e.g. SequenceParallel. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch! thanks for that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestions to remove the hacky recipe change by just assuming that every training collation function support padding_multiple
torchtune/data/_collate.py
Outdated
@@ -161,6 +161,7 @@ def padded_collate_sft( | |||
batch: List[Dict[str, List[int]]], | |||
padding_idx: int = 0, | |||
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, | |||
pad_to_multiple_of: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we drop the "of"
torchtune/data/_collate.py
Outdated
@@ -169,6 +170,8 @@ def padded_collate_sft( | |||
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. | |||
padding_idx (int): Padding index for input ids. Defaults to 0. | |||
ignore_idx (int): Padding index for labels. Defaults to -100. | |||
pad_to_multiple_of (Optional[int]): If not None, pad the sequence to a multiple of this number. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be added to the MM collate function too. It wouldn't hurt to once over the other collation functions.
recipes/full_finetune_distributed.py
Outdated
self.tensor_parallel_dim > 1 | ||
and collate_fn == "torchtune.data.padded_collate_sft" | ||
): | ||
collate_args = {"pad_to_multiple_of": self.tensor_parallel_dim} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just add "pad_to_multiple=1" to every collate function that support training. self.tensor_parallel_dim should be 1 by default anyway.
recipes/full_finetune_distributed.py
Outdated
@@ -727,6 +745,7 @@ def _setup_data( | |||
collate_fn, | |||
padding_idx=self._tokenizer.pad_id, | |||
ignore_idx=self._loss_fn.ignore_index, | |||
**collate_args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making pad_to_multiple standard allows us to get rid of this and the above if statements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I still think it would be cleaner to set the default for pad_to_multiple_of=1 instead of None since a multiple of 1 means no change and doesn't require a special None type. I'll approve anyway.
Currently we are unable to train with TP and fused optimizer (see #2501) because our norm layers were not parallelized. This PR changes that. To do that, we need to
(1) Revert #2054, which switched our RMSNorm to
torch.nn.functional.rms_norm
. That PR did demonstrate some performance improvements and we should probably get(2) Pad to a multiple of
tensor_parallel_dim
by adding apad_to_multiple_of
option topadded_collate_sft
-- otherwise we get unequal sequence lengths when calculating our loss. This is slightly inefficient but hopefully tp_dim << seq_len in most cases. We should probably do this for other collate functions too, but this is the most widely used, so starting here for now. (Really we should just direct people to use packed dataset, in that case this is kinda implicit anyways.)(2) is currently done super hackily, I will do some more testing and eventually clean it up.