|
| 1 | +# Technical Design: Importance Sampling Observability Metrics |
| 2 | + |
| 3 | +## Problem Statement |
| 4 | + |
| 5 | +ART computes importance sampling ratios internally for PPO/GRPO training but does not expose these metrics for monitoring. Users have no visibility into: |
| 6 | + |
| 7 | +1. Whether logprobs are being extracted correctly from trajectories |
| 8 | +2. Whether importance sampling is actually active (vs. falling back to REINFORCE) |
| 9 | +3. How often PPO clipping is triggered |
| 10 | + |
| 11 | +This makes it difficult to debug training issues and verify that the importance sampling pipeline is working correctly. |
| 12 | + |
| 13 | +### Background: How Importance Sampling Works in ART |
| 14 | + |
| 15 | +``` |
| 16 | +Rollout Phase |
| 17 | + │ |
| 18 | + ▼ |
| 19 | +Trajectories with logprobs attached to messages |
| 20 | + │ |
| 21 | + ▼ |
| 22 | +Tokenization Phase (tokenize.py) |
| 23 | + │ |
| 24 | + ├─► Dict messages: extract logprobs if present, else NaN |
| 25 | + └─► Choice objects: extract logprobs if present |
| 26 | + │ |
| 27 | + ▼ |
| 28 | +Training Phase (train.py) |
| 29 | + │ |
| 30 | + ├─► If logprobs are NaN: set old_logprobs = new_logprobs.detach() |
| 31 | + │ └─► prob_ratio = exp(0) = 1.0 (NO importance sampling) |
| 32 | + │ |
| 33 | + └─► If logprobs are real: compute prob_ratio = exp(new - old) |
| 34 | + └─► PPO clipping applied when ratio outside [1-ε, 1+ε] |
| 35 | +``` |
| 36 | + |
| 37 | +When all logprobs are NaN, ART silently falls back to vanilla REINFORCE (advantage-weighted policy gradient with no off-policy correction). This is valid but may not be what users expect. |
| 38 | + |
| 39 | +## Solution |
| 40 | + |
| 41 | +Add three new metrics to ART's training loop that are logged to wandb: |
| 42 | + |
| 43 | +### 1. `frac_old_logprobs_valid` |
| 44 | + |
| 45 | +**What it measures:** Fraction of `old_logprobs` values that are NOT NaN at training time. |
| 46 | + |
| 47 | +**Implementation:** |
| 48 | +```python |
| 49 | +old_logprobs_nan_mask = torch.isnan(old_logprobs) |
| 50 | +frac_old_logprobs_valid = 1.0 - ( |
| 51 | + old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6) |
| 52 | +).item() |
| 53 | +``` |
| 54 | + |
| 55 | +**Interpretation:** |
| 56 | +| Value | Meaning | |
| 57 | +|-------|---------| |
| 58 | +| 0.0 | All logprobs are NaN - importance sampling NOT active | |
| 59 | +| ~0.3-0.5 | Partial logprobs - some tokens have valid logprobs | |
| 60 | +| ~0.8-1.0 | Most logprobs valid - importance sampling fully active | |
| 61 | + |
| 62 | +**Why not exactly 1.0?** System messages, tool calls, and prompt tokens don't have logprobs - only assistant response tokens do. |
| 63 | + |
| 64 | +### 2. `mean_importance_ratio` |
| 65 | + |
| 66 | +**What it measures:** Mean importance sampling ratio π_new(a|s) / π_old(a|s) across assistant tokens. |
| 67 | + |
| 68 | +**Implementation:** |
| 69 | +```python |
| 70 | +mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6) |
| 71 | +``` |
| 72 | + |
| 73 | +**Interpretation:** |
| 74 | +| Value | Meaning | |
| 75 | +|-------|---------| |
| 76 | +| Exactly 1.0 | No distribution shift (or all NaN logprobs) | |
| 77 | +| 0.8 - 1.2 | Healthy training - policy evolving gradually | |
| 78 | +| < 0.5 or > 2.0 | Large distribution shift - may indicate issues | |
| 79 | + |
| 80 | +### 3. `clip_fraction` |
| 81 | + |
| 82 | +**What it measures:** Fraction of assistant tokens where PPO clipping was triggered. |
| 83 | + |
| 84 | +**Implementation:** |
| 85 | +```python |
| 86 | +clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) |
| 87 | +is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high) |
| 88 | +clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6) |
| 89 | +``` |
| 90 | + |
| 91 | +**Interpretation:** |
| 92 | +| Value | Meaning | |
| 93 | +|-------|---------| |
| 94 | +| 0.0 | No clipping - either on-policy or no importance sampling | |
| 95 | +| 0.01 - 0.1 | Healthy - some off-policy correction happening | |
| 96 | +| > 0.3 | High clipping - policy has diverged significantly from rollout policy | |
| 97 | + |
| 98 | +## Implementation Details |
| 99 | + |
| 100 | +### Files Modified |
| 101 | + |
| 102 | +**`src/art/unsloth/train.py`** |
| 103 | + |
| 104 | +1. Compute `frac_old_logprobs_valid` before the NaN replacement: |
| 105 | +```python |
| 106 | +old_logprobs_nan_mask = torch.isnan(old_logprobs) |
| 107 | +frac_old_logprobs_valid = 1.0 - ( |
| 108 | + old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6) |
| 109 | +).item() |
| 110 | +old_logprobs = torch.where( |
| 111 | + old_logprobs_nan_mask, # reuse mask |
| 112 | + new_logprobs.detach(), |
| 113 | + old_logprobs, |
| 114 | +) |
| 115 | +``` |
| 116 | + |
| 117 | +2. Compute clip metrics after prob_ratio calculation: |
| 118 | +```python |
| 119 | +clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) |
| 120 | +is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high) |
| 121 | +clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6) |
| 122 | +mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6) |
| 123 | +``` |
| 124 | + |
| 125 | +3. Log the new metrics: |
| 126 | +```python |
| 127 | +trainer._metrics["train"]["frac_old_logprobs_valid"].append(frac_old_logprobs_valid) |
| 128 | +trainer._metrics["train"]["mean_importance_ratio"].append(mean_importance_ratio.item()) |
| 129 | +trainer._metrics["train"]["clip_fraction"].append(clip_fraction.item()) |
| 130 | +``` |
| 131 | + |
| 132 | +### Performance Impact |
| 133 | + |
| 134 | +- **Memory:** Negligible - reuses existing tensors, only adds scalar computations |
| 135 | +- **Compute:** Negligible - O(n) operations on existing tensors |
| 136 | +- **Logging overhead:** 3 additional floats per training step |
| 137 | + |
| 138 | +## Use Cases |
| 139 | + |
| 140 | +### 1. Debugging Missing Logprobs |
| 141 | + |
| 142 | +If `frac_old_logprobs_valid = 0`: |
| 143 | +- Check that rollout is requesting logprobs from the model |
| 144 | +- Check that logprobs are being attached to trajectory messages |
| 145 | +- Check tokenization is extracting logprobs correctly (especially for dict messages) |
| 146 | + |
| 147 | +### 2. Monitoring Training Health |
| 148 | + |
| 149 | +Healthy training should show: |
| 150 | +- `frac_old_logprobs_valid` stable and > 0 |
| 151 | +- `mean_importance_ratio` fluctuating around 1.0 |
| 152 | +- `clip_fraction` low but non-zero |
| 153 | + |
| 154 | +### 3. Detecting Distribution Drift |
| 155 | + |
| 156 | +If `clip_fraction` suddenly increases: |
| 157 | +- Policy may have diverged too far from rollout policy |
| 158 | +- Consider reducing learning rate or increasing rollout frequency |
| 159 | + |
| 160 | +## Backwards Compatibility |
| 161 | + |
| 162 | +These changes are additive - existing code continues to work. The new metrics appear in wandb logs automatically if wandb is configured. |
| 163 | + |
| 164 | +## Testing |
| 165 | + |
| 166 | +Manual verification: |
| 167 | +1. Run training with valid logprobs → `frac_old_logprobs_valid > 0` |
| 168 | +2. Run training with `allow_training_without_logprobs=True` and no logprobs → `frac_old_logprobs_valid = 0` |
| 169 | +3. Verify `mean_importance_ratio` deviates from 1.0 over training steps |
| 170 | + |
| 171 | +## Related Work |
| 172 | + |
| 173 | +- PPO paper (Schulman et al., 2017) discusses importance sampling and clipping |
| 174 | +- TRL's `PPOTrainer` logs similar metrics (`clipfrac`, `ratio`) |
| 175 | +- This brings ART's observability closer to standard PPO implementations |
0 commit comments