feat: enable GRPO training with logprobs from offline trajectory data #467
+77
−7
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR enables proper GRPO training with importance sampling when using offline trajectory data (e.g., from vLLM traces). It includes four complementary changes:
1. Extract logprobs from dict messages
Problem: ART's tokenizer only extracted logprobs from OpenAI
Choiceobjects, but offline trajectory data often stores logprobs in plain Python dicts. This caused all dict message logprobs to be set to NaN, making the importance ratio = 1.0 always (effectively REINFORCE instead of GRPO).Solution: Modified
tokenize.pyto also extract logprobs from dict messages that have the format{"logprobs": {"content": [{"logprob": -0.5}, ...]}}.2. Strip logprobs before RULER scoring
Problem: When trajectories contain verbose logprobs data, sending them to the RULER judge causes context length errors.
Solution: Strip logprobs from trajectories before sending to RULER using
strip_logprobs().3. Preserve
_internal_config.engine_argsProblem: When using
TrainableModel._internal_config.engine_argsto configure vLLM engine settings (likemax_logprobs), the configuration was silently lost when using the SkyPilot backend.Solution: Add a
model_validator(mode="wrap")to preserve_internal_configduring Pydantic deserialization.4. Add importance sampling observability metrics
Problem: ART computes importance sampling ratios internally but doesn't expose them, making it impossible to verify if importance sampling is actually working.
Solution: Add three new metrics logged during training:
frac_old_logprobs_valid: Fraction of old logprobs that are not NaN (0 = no importance sampling)mean_importance_ratio: Mean π_new/π_old across assistant tokens (should vary around 1.0)clip_fraction: Fraction of tokens where PPO clipping was triggered (>0 means off-policy correction active)Impact
π_new / π_oldfrac_old_logprobs_valid,mean_importance_ratio,clip_fractionNew Metrics Interpretation
frac_old_logprobs_validmean_importance_ratioclip_fractionTest plan
max_logprobssetting works with SkyPilot backend./scripts/run_checks.sh- all checks pass