Skip to content

Refactor _batch_cal_norm and remove #pyre-ignore #3200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
super().__init__(optimizer)
self._clipping = clipping
self._max_gradient = max_gradient
self._norm_type = norm_type
self._norm_type = float(norm_type)
self._check_meta: bool = True
self._enable_global_grad_clip = enable_global_grad_clip
self._step_num = 0
Expand Down Expand Up @@ -124,7 +124,7 @@ def step(self, closure: Any = None) -> None:
torch.nn.utils.clip_grad_norm_(
replicate_params,
self._max_gradient,
norm_type=float(self._norm_type),
norm_type=self._norm_type,
)
else:
self.clip_grad_norm_()
Expand All @@ -139,7 +139,6 @@ def step(self, closure: Any = None) -> None:
def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
"""Clip the gradient norm of all parameters."""
max_norm = self._max_gradient
norm_type = float(self._norm_type)
all_grads = []
total_grad_norm = None

Expand All @@ -157,15 +156,15 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
sharded_grad_norm = _batch_cal_norm(
sharded_grads,
max_norm,
norm_type,
self._norm_type,
pgs,
)
total_grad_norm = (
sharded_grad_norm
if total_grad_norm is None
else (
torch.maximum(total_grad_norm, sharded_grad_norm)
if norm_type == torch.inf
if self._norm_type == torch.inf
else total_grad_norm + sharded_grad_norm
)
)
Expand All @@ -184,27 +183,36 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
replicated_grad_norm = _batch_cal_norm(
replicated_grads,
max_norm,
norm_type,
self._norm_type,
None,
)
total_grad_norm = (
replicated_grad_norm
if total_grad_norm is None
else (
torch.maximum(total_grad_norm, replicated_grad_norm)
if norm_type == torch.inf
if self._norm_type == torch.inf
else total_grad_norm + replicated_grad_norm
)
)
square_replicated_grad_norm = replicated_grad_norm
else:
square_replicated_grad_norm = 0

if total_grad_norm is not None:
total_grad_norm = (
torch.pow(total_grad_norm, 1.0 / self._norm_type)
if self._norm_type != torch.inf
else total_grad_norm
)
else:
return None

global log_grad_norm
if log_grad_norm:
if total_grad_norm is not None and norm_type != torch.inf:
if total_grad_norm is not None and self._norm_type != torch.inf:
# pyre-ignore[58]
grad_norm = total_grad_norm ** (1.0 / norm_type)
grad_norm = total_grad_norm ** (1.0 / self._norm_type)
else:
grad_norm = total_grad_norm

Expand All @@ -213,15 +221,7 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
)

# Aggregation
if total_grad_norm is None:
return

if norm_type != torch.inf:
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
clip_coef = torch.tensor(max_norm) / (total_grad_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
torch._foreach_mul_(all_grads, clip_coef_clamped)
return total_grad_norm
Expand Down
2 changes: 1 addition & 1 deletion torchrec/optim/tests/test_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _get_params_to_pg(
return {param: [param.device_mesh.get_group()] for param in params}

@with_comms
@parametrize("norm_type", ("inf",))
@parametrize("norm_type", ("inf", 1, 2))
def test_dtensor_clip_all_gradients_norm(
self, norm_type: Union[float, str]
) -> None:
Expand Down
Loading