Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions skyrl/backends/skyrl_train/distributed/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,31 +169,37 @@ def optimizer_step(
**kwargs,
) -> Optional[Float[torch.Tensor, "1"]]:
"""Perform optimizer step"""
import time as _time
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The time module is imported inside the optimizer_step method. According to Python style guides (like PEP 8), imports should be placed at the top of the file. This improves readability and avoids potential overhead if the method is called frequently in a hot loop.

Please move this import to the top of the file.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


rank = dist.get_rank() if dist.is_initialized() else -1
grad_norm = None
if isinstance(model, HFModelWrapper):
model = model.model

if self.max_norm > 0:
t0 = _time.time()
# NOTE (sumanthrh): All `grad_norm`s returned here are the original grad norms before clipping.
if isinstance(model, FSDP):
grad_norm = model.clip_grad_norm_(max_norm=self.max_norm)
elif isinstance(model, FSDPModule):
grad_norm = fsdp2_clip_grad_norm_(model.parameters(), max_norm=self.max_norm)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.max_norm)
logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")

# Skip update if gradient norm is not finite
if grad_norm is not None and not torch.isfinite(grad_norm):
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}")
else:
logger.warning(f"grad_norm is not finite: {grad_norm}")
optimizer.zero_grad()
return grad_norm
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)
Comment on lines +190 to +196
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this modification needed? We skip optimizer.step on all ranks when grad norm is NaN ?


t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
Comment on lines +188 to +200
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info timing logs are emitted on every rank for every training step (clip_grad_norm_ and optimizer.step). In long runs this can generate extremely large logs and can materially slow training due to logging overhead.

Consider making these logs debug-level, gating them behind a config flag, and/or restricting to rank 0 (or periodic sampling) to avoid impacting throughput.

Copilot uses AI. Check for mistakes.
Comment on lines +188 to +200
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Per-step INFO logging on every rank causes excessive log output in production

The optimizer_step method adds logger.info() calls at lines 188 and 200 that fire on every optimizer step on every rank. All other logger.info calls in this file are for one-time initialization/config events (e.g., architecture name fixes at fsdp_strategy.py:367). In a typical training run with hundreds of ranks and thousands of steps, these two lines will generate millions of INFO log entries, degrading I/O performance and making logs unusable. These should be logger.debug() at most.

Suggested change
logger.info(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")
# Skip update if gradient norm is not finite
if grad_norm is not None and not torch.isfinite(grad_norm):
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}")
else:
logger.warning(f"grad_norm is not finite: {grad_norm}")
optimizer.zero_grad()
return grad_norm
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)
t0 = _time.time()
optimizer.step()
if scheduler is not None:
logger.info(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
logger.debug(f"[rank {rank}] clip_grad_norm_ done in {_time.time() - t0:.1f}s, grad_norm={grad_norm}")
# NOTE: With FSDP, optimizer.step() involves NCCL collectives. ALL ranks
# must call it even if grad_norm is non-finite, otherwise NCCL deadlocks.
# We zero_grad before stepping so the non-finite update is harmless.
non_finite = grad_norm is not None and not torch.isfinite(grad_norm)
if non_finite:
logger.warning(f"rank {rank} grad_norm is not finite: {grad_norm}, zeroing grads before step")
optimizer.zero_grad(set_to_none=True)
t0 = _time.time()
optimizer.step()
logger.debug(f"[rank {rank}] optimizer.step() done in {_time.time() - t0:.1f}s")
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

# Only advance LR schedule when gradients were finite (non-finite steps are no-ops)
if scheduler is not None and not non_finite:
scheduler.step()
optimizer.zero_grad()
return grad_norm
Comment on lines +190 to 205
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new non-finite grad-norm handling changes correctness-sensitive behavior (all ranks must still call optimizer.step(), scheduler step is skipped, grads are cleared before stepping). There doesn’t appear to be a test covering this branch in the existing FSDP strategy test suite.

Add a unit/integration test that forces grad_norm to be non-finite (e.g., by injecting NaNs into a parameter grad) and asserts: (1) optimizer.step() is still invoked, (2) LR scheduler is not advanced, and (3) parameters do not change on the non-finite step.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -273,9 +279,12 @@ def _fsdp_init_model(self, model, is_train=True, is_wrapped=False):
"reshard_after_forward": self.fsdp_config.reshard_after_forward,
}
module = model.model if is_wrapped else model
if getattr(module.config, "tie_word_embeddings", False):
module.tie_weights()
full_state = module.state_dict()
apply_fsdp2(module, fsdp_kwargs, self.fsdp_config)
fsdp2_load_full_state_dict(module, full_state, cpu_offload)
del full_state # free CPU memory (rank 0 held full model copy)
fsdp_module = module
else:
raise NotImplementedError(f"{self.fsdp_strategy} not implemented")
Expand Down
72 changes: 48 additions & 24 deletions skyrl/backends/skyrl_train/distributed/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,33 +293,57 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
tensor = tensor.contiguous()
return tensor

if dist.get_rank() == 0:
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
full_param = full_param.detach().cuda()
mesh = sharded_param.device_mesh
dist.broadcast(full_param, src=0)
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_param,
)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
else:
for param_name, sharded_param in meta_sharded_sd.items():
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
# Batched broadcast: coalesce many small tensors into fewer NCCL calls.
# For MoE models with 18,000+ params, this reduces init from minutes to seconds.
BATCH_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB per coalesced broadcast

param_names = list(meta_sharded_sd.keys())
batch_tensors = []
batch_names = []
batch_bytes = 0

def _flush_batch():
"""Broadcast current batch and distribute to shards."""
if not batch_tensors:
return
# Use coalesced broadcast when available (private API), fall back to per-tensor
if len(batch_tensors) > 1 and hasattr(dist, "_broadcast_coalesced"):
pg = dist.distributed_c10d._get_default_group()
dist._broadcast_coalesced(pg, batch_tensors, BATCH_SIZE_BYTES, 0)
else:
for t in batch_tensors:
dist.broadcast(t, src=0)

for name, full_tensor in zip(batch_names, batch_tensors):
sharded_param = meta_sharded_sd[name]
mesh = sharded_param.device_mesh
dist.broadcast(full_tensor, src=0)
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
full_tensor,
)
to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, full_tensor)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
sharded_sd[param_name] = sharded_tensor
sharded_sd[name] = sharded_tensor

for param_name in param_names:
sharded_param = meta_sharded_sd[param_name]
if dist.get_rank() == 0:
full_tensor = full_sd.pop(param_name).detach().cuda()
else:
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)

tensor_bytes = full_tensor.nelement() * full_tensor.element_size()

# If adding this tensor exceeds batch size, flush current batch first
if batch_bytes + tensor_bytes > BATCH_SIZE_BYTES and batch_tensors:
_flush_batch()
batch_tensors.clear()
batch_names.clear()
batch_bytes = 0

batch_tensors.append(full_tensor)
batch_names.append(param_name)
batch_bytes += tensor_bytes

# Flush remaining tensors
_flush_batch()

# we set `assign=True` because our params can be on meta device
model.load_state_dict(sharded_sd, assign=True)
Expand Down
Loading
Loading