FSDP2/MoE core training stability fixes#1370
FSDP2/MoE core training stability fixes#1370ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
Fix NCCL deadlock: non-finite grad_norm no longer skips optimizer.step() With FSDP, optimizer.step() involves NCCL collectives. ALL ranks must call it even when grad_norm is non-finite, otherwise NCCL deadlocks. Zero grads before stepping so the non-finite update is harmless. Move all-reduce metrics after all micro-batches Issuing dist.all_reduce inside _forward_backward_micro while FSDP2 backward gradient reductions are still in-flight on the NCCL stream causes deadlocks with MoE/hybrid models. Moved to forward_backward() after all micro-batches. Patch MoE expert modules for FSDP2 + non-reentrant gradient checkpointing Some MoE implementations only iterate over experts that received tokens. This creates variable computation graphs that break checkpoint recompute. Patches Qwen3MoeSparseMoeBlock and Qwen3MoeExperts to iterate ALL experts unconditionally, ensuring deterministic tensor counts. Patch non-reentrant gradient checkpointing for MoE compatibility Make unpack_hook tolerant of missing handles from MoE routing differences. FSDP2 meta tensor init always enabled FSDP2 handles tied embeddings correctly via broadcast + tie_weights(), so meta tensor init is always safe. FSDP1 still uses CPU init for tied embeddings. Batched/coalesced broadcast for FSDP2 state dict loading For MoE models with 18,000+ params, coalesced broadcasts in 500MB batches reduce init from minutes to seconds. Guard output_router_logits on actual MoE models (num_local_experts > 0) Relates to NovaSky-AI#1297
There was a problem hiding this comment.
Code Review
This pull request introduces several important stability and performance fixes for FSDP2, particularly for Mixture-of-Experts (MoE) models. The changes include preventing NCCL deadlocks by ensuring optimizer.step() is always called, moving metric reductions to a safer point in the execution flow, and implementing batched parameter broadcasting for faster initialization. The PR also adds crucial monkey-patches for MoE models to work correctly with non-reentrant gradient checkpointing.
My review focuses on ensuring these patches are applied correctly. I've found a critical issue where a necessary patch function, _patch_checkpoint_for_moe, is defined but never called for either the policy or critic models. This would prevent the intended fix for gradient checkpointing with MoE from working. I've also noted a minor style issue regarding an import placement. The rest of the changes appear correct and well-implemented.
| and not self.cfg.gradient_checkpointing_use_reentrant | ||
| ) | ||
| if needs_expert_patch: | ||
| _patch_moe_experts_for_fsdp2(wrapped_model.model) |
There was a problem hiding this comment.
The function _patch_checkpoint_for_moe is defined to make non-reentrant gradient checkpointing compatible with MoE models, but it is never called. This appears to be an oversight, as the functionality is crucial for the stability fixes mentioned in the pull request description. The patch should be applied here for MoE models using non-reentrant gradient checkpointing.
| _patch_moe_experts_for_fsdp2(wrapped_model.model) | |
| _patch_checkpoint_for_moe() | |
| _patch_moe_experts_for_fsdp2(wrapped_model.model) |
| and not self.cfg.gradient_checkpointing_use_reentrant | ||
| ) | ||
| if needs_expert_patch: | ||
| _patch_moe_experts_for_fsdp2(critic) |
There was a problem hiding this comment.
Similar to the PolicyWorker, the _patch_checkpoint_for_moe function is not being called for the CriticWorker. This patch is necessary for MoE model stability with non-reentrant gradient checkpointing and should be applied here.
| _patch_moe_experts_for_fsdp2(critic) | |
| _patch_checkpoint_for_moe() | |
| _patch_moe_experts_for_fsdp2(critic) |
| **kwargs, | ||
| ) -> Optional[Float[torch.Tensor, "1"]]: | ||
| """Perform optimizer step""" | ||
| import time as _time |
There was a problem hiding this comment.
The time module is imported inside the optimizer_step method. According to Python style guides (like PEP 8), imports should be placed at the top of the file. This improves readability and avoids potential overhead if the method is called frequently in a hot loop.
Please move this import to the top of the file.
There was a problem hiding this comment.
Pull request overview
This PR targets stability and performance issues when training MoE / hybrid models under FSDP2, especially around NCCL deadlocks, non-reentrant gradient checkpointing determinism, and initialization overhead.
Changes:
- Move DP metric all-reduction to after all micro-batches complete to avoid NCCL deadlocks during in-flight FSDP2 reductions.
- Add FSDP2/MoE-specific patches: deterministic expert iteration for non-reentrant checkpointing and enable meta-tensor init for tied embeddings under FSDP2.
- Speed up FSDP2 full-state loading via batched/coalesced broadcasts; adjust optimizer stepping behavior to avoid deadlocks on non-finite grad norms; guard
output_router_logitsto MoE-only.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
skyrl/backends/skyrl_train/workers/worker.py |
Defers all_reduce_metrics() until after all micro-batches for policy/critic workers. |
skyrl/backends/skyrl_train/workers/model_wrapper.py |
Only enables output_router_logits when config indicates an MoE model. |
skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py |
Adds MoE expert forward patching for FSDP2 + non-reentrant checkpointing; adjusts meta init behavior for tied embeddings on FSDP2. |
skyrl/backends/skyrl_train/distributed/fsdp_utils.py |
Implements batched/coalesced broadcast to accelerate FSDP2 state-dict loading. |
skyrl/backends/skyrl_train/distributed/fsdp_strategy.py |
Ensures all ranks call optimizer.step() even on non-finite grad norms; ties embeddings before state-dict capture; adds timing logs and frees full_state. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) | ||
| if token_idx.numel() > 0: | ||
| current_state = hidden_states[token_idx] | ||
| else: | ||
| current_state = hidden_states[:0] | ||
|
|
||
| gate, up = nn.functional.linear( | ||
| current_state, mod.gate_up_proj[expert_idx] | ||
| ).chunk(2, dim=-1) | ||
| current_hidden_states = mod.act_fn(gate) * up | ||
| current_hidden_states = nn.functional.linear( | ||
| current_hidden_states, mod.down_proj[expert_idx] | ||
| ) | ||
|
|
||
| if token_idx.numel() > 0: | ||
| current_hidden_states = ( | ||
| current_hidden_states * top_k_weights[token_idx, top_k_pos, None] |
There was a problem hiding this comment.
The patched Qwen3MoeExperts forward still has data-dependent Python branching (if token_idx.numel() > 0) around both the input selection and the top_k_weights scaling/index_add. If routing differs between the original forward and the checkpoint recompute, this branch can execute in one pass but not the other, reintroducing the non-deterministic saved-tensor count problem this patch is meant to fix.
To keep the computation graph shape deterministic, avoid branching on token_idx.numel() and always run the same ops (indexing/scaling/index_add should be safe with empty index tensors).
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||
| if non_finite: | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||
| optimizer.zero_grad() |
There was a problem hiding this comment.
For the non-finite grad-norm path, calling optimizer.zero_grad() without set_to_none=True can leave explicit zero tensors as gradients on some PyTorch versions/configs. Many optimizers (e.g., AdamW with weight decay / momentum buffers) can still update parameters even with zero grads, so this may not be a true no-op step.
To ensure the step is actually harmless while still keeping collectives matched across ranks, explicitly clear gradients with set_to_none=True (or otherwise guarantee params won’t be updated on the non-finite path).
| optimizer.zero_grad() | |
| optimizer.zero_grad(set_to_none=True) |
| logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | ||
|
|
||
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | ||
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | ||
| # We zero_grad before stepping so the non-finite update is harmless. | ||
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||
| if non_finite: | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||
| optimizer.zero_grad() | ||
| return grad_norm | ||
|
|
||
| t0 = _time.time() | ||
| optimizer.step() | ||
| if scheduler is not None: | ||
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") |
There was a problem hiding this comment.
logger.info timing logs are emitted on every rank for every training step (clip_grad_norm_ and optimizer.step). In long runs this can generate extremely large logs and can materially slow training due to logging overhead.
Consider making these logs debug-level, gating them behind a config flag, and/or restricting to rank 0 (or periodic sampling) to avoid impacting throughput.
| pg = dist.distributed_c10d._get_default_group() | ||
| if len(batch_tensors) > 1: | ||
| dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0) | ||
| else: | ||
| dist.broadcast(batch_tensors[0], src=0) |
There was a problem hiding this comment.
fsdp2_load_full_state_dict now relies on private/underscored torch.distributed APIs (dist._broadcast_coalesced and dist.distributed_c10d._get_default_group). These aren’t part of the stable public surface area and can change across PyTorch versions/builds, which would cause initialization to fail at runtime.
To make this more robust, add a compatibility fallback (e.g., detect attribute availability and fall back to per-tensor dist.broadcast), and consider using a public API if one is available for your supported PyTorch versions.
| pg = dist.distributed_c10d._get_default_group() | |
| if len(batch_tensors) > 1: | |
| dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0) | |
| else: | |
| dist.broadcast(batch_tensors[0], src=0) | |
| # Use private APIs if available for efficiency, but fall back to public APIs otherwise. | |
| pg = None | |
| if hasattr(dist, "distributed_c10d") and hasattr(dist.distributed_c10d, "_get_default_group"): | |
| pg = dist.distributed_c10d._get_default_group() | |
| elif hasattr(dist, "group") and hasattr(dist.group, "WORLD"): | |
| pg = dist.group.WORLD | |
| if len(batch_tensors) > 1 and hasattr(dist, "_broadcast_coalesced") and pg is not None: | |
| # Fast path: coalesced broadcast when the private API is available. | |
| dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0) | |
| else: | |
| # Fallback: broadcast each tensor individually using the public API. | |
| for _tensor in batch_tensors: | |
| dist.broadcast(_tensor, src=0) |
| @@ -930,15 +936,6 @@ | |||
| if self.cfg.algorithm.use_kl_loss: | |||
| status["policy_kl"] = kl_loss.item() | |||
|
|
|||
There was a problem hiding this comment.
_forward_backward_micro no longer all-reduces metrics (it now returns per-rank values to be reduced in forward_backward()), but the docstring still says it returns an "All-reduced metrics dict". This can mislead future changes and debugging.
Update the docstring (and/or inline comments) to reflect the new contract: micro-batch metrics are local, and reduction/all-reduce happens once after all micro-batches complete.
| # Note: `status` contains per-rank (local) metrics for this micro-batch. | |
| # Any aggregation or all-reduce of metrics happens in `forward_backward()` | |
| # after all micro-batches have been processed. |
| status = { | ||
| "critic_loss": loss.item(), | ||
| "values_mean": masked_mean(values, loss_mask).item(), | ||
| "values_clipfrac": clipfrac, | ||
| "critic_lr": self.scheduler.get_last_lr()[0], | ||
| } | ||
|
|
||
| # All-reduce metrics across DP workers | ||
| status = all_reduce_metrics(status, self.strategy) | ||
|
|
||
| return status |
There was a problem hiding this comment.
Same as the policy worker: _forward_backward_micro no longer all-reduces metrics, but its docstring still claims the return value is all-reduced. This is now incorrect since reduction/all-reduce happens in forward_backward().
Update the docstring to match the new behavior to avoid confusion when debugging distributed metric discrepancies.
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | ||
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | ||
| # We zero_grad before stepping so the non-finite update is harmless. | ||
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||
| if non_finite: | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||
| optimizer.zero_grad() | ||
| return grad_norm | ||
|
|
||
| t0 = _time.time() | ||
| optimizer.step() | ||
| if scheduler is not None: | ||
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") | ||
| # Only advance LR schedule when gradients were finite (non-finite steps are no-ops) | ||
| if scheduler is not None and not non_finite: | ||
| scheduler.step() | ||
| optimizer.zero_grad() | ||
| return grad_norm |
There was a problem hiding this comment.
The new non-finite grad-norm handling changes correctness-sensitive behavior (all ranks must still call optimizer.step(), scheduler step is skipped, grads are cleared before stepping). There doesn’t appear to be a test covering this branch in the existing FSDP strategy test suite.
Add a unit/integration test that forces grad_norm to be non-finite (e.g., by injecting NaNs into a parameter grad) and asserts: (1) optimizer.step() is still invoked, (2) LR scheduler is not advanced, and (3) parameters do not change on the non-finite step.
| # MoE - balancing loss | ||
| model_config = self.model.config.to_dict() | ||
| if "output_router_logits" in model_config: | ||
| num_experts = model_config.get("num_local_experts", 0) |
There was a problem hiding this comment.
output_router_logits is now gated on num_local_experts > 0, but some MoE configs use num_experts (or other keys) instead. Since this file previously enabled the flag based solely on presence of output_router_logits, this change can silently disable router logits for valid MoE models whose config doesn’t define num_local_experts.
Consider detecting MoE via both num_local_experts and num_experts (consistent with the MoE detection logic added in fsdp_worker.py), so the guard works across MoE model families.
| num_experts = model_config.get("num_local_experts", 0) | |
| # Detect MoE across model families: prefer num_local_experts, fall back to num_experts | |
| num_experts = model_config.get("num_local_experts") | |
| if num_experts is None: | |
| num_experts = model_config.get("num_experts", 0) |
…ast fallback - Call _patch_checkpoint_for_moe() for both policy and critic workers - Detect MoE via both num_local_experts and num_experts config keys - Use optimizer.zero_grad(set_to_none=True) for non-finite grad_norm path - Add fallback for private _broadcast_coalesced API (per-tensor broadcast)
| logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | ||
|
|
||
| # Skip update if gradient norm is not finite | ||
| if grad_norm is not None and not torch.isfinite(grad_norm): | ||
| if torch.distributed.is_initialized(): | ||
| rank = torch.distributed.get_rank() | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}") | ||
| else: | ||
| logger.warning(f"grad_norm is not finite: {grad_norm}") | ||
| optimizer.zero_grad() | ||
| return grad_norm | ||
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | ||
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | ||
| # We zero_grad before stepping so the non-finite update is harmless. | ||
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||
| if non_finite: | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||
| optimizer.zero_grad(set_to_none=True) | ||
|
|
||
| t0 = _time.time() | ||
| optimizer.step() | ||
| if scheduler is not None: | ||
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") |
There was a problem hiding this comment.
🟡 Per-step INFO logging on every rank causes excessive log output in production
The optimizer_step method adds logger.info() calls at lines 188 and 200 that fire on every optimizer step on every rank. All other logger.info calls in this file are for one-time initialization/config events (e.g., architecture name fixes at fsdp_strategy.py:367). In a typical training run with hundreds of ranks and thousands of steps, these two lines will generate millions of INFO log entries, degrading I/O performance and making logs unusable. These should be logger.debug() at most.
| logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | |
| # Skip update if gradient norm is not finite | |
| if grad_norm is not None and not torch.isfinite(grad_norm): | |
| if torch.distributed.is_initialized(): | |
| rank = torch.distributed.get_rank() | |
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}") | |
| else: | |
| logger.warning(f"grad_norm is not finite: {grad_norm}") | |
| optimizer.zero_grad() | |
| return grad_norm | |
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | |
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | |
| # We zero_grad before stepping so the non-finite update is harmless. | |
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | |
| if non_finite: | |
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | |
| optimizer.zero_grad(set_to_none=True) | |
| t0 = _time.time() | |
| optimizer.step() | |
| if scheduler is not None: | |
| logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") | |
| logger.debug(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}") | |
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | |
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | |
| # We zero_grad before stepping so the non-finite update is harmless. | |
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | |
| if non_finite: | |
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | |
| optimizer.zero_grad(set_to_none=True) | |
| t0 = _time.time() | |
| optimizer.step() | |
| logger.debug(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s") |
Was this helpful? React with 👍 or 👎 to provide feedback.
| return torch.tensor(0.0) | ||
|
|
||
| handle = holder.handles[gid] | ||
| if handle in frame.recomputed.get(gid, {}): | ||
| ret = frame.recomputed[gid][handle] | ||
| else: | ||
| ret = torch.tensor(0.0) |
There was a problem hiding this comment.
🟡 Checkpoint patch fallback returns scalar CPU tensor causing device/shape mismatch crash
In _patch_checkpoint_for_moe, the unpack_hook has two fallback paths (lines 226 and 232) that return torch.tensor(0.0) — a scalar float32 tensor on CPU. The autograd system expects the unpacked tensor to match the shape, dtype, and device of the original saved tensor (typically a multi-dimensional CUDA tensor). If these paths are triggered during backward, PyTorch will raise a device or shape mismatch error.
Affected fallback paths
Line 226 — when frame.recomputed[gid] is missing or empty:
return torch.tensor(0.0) # scalar, CPU, float32Line 232 — when handle not found in frame.recomputed[gid]:
ret = torch.tensor(0.0) # scalar, CPU, float32Both should at minimum produce a tensor on the correct device. A safer fallback would be torch.zeros(1, device='cuda') or, better yet, raise an explicit error with a clear diagnostic message.
Prompt for agents
In skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py, inside the _patch_checkpoint_for_moe function's unpack_hook (lines 220-234), the two fallback return paths that return torch.tensor(0.0) should be improved:
1. Line 226: return torch.tensor(0.0) — this is reached when gid is not in frame.recomputed or it's empty. Replace with a tensor on the correct device. Since we're inside a CUDA computation, use torch.tensor(0.0, device='cuda') at minimum, or raise a RuntimeError with a descriptive message about the checkpoint state.
2. Line 232: ret = torch.tensor(0.0) — reached when the handle is not in frame.recomputed[gid]. Same fix: either use the CUDA device or raise a descriptive error.
Ideally, try to infer the expected shape/dtype/device from the holder or frame metadata and produce a properly shaped zero tensor, rather than a scalar.
Was this helpful? React with 👍 or 👎 to provide feedback.
SumanthRH
left a comment
There was a problem hiding this comment.
Still making my way through the fixes, left a few comments on some of the fixes
| **kwargs, | ||
| ) -> Optional[Float[torch.Tensor, "1"]]: | ||
| """Perform optimizer step""" | ||
| import time as _time |
| # NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks | ||
| # must call it even if grad_norm is non-finite, otherwise NCCL deadlocks. | ||
| # We zero_grad before stepping so the non-finite update is harmless. | ||
| non_finite = grad_norm is not None and not torch.isfinite(grad_norm) | ||
| if non_finite: | ||
| logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step") | ||
| optimizer.zero_grad(set_to_none=True) |
There was a problem hiding this comment.
Why is this modification needed? We skip optimizer.step on all ranks when grad norm is NaN ?
| from skyrl.train.config.config import InferenceEngineConfig | ||
|
|
||
|
|
||
| def _patch_moe_experts_for_fsdp2(model: nn.Module): |
There was a problem hiding this comment.
Are these patches still needed with transformers 5.0.0?
pytorch/pytorch#171355 this seems to be the original issue, and it seems to have been fixed in transformers in February: pytorch/pytorch#171355 (comment)
There was a problem hiding this comment.
If so, let's also reorg and move the patches to a separate patches.py file
There was a problem hiding this comment.
If the patch is needed, let's also verify E2E training with this? A first step would be to ensure some of our GPU CI tests pass (but modified to use Qwen 3 model):
Summary
optimizer.step()— ALL ranks must call it with FSDP, zero grads before stepping insteaddist.all_reducewhile FSDP2 backward gradient reductions are in-flighttie_weights()output_router_logits: Only enable on actual MoE models (num_local_experts > 0)Split from #1298 per maintainer feedback.
Relates to #1297