-
Notifications
You must be signed in to change notification settings - Fork 286
Description
Overview
Tracking a set of changes that enable stable training of Mixture-of-Experts models (tested with Qwen3-30B-A3B MoE) under FSDP2 with
non-reentrant gradient checkpointing and Ulysses sequence parallelism. These changes also fix several NCCL deadlock scenarios and
improve training observability. Summarizing here for visibility and early feedback.
Changes
1. MoE Expert Patching for FSDP2 — fsdp_worker.py
Problem: MoE implementations (e.g., Qwen3MoeSparseMoeBlock, Qwen3MoeExperts) only iterate over experts that received tokens
(via nonzero()). With FSDP2 + non-reentrant gradient checkpointing, the recompute forward may route tokens to different experts
than the original forward, causing check_recomputed_tensors_match assertion failures due to differing saved tensor counts.
Fix: _patch_moe_experts_for_fsdp2(model) replaces the selective expert loop with one that iterates ALL experts unconditionally,
ensuring a deterministic computation graph. Supports both the nn.ModuleList layout (transformers ≤ 4.57.x) and the fused 3D
parameter tensor layout (transformers ≥ 4.58.x). A secondary patch (_patch_checkpoint_for_moe) is included but commented out — it
makes unpack_hook tolerant of missing tensor handles as a fallback.
2. FSDP2 Meta Tensor Init for Tied Embeddings — fsdp_worker.py
Problem: FSDP1 cannot use meta tensor initialization when tie_word_embeddings=True. The existing code disabled meta init for
all FSDP versions when embeddings are tied.
Fix: FSDP2 handles tied embeddings correctly via broadcast + tie_weights(), so meta tensor init is now always enabled for FSDP2
(only FSDP1 retains the old guard). Additionally, tie_weights() is called before state_dict() in fsdp_strategy.py to ensure
the tied parameters are properly linked.
3. Non-Finite Grad Norm NCCL Deadlock Fix — fsdp_strategy.py
Problem: When grad_norm is non-finite, optimizer_step() called optimizer.zero_grad() and returned early, skipping
optimizer.step(). Since FSDP's optimizer.step() involves NCCL collectives, skipping it on some ranks causes a deadlock.
Fix: All ranks now proceed through optimizer.step() even when grad_norm is non-finite. Gradients are zeroed beforehand so the
step is harmless. Timing logs added for clip_grad_norm_ and optimizer.step().
4. Batched/Coalesced Broadcast for FSDP2 State Dict Loading — fsdp_utils.py
Problem: fsdp2_load_full_state_dict() performed one dist.broadcast() per parameter. For MoE models with 18,000+ parameters,
this resulted in thousands of NCCL calls, taking minutes.
Fix: Parameters are accumulated into 500 MB batches and sent via dist._broadcast_coalesced(), reducing initialization time to
seconds. Memory is freed incrementally via full_sd.pop().
5. Pre-gathered Position IDs for Ulysses SP — monkey_patch.py, model_wrapper.py
Problem: _ulysses_flash_attention_forward() runs torch.distributed.all_gather on position_ids inline. During gradient
checkpointing backward recompute, this NCCL collective can deadlock or produce incorrect results because the communication context
differs from the original forward.
Fix: set_ulysses_position_ids() pre-computes and caches the all-gathered position_ids in module-level globals before entering
the model (outside the checkpointed region). The attention forward uses the cached version, with an inline fallback for
non-checkpointed paths. Also handles models (e.g., GraniteMoeHybrid) that don't propagate position_ids through decoder layers.
6. MoE Router Logits Guard — model_wrapper.py
Problem: output_router_logits was set to True for any model whose config contained the key, including non-MoE models.
Fix: Added a num_local_experts > 0 guard so only actual MoE models get output_router_logits=True.
7. All-Reduce Metrics After All Micro-Batches — worker.py
Problem: all_reduce_metrics() was called inside _forward_backward_micro() (per micro-batch). With FSDP2 + MoE, backward
gradient reductions may still be in-flight on the NCCL stream, and issuing dist.all_reduce there causes a deadlock.
Fix: Moved all_reduce_metrics() to forward_backward(), called once after all micro-batches complete. Both policy and critic
workers are updated.
8. NUMA Affinity Rewrite — worker.py
Problem: The bitmask-based NUMA API (numa_parse_nodestring, numa_run_on_node_mask, numa_set_membind) caused segfaults from
bitmask pointer corruption. numa_num_configured_nodes() incorrectly counts virtual NVLink NUMA IDs on GB200 systems. Bare indexing
into CUDA_VISIBLE_DEVICES caused IndexError with fewer than 8 GPUs.
Fix: Replaced with integer-based API (numa_run_on_node, numa_set_preferred), numa_max_node() for real node counting,
numa_available() check, bounds-safe GPU ID lookup, and detailed logging.
9. Per-Micro-Batch Dispatch with Progress Logging — trainer.py
Problem: The trainer dispatched entire mini-batches at once, providing no visibility into per-micro-batch progress for long
training steps.
Fix: The trainer now iterates over micro-batches individually, logging elapsed time, estimated time remaining, and
per-micro-batch latency. Timing logs also added for optim_step.
10. Worker Dispatch Observability — worker_dispatch.py
Added get_dp_size() method and timing logs around ray.get() and _save_memory_snapshot() calls in
forward_backward_from_staged() and optim_step().
11. vLLM RAY_ADDRESS Fix — vllm_engine.py
Problem: vLLM's EngineCore subprocess fails with KeyError: 'bundles' when accessing placement_group_table() because it can't
reach the Ray GCS, which causes issues for tensor parallel/pipeline parallel inference.
Fix: Propagate RAY_ADDRESS from the Ray runtime context when not already set.