From 0db775d2929f7f97cce3103570310801a34e4ffa Mon Sep 17 00:00:00 2001 From: Rui Silva Date: Thu, 1 May 2025 21:56:05 +0000 Subject: [PATCH] Refine the gradient accumulation API --- test/neuron/run_tests.sh | 1 + test/run_tests.sh | 1 + test/spmd/test_train_spmd_linear_model.py | 18 +- test/test_gradient_accumulation.py | 153 +++++++ test/tpu/run_tests.sh | 1 + .../utils/train_spmd_linear_model_grad_acc.py | 3 +- .../experimental/gradient_accumulation.py | 428 +++++++----------- 7 files changed, 345 insertions(+), 260 deletions(-) create mode 100644 test/test_gradient_accumulation.py diff --git a/test/neuron/run_tests.sh b/test/neuron/run_tests.sh index c5f277fe06a5..630b93fdca5b 100755 --- a/test/neuron/run_tests.sh +++ b/test/neuron/run_tests.sh @@ -167,6 +167,7 @@ function run_xla_op_tests1 { run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py" run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py" run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py" + run_test "$CDIR/test_gradient_accumulation.py" } function run_xla_op_tests2 { diff --git a/test/run_tests.sh b/test/run_tests.sh index 979f8731ca0c..1697b00a7cbd 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -243,6 +243,7 @@ function run_xla_op_tests3 { run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py" run_test "$CDIR/spmd/test_mp_input_sharding.py" run_test "$CDIR/spmd/test_train_spmd_linear_model.py" "$@" --skip-gradient-checkpointing + run_test "$CDIR/test_gradient_accumulation.py" run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_input_output_aliases.py" diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index a8e8b459d7e9..c5e62828852a 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -5,6 +5,7 @@ import unittest import torch +from torch_xla import runtime as xr import test_xla_sharding_base @@ -19,6 +20,9 @@ # the gradient checkpointing A/B test run for it. SKIP_GRADIENT_CHECKPOINTING: bool = False +skipOnGpu = unittest.skipIf(xr.device_type() == 'CUDA', + 'https://github.com/pytorch/xla/issues/9128') + @contextmanager def extended_argv(args): @@ -33,7 +37,7 @@ def extended_argv(args): class TestSPMDLinearModel(test_xla_sharding_base.XlaShardingTest): def test_basic(self): - print('Training loop with baseline') + print('Training loop with baseline', flush=True) with extended_argv([]): baseline_losses, baseline_result = train_and_evaluate() # Verify that the model losses are not zero. @@ -42,7 +46,7 @@ def test_basic(self): assert not torch.any(baseline_result == 0) if not SKIP_GRADIENT_CHECKPOINTING: - print('Training loop with gradient checkpointing') + print('Training loop with gradient checkpointing', flush=True) with extended_argv(['--use_gradient_checkpointing']): checkpointing_losses, checkpointing_result = train_and_evaluate() # Verify that the runs match with and without checkpointing. @@ -62,11 +66,11 @@ def test_gradient_accumulation_matches(self): """ COMMON_GRAD_ACC_ARGS = ["--gradient_accumulation_steps", "8"] - print('Training loop with traditional gradient accumulation') + print('Training loop with traditional gradient accumulation', flush=True) with extended_argv(COMMON_GRAD_ACC_ARGS): baseline_grad_acc_losses = train_and_evaluate_grad_acc() - print('Training loop with XLA\'s `While` gradient accumulation') + print('Training loop with XLA\'s `While` gradient accumulation', flush=True) with extended_argv(COMMON_GRAD_ACC_ARGS + ["--use_gradient_accumulation_loop"]): loop_grad_acc_losses = train_and_evaluate_grad_acc() @@ -79,8 +83,10 @@ def test_gradient_accumulation_matches(self): loop_grad_acc_losses)) if not SKIP_GRADIENT_CHECKPOINTING: - print('Training loop with XLA\'s `While` gradient accumulation and ' - 'gradient checkpointing.') + print( + 'Training loop with XLA\'s `While` gradient accumulation and ' + 'gradient checkpointing.', + flush=True) with extended_argv( COMMON_GRAD_ACC_ARGS + ["--use_gradient_accumulation_loop", "--use_gradient_checkpointing"]): diff --git a/test/test_gradient_accumulation.py b/test/test_gradient_accumulation.py new file mode 100644 index 000000000000..6e431a4237d9 --- /dev/null +++ b/test/test_gradient_accumulation.py @@ -0,0 +1,153 @@ +import unittest +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.test.test_utils as test_utils +from torch_xla.experimental.gradient_accumulation import gradient_accumulation + +from test_utils import XlaTestCase # type:ignore + + +class SimpleModel(torch.nn.Module): + + def __init__(self, input_dim=10, hidden_dim=20, output_dim=5): + super(SimpleModel, self).__init__() + self.fc1 = torch.nn.Linear(input_dim, hidden_dim) + self.fc2 = torch.nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + return self.fc2(x) + + +class GradAccumulationTest(XlaTestCase): + + def setUp(self): + self.device = xm.xla_device() + torch.manual_seed(123) + + def test_basic(self): + """Compare results with and without the XLA loop""" + batch_size = 8 + hidden_dim = 20 + input_dim = 10 + output_dim = 5 + + inputs = torch.randn(batch_size, input_dim).to(self.device) + targets = torch.randn(batch_size, output_dim).to(self.device) + + def train_step_fw(input_batch, target_batch, carried_tensor): + output = model_ga(input_batch) + loss = torch.nn.functional.mse_loss(output, target_batch) + new_carried_tensor = carried_tensor + 5 + return loss, new_carried_tensor + + # Gradient accumulation with XLA loop + torch.manual_seed(43) + model_ga = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) + carried_tensor_ga = torch.tensor([5, 5]).to(self.device) + + accumulated_loss_ga, accum_carried_tensor_ga = gradient_accumulation( + train_step_fw, (inputs, targets), model_ga, carried_tensor_ga) + + torch_xla.sync() + + # Traditional accumulation + torch.manual_seed(43) + model_manual = SimpleModel(input_dim, hidden_dim, + output_dim).to(self.device) + carried_tensor_manual = torch.tensor([5, 5]).to(self.device) + + accumulated_loss_manual = torch.tensor(0.0).to(self.device) + for i in range(batch_size): + loss, carried_tensor_manual = train_step_fw(inputs[i:i + 1], + targets[i:i + 1], + carried_tensor_manual) + loss = loss / batch_size + loss.backward() + accumulated_loss_manual += loss.detach() + + torch_xla.sync() + + # Compare losses, carried tensors and resulting gradients + super().compareResults([accumulated_loss_ga], [accumulated_loss_manual]) + super().compareResults([accum_carried_tensor_ga], [carried_tensor_manual]) + super().compareResults(model_ga.parameters(), model_manual.parameters()) + + def test_with_carried_tensors(self): + """Test gradient accumulation with carried tensors, including with RNG""" + batch_size = 2 + hidden_dim = 20 + input_dim = 10 + output_dim = 5 + + model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) + + inputs = torch.randn(batch_size, input_dim).to(self.device) + targets = torch.randn(batch_size, output_dim).to(self.device) + + # Carried tensors + counter = torch.tensor(0).to(self.device) + tensor0 = torch.tensor(0.0).to(self.device) + tensor0_baseline = tensor0.clone() + + # Define train step function that updates the carried tensor. In the case of + # RNG, we negate the previous value, in order to validate that we get unique + # RNG seeds for each iteration. + def train_step_fw(input_batch, target_batch, counter, tensor0): + output = model(input_batch) + loss = torch.nn.functional.mse_loss(output, target_batch) + # Update counter + new_counter = counter + 1 + new_tensor0 = torch.rand_like(tensor0, device=self.device) - tensor0 + return loss, new_counter, new_tensor0 + + # Run gradient accumulation + accumulated_loss, final_counter, final_tensor0 = gradient_accumulation( + train_step_fw, (inputs, targets), model, counter, tensor0) + + torch_xla.sync() + + self.assertEqual(final_counter.item(), batch_size) + # Ensure that the result is not 0, showcasing that the RNG is unique + # per iteration. + self.assertNotEqual(final_tensor0.item(), 0.0) + + def test_error_empty_iterable_tensors(self): + """Test that empty iterable_tensors raises an error.""" + model = SimpleModel().to(self.device) + + def train_step_fw(): + pass + + with self.assertRaises(ValueError): + gradient_accumulation(train_step_fw, [], model) + + def test_error_mutated_input_tensors(self): + """Test that mutating input tensors raises an error.""" + batch_size = 2 + hidden_dim = 20 + input_dim = 10 + output_dim = 5 + + model = SimpleModel(input_dim, hidden_dim, output_dim).to(self.device) + + inputs = torch.randn(batch_size, input_dim).to(self.device) + targets = torch.randn(batch_size, output_dim).to(self.device) + counter = torch.tensor(0).to(self.device) + + def train_step_fw(input_batch, target_batch, counter): + output = model(input_batch) + loss = torch.nn.functional.mse_loss(output, target_batch) + # In-place mutation of an input tensor. + counter += 1 + return loss, counter + + with self.assertRaises(AssertionError): + accumulated_loss, final_counter = gradient_accumulation( + train_step_fw, (inputs, targets), model, counter) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6b596a157460..eaa8add40fbe 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -19,6 +19,7 @@ python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py" python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py" python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py" python3 "$TEST_CDIR/spmd/test_fsdp_v2.py" +python3 "$TEST_CDIR/test_gradient_accumulation.py" XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shape_models.py" -v python3 "$TEST_CDIR/test_autocast.py" python3 "$TEST_CDIR/test_fp8.py" diff --git a/test/utils/train_spmd_linear_model_grad_acc.py b/test/utils/train_spmd_linear_model_grad_acc.py index 706ae4904041..9fb66a99cd90 100644 --- a/test/utils/train_spmd_linear_model_grad_acc.py +++ b/test/utils/train_spmd_linear_model_grad_acc.py @@ -125,8 +125,7 @@ def train_step(input_id, label): def train_loop_fn(data, target, running_loss): if FLAGS.use_gradient_accumulation_loop: - running_loss, = gradient_accumulation(train_step, (data, target), model, - None) + running_loss = gradient_accumulation(train_step, (data, target), model) else: for i in range(FLAGS.gradient_accumulation_steps): loss = train_step(data[i], target[i]) diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index 5299291861ff..4e3f8682e68e 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -1,9 +1,10 @@ import torch import torch_xla import torch_xla.core.xla_builder as xb +import torch_xla.core.xla_model as xm from dataclasses import dataclass -from typing import Any, Callable, Sequence, Tuple, Optional, List, Dict +from typing import Any, Callable, Sequence, Tuple, Optional, List, Union import warnings @@ -14,29 +15,20 @@ class GradientAccumulationContext: * num_gradient_steps: Number of steps to accumulate gradients over * num_iterable_tensors: Number of input tensors to iterate over * num_carried_tensors: Number of tensors carried between iterations - * num_model_params: Number of model parameters - * num_internal_tensors: Number of internal tensors used (default: 2) - - Note: `num_internal_tensors` should only be changed if we create new internal - tensors. """ num_gradient_steps: int num_iterable_tensors: int num_carried_tensors: int - num_model_params: int - num_internal_tensors: int = 2 def gradient_accumulation( - train_step: Callable[..., Any], - iterable_tensors: Sequence[torch.Tensor], - model: torch.nn.Module, - carried_tensors: Optional[Tuple[torch.Tensor, ...]] = None -) -> Tuple[torch.Tensor, ...]: + train_step: Callable[..., Any], iterable_tensors: Tuple[torch.Tensor], + model: torch.nn.Module, *carried_tensors: torch.Tensor +) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """Accumulates gradients over multiple training steps using XLA's `While` - operator to iterate over the leading dimension of the iterable tensors. - The backward computation of the model is implicitly executed following the - train_step operations. + operator to iterate over the leading dimension of the iterable tensors. + The backward computation of the model is implicitly executed following the + train_step operations. Notes: @@ -44,9 +36,10 @@ def gradient_accumulation( assumed that `train_step` is purposefully encapsulated inside of the loop. Hence, it is not recommended to have any operation involving the model parameters outside of `train_step`. - * Note that zeroing the gradients to zero instead of None, (e.g. - `.zero_grad(set_to_none=False)) will avoid the device transfer of the - initial gradients in every call. + * Note that all the inputs are expected to be materialized. In case these are + not materialized, they will be synced early on. + * We expect the train step to be pure. In case it is not, for instance, + containing in-place mutations for the body fn inputs, then we error out. Args: train_step: Training function that takes iterable tensors and carried @@ -62,29 +55,30 @@ def gradient_accumulation( model: PyTorch model whose parameters will be updated. Note that the entire model computation will be traced and generated from within the loop. - carried_tensors: Optional tensors passed and updated between iterations. + carried_tensors: Unpacked tensor arguments that are updated between iterations. Returns: - (accumulated_loss, carried_tensor0, carried_tensor1, ...): A tuple including - the `accumulated_loss` and the same unpacked `carried_tensors` that were - provided as inputs. In addition, the model parameter gradients, if - applicable, contain the accumulated gradients. + accumulated_loss, *carried_tensors: A tuple including the `accumulated_loss` + and the unpacked `carried_tensors` that were provided as inputs. In case + there are no carried tensors, then only the `accumulated_loss` is returned. + In addition, the model parameter gradients, if applicable, contain the + accumulated gradients. Example: >>> # Note: This is a partial example, since it is dependent on the >>> # training model. Please refer to existing tests. - >>> + >>> >>> from torch_xla.experimental.gradient_accumulation import ( >>> gradient_accumulation >>> ) - >>> + >>> >>> def train_step(input, label, other_tensor): >>> output = model(input_id) >>> loss = loss_fn(output, label) - >>> updated_other_tensor += 10 + >>> updated_other_tensor = other_tensor + 10 >>> return loss, updated_other_tensor - >>> + >>> >>> some_tensor = torch.tensor(10).to(device) >>> for (data, target) in loader: >>> # Assuming data's and target's first iterable dimension is 5. @@ -94,7 +88,7 @@ def gradient_accumulation( >>> train_step, >>> (data, target), >>> model, - >>> (some_tensor,) + >>> some_tensor >>> ) >>> print(some_tensor) # Should be 60 >>> print(running_loss) # Should be the accumulated loss across all 5 @@ -115,7 +109,6 @@ def gradient_accumulation( if tensor.size(0) != accumulation_steps: raise ValueError( f"Element {i} of iterable_tensors has inconsistent first dimension") - carried_tensors = carried_tensors or tuple() return _gradient_accumulation(accumulation_steps, train_step, iterable_tensors, model, carried_tensors) @@ -149,147 +142,93 @@ def num_params(self) -> int: return len(self._params) +def _make_init_grad(param): + grad = torch.zeros_like(param, device=param.device, requires_grad=False) + param_sharding = torch_xla._XLAC._get_xla_op_sharding(param) + if param_sharding: + # Match the gradient sharding to the parameter sharding, if present. + torch_xla._XLAC._xla_mark_sharding(grad, param_sharding) + return grad + + def _gradient_accumulation_impl(context, body_fn, iterable_tensors, params, carried_tensors): builder = XlaBuildHelper('grad_acc') device = torch_xla.device() - def _prepare_fake_tensors( - iterable_tensors: Sequence[torch.Tensor], - carried_tensors: Sequence[torch.Tensor], - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - - def __make_placeholder(tensor, is_iterable): - shape = tensor.shape[1:] if is_iterable else tensor.shape - return xb.create_placeholder_tensor( - shape=shape, dtype=tensor.dtype).requires_grad_(tensor.requires_grad) - - fake_iterable_tensors = [ - __make_placeholder(iter_tensor, True) - for iter_tensor in iterable_tensors - ] - fake_carried_tensors = [ - __make_placeholder(carried_tensor, False) - for carried_tensor in carried_tensors - ] - return fake_iterable_tensors, fake_carried_tensors - - # TODO - Fake the model once we are able to create placeholder tensors. - fake_iterable_tensors, fake_carried_tensors = _prepare_fake_tensors( - iterable_tensors, carried_tensors) init_iterator = torch.tensor(0, dtype=torch.int32, device=device) init_loss = torch.tensor(0, dtype=torch.float32, device=device) + init_grads = [_make_init_grad(p) for p in params if p.requires_grad] + + builder.add_param(init_iterator) + builder.add_param(init_loss) + for grad in init_grads: + builder.add_param(grad) + + iterable_tensor_slices = tuple(t[0] for t in iterable_tensors) + body_fn_inputs = [*iterable_tensor_slices, *params, *carried_tensors] - grads = [param.grad for param in params] - body_fn_inputs = (init_iterator, init_loss, *fake_iterable_tensors, - *fake_carried_tensors, *params, *grads) - # TODO - Fake the gradients once we are able to create placeholder tensors. - # Since the body is expected to do an in-place mutation of the gradients, we - # clone the gradients and use that as an input to the body. This will ensure - # that we retain a device data IR node in the graph. The cloned gradient will - # be updated to denote an IR operation (e.g. %add), and that can not be - # captured as a device data input for the other required computations, namely - # the condition and init for the XLA while loop. - for param in params: - param.grad = param.grad.clone() - body_result = body_fn(init_iterator, init_loss, tuple(fake_iterable_tensors), - tuple(fake_carried_tensors), tuple(params), - tuple(grads)) + # Ensure newly initialized inputs are fully materialized for the purpose of tracing. + # In addition, we sync the inputs that are expected to be materialized, namely all + # the model parameter, gradients (if present), iterable_tensors and carried_tensors. + # TODO(rpsilva): Use placeholder tensors for the iterable tensor slices, to avoid + # a redundant sync. + torch_xla._XLAC._xla_sync_multi( + body_fn_inputs, devices=[], wait=False, sync_xla_data=False) + + tid_to_body_fn_inputs = { + torch_xla._XLAC._xla_get_tensor_id(t): (idx, t) + for idx, t in enumerate(body_fn_inputs) + } + + for idx, (t, s) in enumerate(zip(iterable_tensors, iterable_tensor_slices)): + # Map each iterable tensor slice back to the full tensor + tid_to_body_fn_inputs[torch_xla._XLAC._xla_get_tensor_id(s)] = (idx, t) + + body_result = body_fn(iterable_tensor_slices, tuple(params), carried_tensors) + body_result = list(body_result) + + # Ensure that all the body fn inputs are not mutated in-place, as we need to + # guarantee that the body function is pure. + assert not any( + torch_xla._XLAC._check_tensor_need_materialization(body_fn_inputs)) + + # Ensure that any prior async operations on the device has terminated. + xm.wait_device_ops() ( graph_input_tensor_ids, graph_input_xla_values, - ) = torch_xla._XLAC._get_tensors_xla_device_data_node( - list(body_result) + list(body_fn_inputs)) - - body_fn_input_tensor_ids = [ - torch_xla._XLAC._xla_get_tensor_id(i) for i in body_fn_inputs - ] - uncaptured_input_tensor_ids = tuple( - v for i, v in zip(graph_input_tensor_ids, graph_input_xla_values) - if i not in body_fn_input_tensor_ids) - - body_ctx = torch_xla._XLAC.lowering.LoweringContext() - body_ctx.set_name_string("bodyctx") - body_ctx.build(body_result + uncaptured_input_tensor_ids) - body_hlo = body_ctx.hlo() + ) = torch_xla._XLAC._get_tensors_xla_device_data_node(body_result, + body_fn_inputs) + body_hlo = torch_xla._XLAC._get_xla_tensors_hlo_proto(body_result) body_computation = xb.computation_from_module_proto("bodycomputation", body_hlo) - - builder.add_param(init_iterator) - builder.add_param(init_loss) - - def _build_parameter_mapping( - builder: XlaBuildHelper, - context: GradientAccumulationContext, - body_fn_inputs: Tuple[torch.Tensor, ...], - uncaptured_input_tensor_ids: Tuple[torch.Tensor, ...], - iterable_tensors: Sequence[torch.Tensor], - fake_iterable_tensors: Sequence[torch.Tensor], - carried_tensors: Tuple[torch.Tensor, ...], - fake_carried_tensors: Tuple[torch.Tensor, ...], - params: List[torch.Tensor], - grads: List[torch.Tensor], - ) -> Dict[int, int]: - param_mapping = {} - - def add_to_mapping(val: torch.Tensor, - fake_val: Optional[torch.Tensor] = None): - idx = builder.add_param(val) - param_id = body_ctx.tensor_parameter_id( - fake_val if fake_val is not None else val) - if param_id != -1: - param_mapping[param_id] = idx - - # Process iterable tensors and carried inputs - for val, fake_val in zip(iterable_tensors, fake_iterable_tensors): - add_to_mapping(val, fake_val) - for val, fake_val in zip(carried_tensors, fake_carried_tensors): - add_to_mapping(val, fake_val) - - # Process params, grads, and uncaptured input tensor ids - for tensor_list in (params, grads, uncaptured_input_tensor_ids): - for val in tensor_list: - add_to_mapping(val) - - # Handle any additional hoisted variables - hoisted_vars = body_ctx.device_parameter_id_tensor_mapping() - for v in body_fn_inputs + uncaptured_input_tensor_ids: - param_id = body_ctx.tensor_parameter_id(v) - hoisted_vars.pop(param_id, None) - - # TODO(rpsilva-aws): Derived from `experimental/scan.py`. Unify the RNG and - # hoisted paths. - seed_info_id = torch_xla._XLAC._get_seed_info_id() - seed_parameter_id = None - if seed_info_id in graph_input_tensor_ids: - seed_idx = graph_input_tensor_ids.index(seed_info_id) - seed_parameter_id = body_ctx.tensor_parameter_id( - graph_input_xla_values[seed_idx]) - assert seed_parameter_id != -1, "`fn` uses random seed, but random seed is not \ - a parameter to the traced HLO graph" - - # Replace the single seed value with a tensor of seeds, one per iteration. - seed_tensor = hoisted_vars[seed_parameter_id] - assert seed_tensor.dtype == torch.int64 - hoisted_vars[seed_parameter_id] = torch.randint( - 0, - 2**62, (context.num_gradient_steps,), - dtype=torch.int64, - device=device) - - for param_id, tensor in hoisted_vars.items(): - idx = builder.add_param(tensor) - param_mapping[param_id] = idx - return param_mapping, seed_parameter_id - - param_mapping, seed_parameter_id = _build_parameter_mapping( - builder, context, body_fn_inputs, uncaptured_input_tensor_ids, - iterable_tensors, fake_iterable_tensors, carried_tensors, - fake_carried_tensors, params, grads) - - def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, *while_params: - xb.Op): + del body_result + + # Capture the seed info ID, to identify all the seed tensors. + seed_info_id = torch_xla._XLAC._get_seed_info_id() + seed_info_indices = set() + # Maps the graph index to body_fn input/output index. Note that the list + # indices implicitly denote the position for the body_fn. + graph_to_body_idx = [] + # Collect all graph inputs if they match the provided body_fn_inputs. In case + # of hoisted tensors, we simply capture the XLA values. Note that for hoisted + # seed tensors, we fork a unique seed tensor for all iterations, to guarantee + # that the values differ across loop iterations. + for tid, xla_value in zip(graph_input_tensor_ids, graph_input_xla_values): + idx, t = tid_to_body_fn_inputs.get(tid, (None, xla_value)) + if tid == seed_info_id: + seed_info_indices.add(len(graph_to_body_idx)) + t = torch.randint( + 0, (1 << 63) - 1, (context.num_gradient_steps,), device=device) + builder.add_param(t) + graph_to_body_idx.append(idx) + + carried_tensors_start_idx = len(body_fn_inputs) - context.num_carried_tensors + + def _body_fn_wrapper(curr_iter: xb.Op, loss: xb.Op, *args: + xb.Op) -> Tuple[xb.Op, ...]: def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: indices = [idx] + [idx.zeros_like() for _ in range(xs.shape().rank - 1)] @@ -298,72 +237,56 @@ def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: sliced = xs.dynamic_slice(indices, slice_shape) return sliced.reshape(list(xs.shape().sizes)[1:]) - # TODO(rpsilva-aws): Derived from `experimental/scan.py`. Unify the RNG - # path. - def replace_rng_seed(curr_iter: xb.Op, *while_params: xb.Op): - """Slices the pre-generated seed tensor for the current iteration.""" - if seed_parameter_id is None: - return while_params - idx = param_mapping[seed_parameter_id] - replaced = list(while_params) - replaced[idx] = dynamic_slice(replaced[idx], curr_iter) - return replaced - - def call_fn_computation(*while_params: xb.Op) -> xb.Op: - fn_inputs = [ - while_params[param_mapping[i]] for i in range(len(param_mapping)) - ] - return xb.Op.call(body_computation, fn_inputs) - - iterable_tensors = while_params[:context.num_iterable_tensors] - idx = curr_iter - sliced_iterables = [ - dynamic_slice(iter_tensor, idx) for iter_tensor in iterable_tensors - ] - - # Call the computation with current values - result = call_fn_computation( - idx, curr_loss, - *replace_rng_seed(idx, *sliced_iterables, - *while_params[context.num_iterable_tensors:])) - - # Extract the carried tensors and accumulated gradients. - carried_tensors_and_gradients = [ - result.get_tuple_element(i) for i in range( - context.num_internal_tensors + context.num_iterable_tensors, - result.shape().tuple_size()) - ] - one = xb.Op.scalar(idx.builder(), 1, dtype=xb.Type.S32) - updated_loss = curr_loss + result.get_tuple_element(1) - return (curr_iter + one, updated_loss, *iterable_tensors, - *carried_tensors_and_gradients) - - def _cond_fn(curr_iter: xb.Op, *rest): + grads = list(args[:len(init_grads)]) + graph_inputs = list(args[len(grads):]) + graph_outputs = graph_inputs[:] + + # Dynamic slice on iterable and seed tensors. + for graph_idx, body_idx in enumerate(graph_to_body_idx): + if graph_idx in seed_info_indices: + # Substitute the seed tensor that is fed to the user computation. + graph_inputs[graph_idx] = dynamic_slice(graph_inputs[graph_idx], + curr_iter) + if body_idx is None: + # Ignore hoisted variables + continue + if body_idx < context.num_iterable_tensors: + graph_inputs[graph_idx] = dynamic_slice(graph_inputs[graph_idx], + curr_iter) + + result = xb.Op.call(body_computation, graph_inputs) + + # Accumulate loss and grads + loss = loss + result.get_tuple_element(0) + for i, grad in enumerate(grads): + grads[i] = grad + result.get_tuple_element(i + 1) + + # Update carried tensors in graph_inputs for next iteration + for graph_idx, body_idx in enumerate(graph_to_body_idx): + if body_idx is not None and body_idx >= carried_tensors_start_idx: + idx = body_idx - carried_tensors_start_idx + graph_outputs[graph_idx] = result.get_tuple_element(1 + len(grads) + + idx) + + one = xb.Op.scalar(curr_iter.builder(), 1, dtype=xb.Type.S32) + return (curr_iter + one, loss, *grads, *graph_outputs) + + def _cond_fn(curr_iter: xb.Op, *args: xb.Op) -> bool: return curr_iter < xb.Op.scalar( curr_iter.builder(), context.num_gradient_steps, dtype=xb.Type.S32) - def _compute_output_indices( - context: GradientAccumulationContext) -> List[int]: - # Start with loss index - indices = [1] - # Add indices for carried tensors - carried_start = context.num_internal_tensors + context.num_iterable_tensors - carried_end = carried_start + context.num_carried_tensors - indices.extend(range(carried_start, carried_end)) - # Add indices for accumulated gradients - grad_start = carried_end + context.num_model_params - grad_end = grad_start + context.num_model_params - indices.extend(range(grad_start, grad_end)) - return indices - w = xb.Op.mkwhile(builder.params, _cond_fn, _body_fn_wrapper) - outputs = [w.get_tuple_element(i) for i in _compute_output_indices(context)] - op = xb.Op.tuple(outputs) - computation = op.build('grad_acc_loop_torch_func') - result = torch_xla._XLAC._xla_user_computation('xla::_op_grad_acc_loop', - builder.param_tensors, - computation) - return result + computation = w.build('grad_acc_loop_torch_func') + _, loss, *outputs = torch_xla._XLAC._xla_user_computation( + 'xla::_op_grad_acc_loop', builder.param_tensors, computation) + grads = outputs[:len(init_grads)] + graph_outputs = outputs[len(grads):] + carried_tensors = [None] * context.num_carried_tensors + for graph_idx, body_idx in enumerate(graph_to_body_idx): + if body_idx is not None and body_idx >= carried_tensors_start_idx: + idx = carried_tensors_start_idx - body_idx + carried_tensors[idx] = graph_outputs[graph_idx] + return (loss, grads, carried_tensors) def _gradient_accumulation(accumulation_steps, train_step, iterable_tensors, @@ -371,14 +294,19 @@ def _gradient_accumulation(accumulation_steps, train_step, iterable_tensors, model_parameters = list(model.parameters()) context = GradientAccumulationContext(accumulation_steps, len(iterable_tensors), - len(carried_tensors), - len(model_parameters)) - - def body_fn(iteri: torch.Tensor, _: torch.Tensor, - iterable_tensors: Tuple[torch.Tensor, ...], - carried_tensors: Tuple[torch.Tensor, - ...], params: Tuple[torch.Tensor, ...], - grads: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + len(carried_tensors)) + + def body_fn( + iterable_tensors: Tuple[torch.Tensor, ...], params: Tuple[torch.Tensor, + ...], + carried_tensors: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: + # We set the grads to None, and return the computed gradients within this body + # iteration, accumulating it in the body wrapper. Note that body_fn needs to be + # pure, without side effects and tensor mutations. + orig_grads = [param.grad for param in params] + for param in params: + param.grad = None + result = train_step(*iterable_tensors, *carried_tensors) if not context.num_carried_tensors: @@ -387,32 +315,28 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor, loss, *carried_tensors = result loss /= context.num_gradient_steps loss.backward() - grads = [param.grad for param in params] - return (iteri, loss, *iterable_tensors, *carried_tensors, *params, *grads) - for param in model_parameters: - if not param.requires_grad: - continue + grads = [param.grad for param in params if param.requires_grad] + # Restore original grads to make this function pure. + for param, grad in zip(params, orig_grads): + param.grad = grad + + return (loss, *grads, *carried_tensors) + + loss, grads, carried_tensors = _gradient_accumulation_impl( + context, body_fn, iterable_tensors, model_parameters, carried_tensors) + + params_with_grad = [ + param for param in model_parameters if param.requires_grad + ] + # Accumulate the resulting gradients. + for param, grad in zip(params_with_grad, grads): if param.grad is None: - param.grad = torch.zeros(param.size()).to( - param.device).requires_grad_(False) - param_sharding = torch_xla._XLAC._get_xla_op_sharding(param) - if param_sharding: - # Match the gradient sharding to the parameter's. - torch_xla._XLAC._xla_mark_sharding(param.grad, param_sharding) - - # Ensure that the input or pre-initialized gradient tensors can be donated - # after reassigned to the respective model parameters. If the buffer donor - # is not enabled, then this is a no-op. - torch_xla._XLAC._set_buffer_donation(param.grad, True) - - # Apply gradients to parameters - result = _gradient_accumulation_impl(context, body_fn, iterable_tensors, - model_parameters, carried_tensors) - - for param, grad in zip(model_parameters, - result[1 + context.num_carried_tensors:]): - if param.requires_grad: param.grad = grad + else: + param.grad.add_(grad) + + if not carried_tensors: + return loss - return (result[0], *result[1:context.num_carried_tensors + 1]) + return loss, *carried_tensors