Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions megatron/optimizer/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def adjust_lr_for_muon(self, lr, param_shape):
"""
A, B = param_shape[:2]
# We adjust the learning rate based on the size of the parameter matrix
#adjusted_ratio = max(1.0, float(A) / float(B)) ** 0.5
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
#adjusted_ratio = max(1.0, float(A) / float(B)) ** 0.5 #Muon implementation that enables MuP, Uncomment to have MuP enabled for Muon
adjusted_ratio = 0.2 * math.sqrt(max(A, B)) #comment if enabling above line is uncommented
adjusted_lr = lr * adjusted_ratio
return adjusted_lr

Expand Down
34 changes: 0 additions & 34 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,40 +1182,6 @@ def train(
if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss

# explode = (batch_num > 1 and smoothed_loss > 4 * best_loss)


# if explode:
# if mpu.get_data_parallel_rank() == 0:
# log.info(f"Loss exploding at lr={curr_lr:.8f}, stopping LR finder")
# break

## Record the best loss (use the same global smoothed_loss)
#if smoothed_loss < best_loss or batch_num == 1:
# best_loss = smoothed_loss

# --- GLOBALIZE the "loss exploding" decision (any rank => all ranks) ---
#explode_local = (batch_num > 1 and smoothed_loss > 4 * best_loss)
#print(f"Rank {mpu.get_data_parallel_rank()}: explode_local={explode_local}")
#if mpu.get_data_parallel_rank() == 0:

# print(f"Iter {i}: batch_num={batch_num}, smoothed_loss={smoothed_loss:.8f}, best_loss={best_loss:.8f}, ratio={smoothed_loss/best_loss:.2f}")
#if tdist.is_available() and tdist.is_initialized():
# _exp = torch.tensor([1 if explode_local else 0], device=dev)
# tdist.all_reduce(_exp, op=tdist.ReduceOp.MAX, group=dp_group)
# explode = bool(_exp.item())
#else:
# explode = explode_local



#if explode:
# if mpu.get_data_parallel_rank() == 0:
# log.info(f"Loss exploding at lr={curr_lr:.8f}, stopping LR finder")
# Keep everyone in lockstep before breaking
# if tdist.is_available() and tdist.is_initialized():
# tdist.barrier(group=dp_group)
# break

# Record values for plotting
lr_finder_losses.append(smoothed_loss)
Expand Down