diff --git a/recipes/configs/llama4/scout_17B_16E_full.yaml b/recipes/configs/llama4/scout_17B_16E_full.yaml index fdea0ec475..8497b81ca3 100644 --- a/recipes/configs/llama4/scout_17B_16E_full.yaml +++ b/recipes/configs/llama4/scout_17B_16E_full.yaml @@ -77,6 +77,7 @@ compile: False # model: True # loss: True # optimizer_step: False +# scale_grads: True # Reduced precision dtype: bf16 diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 61819ac17b..74e3105fab 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -313,10 +313,19 @@ def setup(self, cfg: DictConfig) -> None: self._compile_model = compile_bool self._compile_loss = compile_bool self._compile_optimizer_step = compile_bool + self._compile_scale_grads = compile_bool if isinstance(compile, DictConfig): self._compile_model = compile.get("model", True) self._compile_loss = compile.get("loss", True) self._compile_optimizer_step = compile.get("optimizer_step", False) + self._compile_scale_grads = compile.get("scale_grads", True) + + # This indirection is needed to apply torch.compile to scale_grads step. + self._grad_scaler = training.scale_grads_ + if self._compile_scale_grads: + self._grad_scaler = torch.compile( + self._grad_scaler, backend=self._compile_backend + ) self._model = self._setup_model( cfg_model=cfg.model, @@ -932,8 +941,12 @@ def train(self) -> None: torch.distributed.all_reduce(num_tokens) # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, self.dp_degree / num_tokens) + self._grad_scaler( + self._model.parameters(), self.dp_degree / num_tokens + ) + if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index b2c327c617..ab5254bb2e 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -25,7 +25,7 @@ shard_model, validate_no_params_on_meta_device, ) -from torchtune.training._grad_scaler import scale_grads +from torchtune.training._grad_scaler import scale_grads, scale_grads_ from torchtune.training._model_util import disable_dropout from torchtune.training._profiler import ( DEFAULT_PROFILE_DIR, @@ -139,6 +139,7 @@ "OffloadActivations", "FormattedCheckpointFiles", "scale_grads", + "scale_grads_", "get_distributed_backend", "disable_dropout", "DATALOADER_KEY", diff --git a/torchtune/training/_grad_scaler.py b/torchtune/training/_grad_scaler.py index 484cd8f372..672c44d573 100644 --- a/torchtune/training/_grad_scaler.py +++ b/torchtune/training/_grad_scaler.py @@ -4,10 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict +from typing import Optional + import torch -from torch import nn +from torch import nn, Tensor +from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors +from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support +from torchtune.utils._logging import deprecated +@deprecated(msg="Please use `scale_grads_` instead.") def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: """ Utility to scale the gradients of a model. @@ -29,3 +36,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: scaler = scaler.to(device) if p.grad is not None: p.grad *= scaler + + +@_no_grad +def scale_grads_( + parameters: _tensor_or_tensors, + scaler: torch.Tensor, + foreach: Optional[bool] = None, +) -> None: + r"""Scale gradients of iterable parameters. + + This function is equivalent to :func:`torch.mul_` applied to each parameter. + Gradients are modified in-place, multiplying by specified scaler. + + Args: + parameters (_tensor_or_tensors): an iterable of Tensors or a + single Tensor that will have gradients scaled + scaler (torch.Tensor): multiplier to scale gradients + foreach (Optional[bool]): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + Returns: + None + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + parameters = list(parameters) + _scale_grad_(parameters, scaler, foreach) + + +def _group_tensors_by_device( + tensors: list[torch.Tensor], +) -> dict[torch.device, list[Tensor]]: + ret = defaultdict(list) + for i, tensor in enumerate(tensors): + ret[tensor.device].append(tensor) + + return ret + + +@_no_grad +def _scale_grad_( + parameters: _tensor_or_tensors, + scaler: torch.Tensor, + foreach: Optional[bool] = None, +) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + if len(grads) == 0: + return + grouped_grads = _group_tensors_by_device(grads) + + for device, device_grads in grouped_grads.items(): + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, scaler.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + scaler_device = scaler.to(device) + for g in device_grads: + g.mul_(scaler_device)