Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable SequenceParallel in 2D training #2503

Merged
merged 12 commits into from
Mar 26, 2025

Conversation

ebsmothers
Copy link
Contributor

Currently we are unable to train with TP and fused optimizer (see #2501) because our norm layers were not parallelized. This PR changes that. To do that, we need to

(1) Revert #2054, which switched our RMSNorm to torch.nn.functional.rms_norm. That PR did demonstrate some performance improvements and we should probably get
(2) Pad to a multiple of tensor_parallel_dim by adding a pad_to_multiple_of option to padded_collate_sft -- otherwise we get unequal sequence lengths when calculating our loss. This is slightly inefficient but hopefully tp_dim << seq_len in most cases. We should probably do this for other collate functions too, but this is the most widely used, so starting here for now. (Really we should just direct people to use packed dataset, in that case this is kinda implicit anyways.)

(2) is currently done super hackily, I will do some more testing and eventually clean it up.

Copy link

pytorch-bot bot commented Mar 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2503

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit e552802 with merge base ab8c23e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 14, 2025
@codecov-commenter
Copy link

codecov-commenter commented Mar 14, 2025

Codecov Report

Attention: Patch coverage is 50.00000% with 7 lines in your changes missing coverage. Please review.

Project coverage is 66.41%. Comparing base (ab8c23e) to head (859c5a8).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 5 Missing ⚠️
torchtune/data/_collate.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2503       +/-   ##
==========================================
+ Coverage   8.60%   66.41%   +57.80%     
==========================================
  Files        323      373       +50     
  Lines      19293    22005     +2712     
==========================================
+ Hits        1661    14615    +12954     
+ Misses     17632     7390    -10242     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1), output_layouts=Replicate(), use_local_output=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe use_local_output is by default True?

@@ -169,6 +170,8 @@ def padded_collate_sft(
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
padding_idx (int): Padding index for input ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.
pad_to_multiple_of (Optional[int]): If not None, pad the sequence to a multiple of this number.
This is useful for proper sharding with e.g. SequenceParallel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! thanks for that!

@joecummings
Copy link
Contributor

Screenshot 2025-03-26 at 12 51 12 PM Confirmation that this works as expected. Let's take a look:
  1. Loss is the same: cool, but wait you're testing with two different GPU configurations - one has two GPUs and is all DP and the other has DP + TP on 4 GPUs? Yes, this is b/c the data split and batch sizes should remain the same in order to properly test that the loss is the same. If we did full DP on 4 GPUs, then the data would be sharded 4 times, not twice and the order of data shown to the model would be different.
  2. Tokens per second is lower w/ 'hybrid': Yep, this makes sense as we are still on single node and there will be more small communication synchronizations w/ TP than pure FSDP.
  3. Memory usage is lower w/ 'hybrid': Duh, each GPU holds an even smaller fraction of parameters at all times

@joecummings joecummings marked this pull request as ready for review March 26, 2025 16:58
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions to remove the hacky recipe change by just assuming that every training collation function support padding_multiple

@@ -161,6 +161,7 @@ def padded_collate_sft(
batch: List[Dict[str, List[int]]],
padding_idx: int = 0,
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX,
pad_to_multiple_of: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we drop the "of"

@@ -169,6 +170,8 @@ def padded_collate_sft(
batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs.
padding_idx (int): Padding index for input ids. Defaults to 0.
ignore_idx (int): Padding index for labels. Defaults to -100.
pad_to_multiple_of (Optional[int]): If not None, pad the sequence to a multiple of this number.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be added to the MM collate function too. It wouldn't hurt to once over the other collation functions.

self.tensor_parallel_dim > 1
and collate_fn == "torchtune.data.padded_collate_sft"
):
collate_args = {"pad_to_multiple_of": self.tensor_parallel_dim}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add "pad_to_multiple=1" to every collate function that support training. self.tensor_parallel_dim should be 1 by default anyway.

@@ -727,6 +745,7 @@ def _setup_data(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
**collate_args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making pad_to_multiple standard allows us to get rid of this and the above if statements

@joecummings joecummings changed the title Hacky changes to finish enabling 2D parallel Enable SequenceParallel in 2D training Mar 26, 2025
@joecummings joecummings requested a review from pbontrager March 26, 2025 20:49
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I still think it would be cleaner to set the default for pad_to_multiple_of=1 instead of None since a multiple of 1 means no change and doesn't require a special None type. I'll approve anyway.

@joecummings joecummings merged commit c5c160b into pytorch:main Mar 26, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants