-
Notifications
You must be signed in to change notification settings - Fork 286
[skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy
#1296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
589c150
333f31a
aaaba4c
a121360
15de89a
e3842c3
13bfe80
e76bece
0192e8e
4ee0b31
c8f06cc
2c13315
14ba02e
0cfc95b
717c3a7
661f5d8
5cc95a1
ce8f6aa
1a60bb5
c5feb83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -246,10 +246,21 @@ def loss_func(logits, data): | |
| loss_mask = data["loss_mask"] | ||
| rollout_action_logprobs = data["rollout_action_logprobs"] | ||
| action_mask = data.get("action_mask") | ||
| num_microbatches = data.get("num_microbatches") | ||
|
|
||
| dp_size = mpu.get_data_parallel_world_size() | ||
| tp_grp = mpu.get_tensor_model_parallel_group() | ||
| tp_rank = mpu.get_tensor_model_parallel_rank() | ||
|
|
||
| # Policy losses are pre-scaled to achieve the correct loss_reduction when summing across the entire minibatch | ||
| # (see `apply_loss_reduction_to_advantages_minibatch`). | ||
| # Megatron divides loss by num_microbatches | ||
| # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/pipeline_parallel/schedules.py#L248) | ||
| # and the data parallel all-reduce averages gradients across dp_size | ||
| # (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.15.2/megatron/core/distributed/distributed_data_parallel.py#L285) | ||
| # so we multiply by both factors to recover the correct sum reduction. | ||
| grad_sum_correction_factor = num_microbatches * dp_size | ||
|
|
||
| # temperature normalization | ||
| if temperature != 1.0: | ||
| logits.div_(temperature) | ||
|
|
@@ -279,13 +290,15 @@ def loss_func(logits, data): | |
|
|
||
| # SFT path: cross_entropy loss (negative log likelihood) | ||
| if resolved_loss_name == "cross_entropy": | ||
| loss = policy_loss | ||
| unscaled_loss = policy_loss | ||
| loss = unscaled_loss * grad_sum_correction_factor | ||
|
Comment on lines
292
to
+294
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: should this affect the SFT case? SFT doesn't look at the normalized advantages either, similar to the critic loss case. Before the PR, the SFT case does a sum across the negative log likelihoods within a microbatch, but still averaged over microbatches and dp workers. |
||
|
|
||
| # Compute elementwise loss for Tinker API (per-token NLL) | ||
| with torch.no_grad(): | ||
| elementwise_loss = -action_log_probs | ||
| if loss_mask is not None: | ||
| elementwise_loss = elementwise_loss * loss_mask | ||
| elementwise_loss = elementwise_loss * grad_sum_correction_factor | ||
|
|
||
| # Build per-sequence loss_fn_outputs | ||
| batch_size = action_log_probs.shape[0] | ||
|
|
@@ -310,7 +323,7 @@ def loss_func(logits, data): | |
| ) | ||
|
|
||
| metrics = { | ||
| "loss": loss.detach().item(), | ||
| "loss": unscaled_loss.detach().item(), | ||
| "response_length": num_actions, | ||
| "loss_fn_outputs": loss_fn_outputs, | ||
| } | ||
|
|
@@ -340,7 +353,8 @@ def loss_func(logits, data): | |
| kl_loss = torch.tensor(0.0) | ||
| kl_loss_term = kl_loss * loss_config.kl_loss_coef | ||
|
|
||
| loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| unscaled_loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| loss = unscaled_loss * grad_sum_correction_factor | ||
|
Comment on lines
+356
to
+357
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 KL loss and entropy loss gradients amplified by num_microbatches in Megatron model wrapper Same issue as the FSDP worker but in the Megatron path. Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| # Build per-sequence loss_fn_outputs with logprobs. | ||
| batch_size = action_log_probs.shape[0] | ||
|
|
@@ -363,7 +377,7 @@ def loss_func(logits, data): | |
| ) | ||
|
|
||
| metrics = { | ||
| "final_loss": loss.detach().item(), | ||
| "final_loss": unscaled_loss.detach().item(), | ||
| "policy_loss": policy_loss.detach().item(), | ||
|
Comment on lines
+380
to
+381
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Metrics fix 1: remove |
||
| "policy_entropy": entropy.detach().item(), | ||
| "policy_kl": kl_loss.detach().item(), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.