Feat/r3 prod v2#2593
Draft
samsja wants to merge 7 commits into
Draft
Conversation
…ut error (#2590) * feat(scheduler): train on partial groups instead of dropping on rollout error When an individual-scoring env returns N rollouts for a single example and one errors out, scrap only the failed rollout — keep the survivors and ship the group through as soon as every dispatched rollout has come back (success or failure). Group-scoring envs still drop the whole group on any failure because their per-rollout scores are computed against the now-missing rollouts. To make variable-size groups round-trip through advantage computation, group rollouts by (env_name, example_id) instead of positional slicing, and bucket groups by size so each advantage_fn call still sees a uniform 2D rewards tensor. Singleton groups produce zero advantage and get filtered out by the existing zero-advantage filter — no special-casing. Closes #2585. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor(advantage): make advantage_fn per-group Drop the bucket-by-size workaround in compute_advantages by changing the advantage_fn contract: AdvantageInputs.rollouts is now a single group (list[RolloutOutput]) and AdvantageOutputs.advantages is 1D. The framework calls advantage_fn once per group, which works cleanly for variable-size groups (partial-group training). BREAKING: second change to this public API in three weeks. Custom advantage functions must drop the outer list dim. Migration documented in CHANGELOG.md and docs/bring-your-own-algorithms.md. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor(advantage): return list[float] from advantage_fn Drops the torch tensor from the public AdvantageOutputs contract; internal math stays in torch and converts via .tolist() at the boundary. Same partial- group support, simpler downstream consumers (no more .tolist() / no shape gymnastics in custom advantages). BREAKING (folds into the per-group change in the previous commit): custom advantage functions must return AdvantageOutputs(advantages=[...]) rather than a tensor. CHANGELOG entry and docs example updated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor(scheduler): log each rollout failure inline, drop GroupState.last_failure_reason The reason is now logged at the moment the failure is observed (one warning per failed rollout) instead of being stashed on the group and replayed at finalization. Removes the per-group field entirely and avoids the "first vs latest wins" semantic question that came up in review - each log line carries its own actual reason. Finalization warnings only carry counts now. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore(scheduler): drop verbose comment on GroupState.failed_rollouts Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…2589) * feat(orchestrator): per-env state_columns for extra rollout fields Adds `state_columns: list[str] = []` to `EnvConfig` so each env can persist additional `State` fields into the saved JSONL rollouts on top of the always-saved `trajectory` and `sampling_args`. Merged at the call site (required first, deduped). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: drop seen set from state_columns dedup Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each DP rank was dividing its summed token loss by its own local loss_scale, then FSDP averaged the resulting gradients. Because ranks process different sequence lengths, that mean is not the true per-token mean over the global batch — ranks with fewer loss tokens get implicitly upweighted. Mirror the SFT trainer fix (src/prime_rl/trainer/sft/train.py:416-427): all-reduce the local token count across dp_cp, divide by that global denominator on every rank, and multiply grads by fsdp_gradient_divide_factor after the microbatch loop so FSDP's per-rank averaging is undone and the final gradient is the per-token mean over the global batch. Closes #2358. Adapted from #2359, which first diagnosed the bias and proposed the all-reduce-then-rescale approach. Co-authored-by: irfanjamil <irfanjamil9@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* add per-env trainer metrics
* require env names for trainer batches
* address per-env metric review feedback
* reuse tensor stats for env metrics
* fix sft trajectory env-name fixture
* address trainer metric naming comments
* fix: reuse trainer ratio tensors for env metrics
* fix: derive dppo mask from shared ratio
* address PR review: drop precomputed loss inputs, use {all,env} keys
- compute importance-ratio / mismatch_kl inside the loss functions instead of
passing them in via LossInputs (per Mika)
- compute mismatch_kl inline in train.py only for per-env logging
- rename trainer aggregate keys to entropy/all and mismatch_kl/all to match the
orchestrator {all,env}/{mean,std,max} convention; drop the leftover bare
entropy/* and mismatch_kl/* keys
- drop the overly defensive env_names length check in DataLoader
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* guard mismatch_kl logging behind sft_loss flag
SFT batches don't have meaningful inference_logprobs (sft_loss_fn ignores
them), so computing and logging mismatch_kl for those microbatches is wasted
work and produces misleading numbers. Skip the inline mismatch_kl compute and
the per-env / debug-log emissions when sft_loss is True.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* address review: simplify probs_diff and move mismatch_kl/all to trainer loop
* reserve env_name='all' for aggregate metric keys
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* feat: add trainer token jsonl export * chore: use docstrings for token export config
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.