Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable SequenceParallel in 2D training #2503

Merged
merged 12 commits into from
Mar 26, 2025
14 changes: 7 additions & 7 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ def __init__(self, cfg: DictConfig) -> None:
raise ValueError(
f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
)
if self.tensor_parallel_dim > 1 and cfg.optimizer.get("fused", False):
raise ValueError(
"Tensor parallelism is currently incompatible with fused optimizer."
)

self.data_parallel_dim = self.world_size // self.tensor_parallel_dim

Expand Down Expand Up @@ -552,6 +548,10 @@ def _setup_model(

# Apply tensor parallelism to the model
if self.tensor_parallel_dim > 1:
if self.data_parallel_dim == 1 and self.fsdp_cpu_offload:
raise ValueError(
"Tensor parallelism is not supported with FSDP CPU offloading when data parallelism is disabled."
)
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
model = training.prepare_mha_for_tp(model, device_mesh["tp"])
parallelize_module(
Expand Down Expand Up @@ -727,6 +727,7 @@ def _setup_data(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
pad_to_multiple_of=self.tensor_parallel_dim,
)
if not packed
else padded_collate_packed
Expand Down Expand Up @@ -785,7 +786,6 @@ def train(self) -> None:

with self.activations_handling_ctx:
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
Expand Down Expand Up @@ -813,7 +813,7 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (self.world_size / num_tokens)
current_loss = current_loss * (self.dp_size / num_tokens)

current_loss.backward()

Expand All @@ -826,7 +826,7 @@ def train(self) -> None:
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, self.world_size / num_tokens)
training.scale_grads(self._model, self.dp_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
Expand Down
199 changes: 182 additions & 17 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@


class TestPaddedCollateSFT:
def test_batch_pad_sequence(self):
"""
Tests that shorter input, label sequences are padded to the max seq len.
"""
padding_idx = -8
ignore_idx = -9
token_pairs = [
# Shared test constants
padding_idx = -8
ignore_idx = -9

# Common test batch data
@pytest.fixture
def test_batch(self):
return [
{
"tokens": [1, 2, 3],
"labels": [4, 5, 6],
Expand All @@ -40,20 +41,76 @@ def test_batch_pad_sequence(self):
"labels": [10],
},
]
padded = padded_collate_sft(
batch=token_pairs,
padding_idx=padding_idx,
ignore_idx=ignore_idx,

def test_batch_pad_sequence(self, test_batch):
"""
Tests that shorter input, label sequences are padded to the max seq len.
"""
# Apply padding via the collate function
padded_result = padded_collate_sft(
batch=test_batch,
padding_idx=self.padding_idx,
ignore_idx=self.ignore_idx,
)
padded_input = padded["tokens"][1]
padded_label = padded["labels"][1]

# Extract the padded sequences for the second item (shorter sequence)
padded_tokens = padded_result["tokens"][1]
padded_labels = padded_result["labels"][1]

# Verify padding was applied correctly
torch.testing.assert_close(
padded_input, torch.tensor([7, padding_idx, padding_idx])
padded_tokens, torch.tensor([7, self.padding_idx, self.padding_idx])
)
torch.testing.assert_close(
padded_label, torch.tensor([10, ignore_idx, ignore_idx])
padded_labels, torch.tensor([10, self.ignore_idx, self.ignore_idx])
)

def test_batch_pad_sequence_to_multiple_of(self, test_batch):
"""Test that padding to a multiple of X works as expected."""
# Apply padding with multiple-of-5 requirement
padded_result = padded_collate_sft(
batch=test_batch,
padding_idx=self.padding_idx,
ignore_idx=self.ignore_idx,
pad_to_multiple_of=5,
)

# Expected padded tokens (padded to length 5)
expected_tokens = torch.stack(
(
torch.tensor([1, 2, 3, self.padding_idx, self.padding_idx]),
torch.tensor(
[
7,
self.padding_idx,
self.padding_idx,
self.padding_idx,
self.padding_idx,
]
),
)
)

# Expected padded labels (padded to length 5)
expected_labels = torch.stack(
(
torch.tensor([4, 5, 6, self.ignore_idx, self.ignore_idx]),
torch.tensor(
[
10,
self.ignore_idx,
self.ignore_idx,
self.ignore_idx,
self.ignore_idx,
]
),
)
)

# Verify padding was applied correctly
torch.testing.assert_close(padded_result["tokens"], expected_tokens)
torch.testing.assert_close(padded_result["labels"], expected_labels)


class TestPaddedCollateTiledImagesAndMask:
img_shape = 1, 1, 1
Expand Down Expand Up @@ -84,6 +141,40 @@ def batch(self):
},
]

def test_raises_error_with_pad_multiple_provided_and_pad_direction_is_left(
self, batch
):
# We don't support padding to a multiple of X with left padding (inference)
with pytest.raises(
ValueError,
match="pad_to_multiple_of=5 is not supported for pad_direction='left'",
):
padded_collate_tiled_images_and_mask(
batch=batch,
padding_idx=0,
ignore_idx=-100,
pad_to_multiple_of=5,
pad_direction="left",
)

def test_padding_to_multiple(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch,
padding_idx=0,
ignore_idx=-100,
pad_to_multiple_of=5,
)

# Make sure tokens & labels are padded to a multiple of 5
expected_tokens = torch.tensor([[1, 2, 1, 3, 0], [1, 4, 0, 0, 0]])
expected_labels = torch.tensor([[4, 5, 6, 7, -100], [8, 9, -100, -100, -100]])
assert torch.allclose(actual["tokens"], expected_tokens)
assert torch.allclose(actual["labels"], expected_labels)

# We don't have to ensure images look any different b/c they are padded differently
# But we do need to make sure the masks are padded to a multiple of 5
assert actual["encoder_mask"].size(1) % 5 == 0

def test_right_pad_sequence(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="right"
Expand Down Expand Up @@ -312,6 +403,51 @@ def test_left_pad_sequence(self):


class TestPaddedCollate:
def test_throws_error_with_pad_direction_left_and_pad_to_multiple_of(self):
batch = [
{"tokens": [1, 2, 3], "labels": [4, 5, 6]},
]
with pytest.raises(
ValueError,
match="pad_to_multiple_of=7 is not supported for pad_direction='left'",
):
padded_collate(
batch,
pad_direction="left",
keys_to_pad=["tokens"],
padding_idx=-10,
pad_to_multiple_of=7,
)

def test_padded_collate_with_multiple_of(self):
batch = [
{"tokens": [1, 2, 3], "labels": [4, 5, 6]},
{"tokens": [4, 5, 6, 7], "labels": [8, 9, 10, 11]},
{"tokens": [8, 9, 10, 11, 12], "labels": [13, 14, 15, 16, 17]},
]
result = padded_collate(
batch,
pad_direction="right",
keys_to_pad=["tokens", "labels"],
padding_idx=-10,
pad_to_multiple_of=7,
)
expected_tokens = torch.tensor(
[
[1, 2, 3, -10, -10, -10, -10],
[4, 5, 6, 7, -10, -10, -10],
[8, 9, 10, 11, 12, -10, -10],
]
)
expected_labels = torch.tensor(
[
[4, 5, 6, -10, -10, -10, -10],
[8, 9, 10, 11, -10, -10, -10],
[13, 14, 15, 16, 17, -10, -10],
]
)
assert torch.equal(result["tokens"], expected_tokens)

def test_padded_collate_classifier_labels(self):
batch = [
{"tokens": [1, 2, 3], "labels": 1},
Expand Down Expand Up @@ -384,8 +520,9 @@ def test_value_error_raised_when_invalid_pad_direction(self):


class TestPaddedCollateDPO:
def test_dpo_collate(self):
batch = [
@pytest.fixture
def batch(self):
return [
{
"chosen_input_ids": [1, 2, 3],
"chosen_labels": [4, 5, 6],
Expand All @@ -399,6 +536,34 @@ def test_dpo_collate(self):
"rejected_labels": [18, 19, 20],
},
]

def test_dpo_collate_with_pad_to_multiple_of(self, batch):
input_ids, labels = padded_collate_dpo(
batch,
padding_idx=0,
ignore_idx=-100,
pad_to_multiple_of=7,
)
expected_input_ids = torch.tensor(
[
[1, 2, 3, 0, 0, 0, 0],
[11, 12, 0, 0, 0, 0, 0],
[7, 8, 0, 0, 0, 0, 0],
[15, 16, 17, 0, 0, 0, 0],
],
)
expected_labels = torch.tensor(
[
[4, 5, 6, -100, -100, -100, -100],
[13, 14, -100, -100, -100, -100, -100],
[9, 10, -100, -100, -100, -100, -100],
[18, 19, 20, -100, -100, -100, -100],
]
)
assert torch.equal(input_ids, expected_input_ids)
assert torch.equal(labels, expected_labels)

def test_dpo_collate(self, batch):
input_ids, labels = padded_collate_dpo(batch, padding_idx=0, ignore_idx=-100)
expected_input_ids = torch.tensor(
[[1, 2, 3], [11, 12, 0], [7, 8, 0], [15, 16, 17]]
Expand Down
5 changes: 3 additions & 2 deletions tests/torchtune/modules/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree.

import pytest

import torch

from tests.test_utils import assert_expected
from torch.nn.functional import normalize

from torchtune.modules.rms_norm import RMSNorm
from torchtune.training.seed import set_seed

Expand Down Expand Up @@ -64,7 +66,6 @@ def test_forward_fp16(self, rms_norm, input_random_fp16, dim) -> None:

# convert input to float since rms_norm computes in fp32
expected_fp16 = normalize(input_random_fp16.float(), p=2, dim=-1) * (dim**0.5)
expected_fp16 = expected_fp16.to(torch.float16)

assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3)
assert output_fp16.dtype == torch.float16
assert output_fp16.dtype == torch.float32
Loading