-
Notifications
You must be signed in to change notification settings - Fork 293
FSDP2/MoE core training stability fixes #1370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -169,31 +169,37 @@ def optimizer_step( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Optional[Float[torch.Tensor, "1"]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Perform optimizer step""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time as _time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rank = dist.get_rank() if dist.is_initialized() else -1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_norm = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(model, HFModelWrapper): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model = model.model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.max_norm > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| t0 = _time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # NOTE (sumanthrh): All `grad_norm`s returned here are the original grad norms before clipping. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(model, FSDP): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_norm = model.clip_grad_norm_(max_norm=self.max_norm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(model, FSDPModule): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_norm = fsdp2_clip_grad_norm_(model.parameters(), max_norm=self.max_norm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.max_norm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Skip update if gradient norm is not finite | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if grad_norm is not None and not torch.isfinite(grad_norm): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if torch.distributed.is_initialized(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rank = torch.distributed.get_rank() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning(f"grad_norm is not finite: {grad_norm}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.zero_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return grad_norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # We zero_grad before stepping so the non-finite update is harmless. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if non_finite: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.zero_grad(set_to_none=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+190
to
+196
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this modification needed? We skip |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| t0 = _time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.step() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if scheduler is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+188
to
+200
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | |
| # Skip update if gradient norm is not finite | |
| if grad_norm is not None and not torch.isfinite(grad_norm): | |
| if torch.distributed.is_initialized(): | |
| rank = torch.distributed.get_rank() | |
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}") | |
| else: | |
| logger.warning(f"grad_norm is not finite: {grad_norm}") | |
| optimizer.zero_grad() | |
| return grad_norm | |
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | |
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | |
| # We zero_grad before stepping so the non-finite update is harmless. | |
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | |
| if non_finite: | |
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | |
| optimizer.zero_grad(set_to_none=True) | |
| t0 = _time.time() | |
| optimizer.step() | |
| if scheduler is not None: | |
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") | |
| logger.debug(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | |
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | |
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | |
| # We zero_grad before stepping so the non-finite update is harmless. | |
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | |
| if non_finite: | |
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | |
| optimizer.zero_grad(set_to_none=True) | |
| t0 = _time.time() | |
| optimizer.step() | |
| logger.debug(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") |
Was this helpful? React with 👍 or 👎 to provide feedback.
Copilot
AI
Mar 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new non-finite grad-norm handling changes correctness-sensitive behavior (all ranks must still call optimizer.step(), scheduler step is skipped, grads are cleared before stepping). There doesn’t appear to be a test covering this branch in the existing FSDP strategy test suite.
Add a unit/integration test that forces grad_norm to be non-finite (e.g., by injecting NaNs into a parameter grad) and asserts: (1) optimizer.step() is still invoked, (2) LR scheduler is not advanced, and (3) parameters do not change on the non-finite step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
timemodule is imported inside theoptimizer_stepmethod. According to Python style guides (like PEP 8), imports should be placed at the top of the file. This improves readability and avoids potential overhead if the method is called frequently in a hot loop.Please move this import to the top of the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1