Skip to content

Refine the gradient accumulation API #9078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ function run_xla_op_tests1 {
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
run_test "$CDIR/test_gradient_accumulation.py"
}

function run_xla_op_tests2 {
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing
run_test "$CDIR/test_gradient_accumulation.py"
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
Expand Down
18 changes: 12 additions & 6 deletions test/spmd/test_train_spmd_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unittest

import torch
from torch_xla import runtime as xr

import test_xla_sharding_base

Expand All @@ -19,6 +20,9 @@
# the gradient checkpointing A/B test run for it.
SKIP_GRADIENT_CHECKPOINTING: bool = False

skipOnGpu = unittest.skipIf(xr.device_type() == 'CUDA',
'https://github.com/pytorch/xla/issues/9128')


@contextmanager
def extended_argv(args):
Expand All @@ -33,7 +37,7 @@ def extended_argv(args):
class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest):

def test_basic(self):
print('Training loop with baseline')
print('Training loop with baseline', flush=True)
with extended_argv([]):
baseline_losses, baseline_result = train_and_evaluate()
# Verify that the model losses are not zero.
Expand All @@ -42,7 +46,7 @@ def test_basic(self):
assert not torch.any(baseline_result == 0)

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with gradient checkpointing')
print('Training loop with gradient checkpointing', flush=True)
with extended_argv(['--use_gradient_checkpointing']):
checkpointing_losses, checkpointing_result = train_and_evaluate()
# Verify that the runs match with and without checkpointing.
Expand All @@ -62,11 +66,11 @@ def test_gradient_accumulation_matches(self):
"""

COMMON_GRAD_ACC_ARGS = ["--gradient_accumulation_steps", "8"]
print('Training loop with traditional gradient accumulation')
print('Training loop with traditional gradient accumulation', flush=True)
with extended_argv(COMMON_GRAD_ACC_ARGS):
baseline_grad_acc_losses = train_and_evaluate_grad_acc()

print('Training loop with XLA\'s `While` gradient accumulation')
print('Training loop with XLA\'s `While` gradient accumulation', flush=True)
with extended_argv(COMMON_GRAD_ACC_ARGS +
["--use_gradient_accumulation_loop"]):
loop_grad_acc_losses = train_and_evaluate_grad_acc()
Expand All @@ -79,8 +83,10 @@ def test_gradient_accumulation_matches(self):
loop_grad_acc_losses))

if not SKIP_GRADIENT_CHECKPOINTING:
print('Training loop with XLA\'s `While` gradient accumulation and '
'gradient checkpointing.')
print(
'Training loop with XLA\'s `While` gradient accumulation and '
'gradient checkpointing.',
flush=True)
with extended_argv(
COMMON_GRAD_ACC_ARGS +
["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]):
Expand Down
153 changes: 153 additions & 0 deletions test/test_gradient_accumulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import unittest
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.test.test_utils as test_utils
from torch_xla.experimental.gradient_accumulation import gradient_accumulation

from test_utils import XlaTestCase # type:ignore


class SimpleModel(torch.nn.Module):

def __init__(self, input_dim=10, hidden_dim=20, output_dim=5):
super(SimpleModel, self).__init__()
self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, output_dim)

def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)


class GradAccumulationTest(XlaTestCase):

def setUp(self):
self.device = xm.xla_device()
torch.manual_seed(123)

def test_basic(self):
"""Compare results with and without the XLA loop"""
batch_size = 8
hidden_dim = 20
input_dim = 10
output_dim = 5

inputs = torch.randn(batch_size, input_dim).to(self.device)
targets = torch.randn(batch_size, output_dim).to(self.device)

def train_step_fw(input_batch, target_batch, carried_tensor):
output = model_ga(input_batch)
loss = torch.nn.functional.mse_loss(output, target_batch)
new_carried_tensor = carried_tensor + 5
return loss, new_carried_tensor

# Gradient accumulation with XLA loop
torch.manual_seed(43)
model_ga = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)
carried_tensor_ga = torch.tensor([5, 5]).to(self.device)

accumulated_loss_ga, accum_carried_tensor_ga = gradient_accumulation(
train_step_fw, (inputs, targets), model_ga, carried_tensor_ga)

torch_xla.sync()

# Traditional accumulation
torch.manual_seed(43)
model_manual = SimpleModel(input_dim, hidden_dim,
output_dim).to(self.device)
carried_tensor_manual = torch.tensor([5, 5]).to(self.device)

accumulated_loss_manual = torch.tensor(0.0).to(self.device)
for i in range(batch_size):
loss, carried_tensor_manual = train_step_fw(inputs[i:i + 1],
targets[i:i + 1],
carried_tensor_manual)
loss = loss / batch_size
loss.backward()
accumulated_loss_manual += loss.detach()

torch_xla.sync()

# Compare losses, carried tensors and resulting gradients
super().compareResults([accumulated_loss_ga], [accumulated_loss_manual])
super().compareResults([accum_carried_tensor_ga], [carried_tensor_manual])
super().compareResults(model_ga.parameters(), model_manual.parameters())

def test_with_carried_tensors(self):
"""Test gradient accumulation with carried tensors, including with RNG"""
batch_size = 2
hidden_dim = 20
input_dim = 10
output_dim = 5

model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)

inputs = torch.randn(batch_size, input_dim).to(self.device)
targets = torch.randn(batch_size, output_dim).to(self.device)

# Carried tensors
counter = torch.tensor(0).to(self.device)
tensor0 = torch.tensor(0.0).to(self.device)
tensor0_baseline = tensor0.clone()

# Define train step function that updates the carried tensor. In the case of
# RNG, we negate the previous value, in order to validate that we get unique
# RNG seeds for each iteration.
def train_step_fw(input_batch, target_batch, counter, tensor0):
output = model(input_batch)
loss = torch.nn.functional.mse_loss(output, target_batch)
# Update counter
new_counter = counter + 1
new_tensor0 = torch.rand_like(tensor0, device=self.device) - tensor0
return loss, new_counter, new_tensor0

# Run gradient accumulation
accumulated_loss, final_counter, final_tensor0 = gradient_accumulation(
train_step_fw, (inputs, targets), model, counter, tensor0)

torch_xla.sync()

self.assertEqual(final_counter.item(), batch_size)
# Ensure that the result is not 0, showcasing that the RNG is unique
# per iteration.
self.assertNotEqual(final_tensor0.item(), 0.0)

def test_error_empty_iterable_tensors(self):
"""Test that empty iterable_tensors raises an error."""
model = SimpleModel().to(self.device)

def train_step_fw():
pass

with self.assertRaises(ValueError):
gradient_accumulation(train_step_fw, [], model)

def test_error_mutated_input_tensors(self):
"""Test that mutating input tensors raises an error."""
batch_size = 2
hidden_dim = 20
input_dim = 10
output_dim = 5

model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device)

inputs = torch.randn(batch_size, input_dim).to(self.device)
targets = torch.randn(batch_size, output_dim).to(self.device)
counter = torch.tensor(0).to(self.device)

def train_step_fw(input_batch, target_batch, counter):
output = model(input_batch)
loss = torch.nn.functional.mse_loss(output, target_batch)
# In-place mutation of an input tensor.
counter += 1
return loss, counter

with self.assertRaises(AssertionError):
accumulated_loss, final_counter = gradient_accumulation(
train_step_fw, (inputs, targets), model, counter)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py"
python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py"
python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py"
python3 "$TEST_CDIR/spmd/test_fsdp_v2.py"
python3 "$TEST_CDIR/test_gradient_accumulation.py"
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shape_models.py" -v
python3 "$TEST_CDIR/test_autocast.py"
python3 "$TEST_CDIR/test_fp8.py"
Expand Down
3 changes: 1 addition & 2 deletions test/utils/train_spmd_linear_model_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def train_step(input_id, label):

def train_loop_fn(data, target, running_loss):
if FLAGS.use_gradient_accumulation_loop:
running_loss, = gradient_accumulation(train_step, (data, target), model,
None)
running_loss = gradient_accumulation(train_step, (data, target), model)
else:
for i in range(FLAGS.gradient_accumulation_steps):
loss = train_step(data[i], target[i])
Expand Down
Loading