-
Notifications
You must be signed in to change notification settings - Fork 693
scale_grads with foreach + compile #2624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
14736b8
87fe9f7
742ff81
7a6fb12
1dd7eb2
d0d42b8
0523776
93c05ce
eaa7a90
cfab1c3
a32d5c9
43934ab
da6ca90
63a8dca
69f9e16
8925f87
6242893
b0a279c
bd7584b
9e1ab29
452cd7d
a7aeb65
06e55ee
ae2e965
35020b2
45d150e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -313,10 +313,19 @@ def setup(self, cfg: DictConfig) -> None: | |
| self._compile_model = compile_bool | ||
| self._compile_loss = compile_bool | ||
| self._compile_optimizer_step = compile_bool | ||
| self._compile_scale_grads = compile_bool | ||
| if isinstance(compile, DictConfig): | ||
| self._compile_model = compile.get("model", True) | ||
| self._compile_loss = compile.get("loss", True) | ||
| self._compile_optimizer_step = compile.get("optimizer_step", False) | ||
| self._compile_scale_grads = compile.get("scale_grads", True) | ||
|
|
||
| # This indirection is needed to apply torch.compile to scale_grads step. | ||
| self._grad_scaler = training.scale_grads_ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you leave a comment here explaining that we need this indirection for things to work w/ PT2 compile for some reason?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will add a comment. |
||
| if self._compile_scale_grads: | ||
| self._grad_scaler = torch.compile( | ||
| self._grad_scaler, backend=self._compile_backend | ||
| ) | ||
|
|
||
| self._model = self._setup_model( | ||
| cfg_model=cfg.model, | ||
|
|
@@ -932,8 +941,12 @@ def train(self) -> None: | |
| torch.distributed.all_reduce(num_tokens) | ||
| # This will ensure that the logged loss matches what we're optimizing | ||
| torch.distributed.all_reduce(running_loss) | ||
|
|
||
| # Manually scale the gradients from unnormalized loss by total # of tokens | ||
| training.scale_grads(self._model, self.dp_degree / num_tokens) | ||
| self._grad_scaler( | ||
| self._model.parameters(), self.dp_degree / num_tokens | ||
| ) | ||
|
|
||
| if self._clip_grad_norm is not None: | ||
| grad_norm = torch.nn.utils.clip_grad_norm_( | ||
| self._model.parameters(), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ | |
| shard_model, | ||
| validate_no_params_on_meta_device, | ||
| ) | ||
| from torchtune.training._grad_scaler import scale_grads | ||
| from torchtune.training._grad_scaler import scale_grads, scale_grads_ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this needs to be added to all |
||
| from torchtune.training._model_util import disable_dropout | ||
| from torchtune.training._profiler import ( | ||
| DEFAULT_PROFILE_DIR, | ||
|
|
@@ -139,6 +139,7 @@ | |
| "OffloadActivations", | ||
| "FormattedCheckpointFiles", | ||
| "scale_grads", | ||
| "scale_grads_", | ||
| "get_distributed_backend", | ||
| "disable_dropout", | ||
| "DATALOADER_KEY", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,17 @@ | |
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from collections import defaultdict | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch import nn, Tensor | ||
| from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors | ||
| from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support | ||
| from torchtune.utils._logging import deprecated | ||
|
|
||
|
|
||
| @deprecated(msg="Please use `scale_grads_` instead.") | ||
| def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebsmothers Should we add a deprecation warning here then? No need to keep both.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Talked with Evan, let's add the deprecation label here. |
||
| """ | ||
| Utility to scale the gradients of a model. | ||
|
|
@@ -29,3 +36,70 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: | |
| scaler = scaler.to(device) | ||
| if p.grad is not None: | ||
| p.grad *= scaler | ||
|
|
||
|
|
||
| @_no_grad | ||
| def scale_grads_( | ||
| parameters: _tensor_or_tensors, | ||
| scaler: torch.Tensor, | ||
| foreach: Optional[bool] = None, | ||
| ) -> None: | ||
| r"""Scale gradients of iterable parameters. | ||
|
|
||
| This function is equivalent to :func:`torch.mul_` applied to each parameter. | ||
| Gradients are modified in-place, multiplying by specified scaler. | ||
|
|
||
| Args: | ||
| parameters (_tensor_or_tensors): an iterable of Tensors or a | ||
| single Tensor that will have gradients scaled | ||
| scaler (torch.Tensor): multiplier to scale gradients | ||
| foreach (Optional[bool]): use the faster foreach-based implementation. | ||
| If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently | ||
| fall back to the slow implementation for other device types. | ||
| Default: ``None`` | ||
| Returns: | ||
| None | ||
| """ | ||
| if isinstance(parameters, torch.Tensor): | ||
| parameters = [parameters] | ||
| else: | ||
| parameters = list(parameters) | ||
| _scale_grad_(parameters, scaler, foreach) | ||
|
|
||
|
|
||
| def _group_tensors_by_device( | ||
| tensors: list[torch.Tensor], | ||
| ) -> dict[torch.device, list[Tensor]]: | ||
| ret = defaultdict(list) | ||
| for i, tensor in enumerate(tensors): | ||
| ret[tensor.device].append(tensor) | ||
|
|
||
| return ret | ||
|
|
||
|
|
||
| @_no_grad | ||
| def _scale_grad_( | ||
| parameters: _tensor_or_tensors, | ||
| scaler: torch.Tensor, | ||
| foreach: Optional[bool] = None, | ||
| ) -> None: | ||
| if isinstance(parameters, torch.Tensor): | ||
| parameters = [parameters] | ||
| grads = [p.grad for p in parameters if p.grad is not None] | ||
| if len(grads) == 0: | ||
| return | ||
| grouped_grads = _group_tensors_by_device(grads) | ||
|
|
||
| for device, device_grads in grouped_grads.items(): | ||
| if (foreach is None and _has_foreach_support(device_grads, device)) or ( | ||
| foreach and _device_has_foreach_support(device) | ||
| ): | ||
| torch._foreach_mul_(device_grads, scaler.to(device)) | ||
| elif foreach: | ||
| raise RuntimeError( | ||
| f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" | ||
| ) | ||
| else: | ||
| scaler_device = scaler.to(device) | ||
| for g in device_grads: | ||
| g.mul_(scaler_device) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but I actually wonder whether it's worth having a separate flag for this. I think model, loss, optimizer step are all pretty clear but this one may be a bit niche for someone to look at and immediately understand (and also I think it's unlikely for someone to actually want to experiment with). Could even just use some heuristic like
self._compile_scale_grads = all([self._compile_model, self._compile_loss, self._compile_optimizer_step])