diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 2ba9a6290..d38f08775 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -59,7 +59,7 @@ def __init__( super().__init__(optimizer) self._clipping = clipping self._max_gradient = max_gradient - self._norm_type = norm_type + self._norm_type = float(norm_type) self._check_meta: bool = True self._enable_global_grad_clip = enable_global_grad_clip self._step_num = 0 @@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None: torch.nn.utils.clip_grad_norm_( replicate_params, self._max_gradient, - norm_type=float(self._norm_type), + norm_type=self._norm_type, ) else: self.clip_grad_norm_() @@ -139,7 +139,6 @@ def step(self, closure: Any = None) -> None: def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: """Clip the gradient norm of all parameters.""" max_norm = self._max_gradient - norm_type = float(self._norm_type) all_grads = [] total_grad_norm = None @@ -157,7 +156,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: sharded_grad_norm = _batch_cal_norm( sharded_grads, max_norm, - norm_type, + self._norm_type, pgs, ) total_grad_norm = ( @@ -165,7 +164,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None else ( torch.maximum(total_grad_norm, sharded_grad_norm) - if norm_type == torch.inf + if self._norm_type == torch.inf else total_grad_norm + sharded_grad_norm ) ) @@ -184,7 +183,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: replicated_grad_norm = _batch_cal_norm( replicated_grads, max_norm, - norm_type, + self._norm_type, None, ) total_grad_norm = ( @@ -192,7 +191,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: if total_grad_norm is None else ( torch.maximum(total_grad_norm, replicated_grad_norm) - if norm_type == torch.inf + if self._norm_type == torch.inf else total_grad_norm + replicated_grad_norm ) ) @@ -200,11 +199,20 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: else: square_replicated_grad_norm = 0 + if total_grad_norm is not None: + total_grad_norm = ( + torch.pow(total_grad_norm, 1.0 / self._norm_type) + if self._norm_type != torch.inf + else total_grad_norm + ) + else: + return None + global log_grad_norm if log_grad_norm: - if total_grad_norm is not None and norm_type != torch.inf: + if total_grad_norm is not None and self._norm_type != torch.inf: # pyre-ignore[58] - grad_norm = total_grad_norm ** (1.0 / norm_type) + grad_norm = total_grad_norm ** (1.0 / self._norm_type) else: grad_norm = total_grad_norm @@ -213,15 +221,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]: f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}" ) - # Aggregation - if total_grad_norm is None: - return - - if norm_type != torch.inf: - # pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float. - total_grad_norm = total_grad_norm ** (1.0 / norm_type) - # pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor]. - clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6)) + clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) torch._foreach_mul_(all_grads, clip_coef_clamped) return total_grad_norm diff --git a/torchrec/optim/tests/test_clipping.py b/torchrec/optim/tests/test_clipping.py index 0c837ec86..5f311459c 100644 --- a/torchrec/optim/tests/test_clipping.py +++ b/torchrec/optim/tests/test_clipping.py @@ -251,7 +251,7 @@ def _get_params_to_pg( return {param: [param.device_mesh.get_group()] for param in params} @with_comms - @parametrize("norm_type", ("inf",)) + @parametrize("norm_type", ("inf", 1, 2)) def test_dtensor_clip_all_gradients_norm( self, norm_type: Union[float, str] ) -> None: