Skip to content

FSDP2/MoE core training stability fixes#1370

Open
ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
ashutoshuiuc:pr/fsdp2-moe-core-fixes
Open

FSDP2/MoE core training stability fixes#1370
ashutoshuiuc wants to merge 2 commits intoNovaSky-AI:mainfrom
ashutoshuiuc:pr/fsdp2-moe-core-fixes

Conversation

@ashutoshuiuc
Copy link
Copy Markdown

@ashutoshuiuc ashutoshuiuc commented Mar 23, 2026

Summary

  • Fix NCCL deadlock: Non-finite grad_norm no longer skips optimizer.step() — ALL ranks must call it with FSDP, zero grads before stepping instead
  • Move all-reduce metrics after all micro-batches: Prevents NCCL deadlock from issuing dist.all_reduce while FSDP2 backward gradient reductions are in-flight
  • Patch MoE experts for FSDP2: Qwen3MoeSparseMoeBlock and Qwen3MoeExperts iterate ALL experts unconditionally for deterministic computation graph with non-reentrant gradient checkpointing
  • Patch non-reentrant checkpoint for MoE: Tolerant unpack_hook for missing handles from MoE routing differences
  • FSDP2 meta tensor init always enabled: Handles tied embeddings via tie_weights()
  • Batched/coalesced broadcast: 500MB batches reduce MoE init from minutes to seconds (18K+ params)
  • Guard output_router_logits: Only enable on actual MoE models (num_local_experts > 0)

Split from #1298 per maintainer feedback.

Relates to #1297


Open with Devin

Fix NCCL deadlock: non-finite grad_norm no longer skips optimizer.step()
  With FSDP, optimizer.step() involves NCCL collectives. ALL ranks must call
  it even when grad_norm is non-finite, otherwise NCCL deadlocks. Zero grads
  before stepping so the non-finite update is harmless.

Move all-reduce metrics after all micro-batches
  Issuing dist.all_reduce inside _forward_backward_micro while FSDP2 backward
  gradient reductions are still in-flight on the NCCL stream causes deadlocks
  with MoE/hybrid models. Moved to forward_backward() after all micro-batches.

Patch MoE expert modules for FSDP2 + non-reentrant gradient checkpointing
  Some MoE implementations only iterate over experts that received tokens.
  This creates variable computation graphs that break checkpoint recompute.
  Patches Qwen3MoeSparseMoeBlock and Qwen3MoeExperts to iterate ALL experts
  unconditionally, ensuring deterministic tensor counts.

Patch non-reentrant gradient checkpointing for MoE compatibility
  Make unpack_hook tolerant of missing handles from MoE routing differences.

FSDP2 meta tensor init always enabled
  FSDP2 handles tied embeddings correctly via broadcast + tie_weights(),
  so meta tensor init is always safe. FSDP1 still uses CPU init for tied
  embeddings.

Batched/coalesced broadcast for FSDP2 state dict loading
  For MoE models with 18,000+ params, coalesced broadcasts in 500MB batches
  reduce init from minutes to seconds.

Guard output_router_logits on actual MoE models (num_local_experts > 0)

Relates to NovaSky-AI#1297
Copilot AI review requested due to automatic review settings March 23, 2026 13:45
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several important stability and performance fixes for FSDP2, particularly for Mixture-of-Experts (MoE) models. The changes include preventing NCCL deadlocks by ensuring optimizer.step() is always called, moving metric reductions to a safer point in the execution flow, and implementing batched parameter broadcasting for faster initialization. The PR also adds crucial monkey-patches for MoE models to work correctly with non-reentrant gradient checkpointing.

My review focuses on ensuring these patches are applied correctly. I've found a critical issue where a necessary patch function, _patch_checkpoint_for_moe, is defined but never called for either the policy or critic models. This would prevent the intended fix for gradient checkpointing with MoE from working. I've also noted a minor style issue regarding an import placement. The rest of the changes appear correct and well-implemented.

and not self.cfg.gradient_checkpointing_use_reentrant
)
if needs_expert_patch:
_patch_moe_experts_for_fsdp2(wrapped_model.model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _patch_checkpoint_for_moe is defined to make non-reentrant gradient checkpointing compatible with MoE models, but it is never called. This appears to be an oversight, as the functionality is crucial for the stability fixes mentioned in the pull request description. The patch should be applied here for MoE models using non-reentrant gradient checkpointing.

Suggested change
_patch_moe_experts_for_fsdp2(wrapped_model.model)
_patch_checkpoint_for_moe()
_patch_moe_experts_for_fsdp2(wrapped_model.model)

and not self.cfg.gradient_checkpointing_use_reentrant
)
if needs_expert_patch:
_patch_moe_experts_for_fsdp2(critic)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the PolicyWorker, the _patch_checkpoint_for_moe function is not being called for the CriticWorker. This patch is necessary for MoE model stability with non-reentrant gradient checkpointing and should be applied here.

Suggested change
_patch_moe_experts_for_fsdp2(critic)
_patch_checkpoint_for_moe()
_patch_moe_experts_for_fsdp2(critic)

**kwargs,
) -> Optional[Float[torch.Tensor, "1"]]:
"""Perform optimizer step"""
import time as _time
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The time module is imported inside the optimizer_step method. According to Python style guides (like PEP 8), imports should be placed at the top of the file. This improves readability and avoids potential overhead if the method is called frequently in a hot loop.

Please move this import to the top of the file.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets stability and performance issues when training MoE / hybrid models under FSDP2, especially around NCCL deadlocks, non-reentrant gradient checkpointing determinism, and initialization overhead.

Changes:

  • Move DP metric all-reduction to after all micro-batches complete to avoid NCCL deadlocks during in-flight FSDP2 reductions.
  • Add FSDP2/MoE-specific patches: deterministic expert iteration for non-reentrant checkpointing and enable meta-tensor init for tied embeddings under FSDP2.
  • Speed up FSDP2 full-state loading via batched/coalesced broadcasts; adjust optimizer stepping behavior to avoid deadlocks on non-finite grad norms; guard output_router_logits to MoE-only.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
skyrl/backends/skyrl_train/workers/worker.py Defers all_reduce_metrics() until after all micro-batches for policy/critic workers.
skyrl/backends/skyrl_train/workers/model_wrapper.py Only enables output_router_logits when config indicates an MoE model.
skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py Adds MoE expert forward patching for FSDP2 + non-reentrant checkpointing; adjusts meta init behavior for tied embeddings on FSDP2.
skyrl/backends/skyrl_train/distributed/fsdp_utils.py Implements batched/coalesced broadcast to accelerate FSDP2 state-dict loading.
skyrl/backends/skyrl_train/distributed/fsdp_strategy.py Ensures all ranks call optimizer.step() even on non-finite grad norms; ties embeddings before state-dict capture; adds timing logs and frees full_state.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +136 to +152
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.numel() > 0:
current_state = hidden_states[token_idx]
else:
current_state = hidden_states[:0]

gate, up = nn.functional.linear(
current_state, mod.gate_up_proj[expert_idx]
).chunk(2, dim=-1)
current_hidden_states = mod.act_fn(gate) * up
current_hidden_states = nn.functional.linear(
current_hidden_states, mod.down_proj[expert_idx]
)

if token_idx.numel() > 0:
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patched Qwen3MoeExperts forward still has data-dependent Python branching (if token_idx.numel() > 0) around both the input selection and the top_k_weights scaling/index_add. If routing differs between the original forward and the checkpoint recompute, this branch can execute in one pass but not the other, reintroducing the non-deterministic saved-tensor count problem this patch is meant to fix.

To keep the computation graph shape deterministic, avoid branching on token_idx.numel() and always run the same ops (indexing/scaling/index_add should be safe with empty index tensors).

Copilot uses AI. Check for mistakes.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad()
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the non-finite grad-norm path, calling optimizer.zero_grad() without set_to_none=True can leave explicit zero tensors as gradients on some PyTorch versions/configs. Many optimizers (e.g., AdamW with weight decay / momentum buffers) can still update parameters even with zero grads, so this may not be a true no-op step.

To ensure the step is actually harmless while still keeping collectives matched across ranks, explicitly clear gradients with set_to_none=True (or otherwise guarantee params won’t be updated on the non-finite path).

Suggested change
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=True)

Copilot uses AI. Check for mistakes.
Comment on lines +188 to +200
logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")

# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad()
return grad_norm

t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info timing logs are emitted on every rank for every training step (clip_grad_norm_ and optimizer.step). In long runs this can generate extremely large logs and can materially slow training due to logging overhead.

Consider making these logs debug-level, gating them behind a config flag, and/or restricting to rank 0 (or periodic sampling) to avoid impacting throughput.

Copilot uses AI. Check for mistakes.
Comment on lines +309 to +313
pg = dist.distributed_c10d._get_default_group()
if len(batch_tensors) > 1:
dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0)
else:
dist.broadcast(batch_tensors[0], src=0)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp2_load_full_state_dict now relies on private/underscored torch.distributed APIs (dist._broadcast_coalesced and dist.distributed_c10d._get_default_group). These aren’t part of the stable public surface area and can change across PyTorch versions/builds, which would cause initialization to fail at runtime.

To make this more robust, add a compatibility fallback (e.g., detect attribute availability and fall back to per-tensor dist.broadcast), and consider using a public API if one is available for your supported PyTorch versions.

Suggested change
pg = dist.distributed_c10d._get_default_group()
if len(batch_tensors) > 1:
dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0)
else:
dist.broadcast(batch_tensors[0], src=0)
# Use private APIs if available for efficiency, but fall back to public APIs otherwise.
pg = None
if hasattr(dist, "distributed_c10d") and hasattr(dist.distributed_c10d, "_get_default_group"):
pg = dist.distributed_c10d._get_default_group()
elif hasattr(dist, "group") and hasattr(dist.group, "WORLD"):
pg = dist.group.WORLD
if len(batch_tensors) > 1 and hasattr(dist, "_broadcast_coalesced") and pg is not None:
# Fast path: coalesced broadcast when the private API is available.
dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0)
else:
# Fallback: broadcast each tensor individually using the public API.
for _tensor in batch_tensors:
dist.broadcast(_tensor, src=0)

Copilot uses AI. Check for mistakes.
@@ -930,15 +936,6 @@
if self.cfg.algorithm.use_kl_loss:
status["policy_kl"] = kl_loss.item()

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_forward_backward_micro no longer all-reduces metrics (it now returns per-rank values to be reduced in forward_backward()), but the docstring still says it returns an "All-reduced metrics dict". This can mislead future changes and debugging.

Update the docstring (and/or inline comments) to reflect the new contract: micro-batch metrics are local, and reduction/all-reduce happens once after all micro-batches complete.

Suggested change
# Note: `status` contains per-rank (local) metrics for this micro-batch.
# Any aggregation or all-reduce of metrics happens in `forward_backward()`
# after all micro-batches have been processed.

Copilot uses AI. Check for mistakes.
Comment on lines 1149 to 1156
status = {
"critic_loss": loss.item(),
"values_mean": masked_mean(values, loss_mask).item(),
"values_clipfrac": clipfrac,
"critic_lr": self.scheduler.get_last_lr()[0],
}

# All-reduce metrics across DP workers
status = all_reduce_metrics(status, self.strategy)

return status
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the policy worker: _forward_backward_micro no longer all-reduces metrics, but its docstring still claims the return value is all-reduced. This is now incorrect since reduction/all-reduce happens in forward_backward().

Update the docstring to match the new behavior to avoid confusion when debugging distributed metric discrepancies.

Copilot uses AI. Check for mistakes.
Comment on lines +190 to 205
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad()
return grad_norm

t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
# Only advance LR schedule when gradients were finite (non-finite steps are no-ops)
if scheduler is not None and not non_finite:
scheduler.step()
optimizer.zero_grad()
return grad_norm
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new non-finite grad-norm handling changes correctness-sensitive behavior (all ranks must still call optimizer.step(), scheduler step is skipped, grads are cleared before stepping). There doesn’t appear to be a test covering this branch in the existing FSDP strategy test suite.

Add a unit/integration test that forces grad_norm to be non-finite (e.g., by injecting NaNs into a parameter grad) and asserts: (1) optimizer.step() is still invoked, (2) LR scheduler is not advanced, and (3) parameters do not change on the non-finite step.

Copilot uses AI. Check for mistakes.
# MoE - balancing loss
model_config = self.model.config.to_dict()
if "output_router_logits" in model_config:
num_experts = model_config.get("num_local_experts", 0)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_router_logits is now gated on num_local_experts > 0, but some MoE configs use num_experts (or other keys) instead. Since this file previously enabled the flag based solely on presence of output_router_logits, this change can silently disable router logits for valid MoE models whose config doesn’t define num_local_experts.

Consider detecting MoE via both num_local_experts and num_experts (consistent with the MoE detection logic added in fsdp_worker.py), so the guard works across MoE model families.

Suggested change
num_experts = model_config.get("num_local_experts", 0)
# Detect MoE across model families: prefer num_local_experts, fall back to num_experts
num_experts = model_config.get("num_local_experts")
if num_experts is None:
num_experts = model_config.get("num_experts", 0)

Copilot uses AI. Check for mistakes.
devin-ai-integration[bot]

This comment was marked as resolved.

…ast fallback

- Call _patch_checkpoint_for_moe() for both policy and critic workers
- Detect MoE via both num_local_experts and num_experts config keys
- Use optimizer.zero_grad(set_to_none=True) for non-finite grad_norm path
- Add fallback for private _broadcast_coalesced API (per-tensor broadcast)
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 8 additional findings in Devin Review.

Open in Devin Review

Comment on lines +188 to +200
logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")

# Skip update if gradient norm is not finite
if grad_norm is not None and not torch.isfinite(grad_norm):
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}")
else:
logger.warning(f"grad_norm is not finite: {grad_norm}")
optimizer.zero_grad()
return grad_norm
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)

t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Per-step INFO logging on every rank causes excessive log output in production

The optimizer_step method adds logger.info() calls at lines 188 and 200 that fire on every optimizer step on every rank. All other logger.info calls in this file are for one-time initialization/config events (e.g., architecture name fixes at fsdp_strategy.py:367). In a typical training run with hundreds of ranks and thousands of steps, these two lines will generate millions of INFO log entries, degrading I/O performance and making logs unusable. These should be logger.debug() at most.

Suggested change
logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")
# Skip update if gradient norm is not finite
if grad_norm is not None and not torch.isfinite(grad_norm):
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}")
else:
logger.warning(f"grad_norm is not finite: {grad_norm}")
optimizer.zero_grad()
return grad_norm
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)
t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
logger.debug(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)
t0 = _time.time()
optimizer.step()
logger.debug(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +226 to +232
return torch.tensor(0.0)

handle = holder.handles[gid]
if handle in frame.recomputed.get(gid, {}):
ret = frame.recomputed[gid][handle]
else:
ret = torch.tensor(0.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Checkpoint patch fallback returns scalar CPU tensor causing device/shape mismatch crash

In _patch_checkpoint_for_moe, the unpack_hook has two fallback paths (lines 226 and 232) that return torch.tensor(0.0) — a scalar float32 tensor on CPU. The autograd system expects the unpacked tensor to match the shape, dtype, and device of the original saved tensor (typically a multi-dimensional CUDA tensor). If these paths are triggered during backward, PyTorch will raise a device or shape mismatch error.

Affected fallback paths

Line 226 — when frame.recomputed[gid] is missing or empty:

return torch.tensor(0.0)  # scalar, CPU, float32

Line 232 — when handle not found in frame.recomputed[gid]:

ret = torch.tensor(0.0)  # scalar, CPU, float32

Both should at minimum produce a tensor on the correct device. A safer fallback would be torch.zeros(1, device='cuda') or, better yet, raise an explicit error with a clear diagnostic message.

Prompt for agents
In skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py, inside the _patch_checkpoint_for_moe function's unpack_hook (lines 220-234), the two fallback return paths that return torch.tensor(0.0) should be improved:

1. Line 226: return torch.tensor(0.0) — this is reached when gid is not in frame.recomputed or it's empty. Replace with a tensor on the correct device. Since we're inside a CUDA computation, use torch.tensor(0.0, device='cuda') at minimum, or raise a RuntimeError with a descriptive message about the checkpoint state.

2. Line 232: ret = torch.tensor(0.0) — reached when the handle is not in frame.recomputed[gid]. Same fix: either use the CUDA device or raise a descriptive error.

Ideally, try to infer the expected shape/dtype/device from the holder or frame metadata and produce a properly shaped zero tensor, rather than a scalar.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@SumanthRH SumanthRH self-assigned this Mar 31, 2026
Copy link
Copy Markdown
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still making my way through the fixes, left a few comments on some of the fixes

**kwargs,
) -> Optional[Float[torch.Tensor, "1"]]:
"""Perform optimizer step"""
import time as _time
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +190 to +196
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this modification needed? We skip optimizer.step on all ranks when grad norm is NaN ?

from skyrl.train.config.config import InferenceEngineConfig


def _patch_moe_experts_for_fsdp2(model: nn.Module):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these patches still needed with transformers 5.0.0?

pytorch/pytorch#171355 this seems to be the original issue, and it seems to have been fixed in transformers in February: pytorch/pytorch#171355 (comment)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, let's also reorg and move the patches to a separate patches.py file

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the patch is needed, let's also verify E2E training with this? A first step would be to ensure some of our GPU CI tests pass (but modified to use Qwen 3 model):

https://github.com/NovaSky-AI/SkyRL/blob/main/tests/backends/skyrl_train/gpu/gpu_ci/test_model_wrapper.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants