Skip to content

Commit 61a1526

Browse files
JRMeyerclaude
andcommitted
feat: add importance sampling observability metrics
Adds three new metrics logged during training to help users verify that importance sampling is working correctly: - frac_old_logprobs_valid: Fraction of old logprobs that are not NaN - mean_importance_ratio: Mean π_new/π_old across assistant tokens - clip_fraction: Fraction of tokens where PPO clipping was triggered These metrics help diagnose whether GRPO/PPO importance sampling is active or if training has fallen back to vanilla REINFORCE (when all logprobs are NaN). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 3d79de3 commit 61a1526

File tree

3 files changed

+198
-3
lines changed

3 files changed

+198
-3
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Technical Design: Importance Sampling Observability Metrics
2+
3+
## Problem Statement
4+
5+
ART computes importance sampling ratios internally for PPO/GRPO training but does not expose these metrics for monitoring. Users have no visibility into:
6+
7+
1. Whether logprobs are being extracted correctly from trajectories
8+
2. Whether importance sampling is actually active (vs. falling back to REINFORCE)
9+
3. How often PPO clipping is triggered
10+
11+
This makes it difficult to debug training issues and verify that the importance sampling pipeline is working correctly.
12+
13+
### Background: How Importance Sampling Works in ART
14+
15+
```
16+
Rollout Phase
17+
18+
19+
Trajectories with logprobs attached to messages
20+
21+
22+
Tokenization Phase (tokenize.py)
23+
24+
├─► Dict messages: extract logprobs if present, else NaN
25+
└─► Choice objects: extract logprobs if present
26+
27+
28+
Training Phase (train.py)
29+
30+
├─► If logprobs are NaN: set old_logprobs = new_logprobs.detach()
31+
│ └─► prob_ratio = exp(0) = 1.0 (NO importance sampling)
32+
33+
└─► If logprobs are real: compute prob_ratio = exp(new - old)
34+
└─► PPO clipping applied when ratio outside [1-ε, 1+ε]
35+
```
36+
37+
When all logprobs are NaN, ART silently falls back to vanilla REINFORCE (advantage-weighted policy gradient with no off-policy correction). This is valid but may not be what users expect.
38+
39+
## Solution
40+
41+
Add three new metrics to ART's training loop that are logged to wandb:
42+
43+
### 1. `frac_old_logprobs_valid`
44+
45+
**What it measures:** Fraction of `old_logprobs` values that are NOT NaN at training time.
46+
47+
**Implementation:**
48+
```python
49+
old_logprobs_nan_mask = torch.isnan(old_logprobs)
50+
frac_old_logprobs_valid = 1.0 - (
51+
old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6)
52+
).item()
53+
```
54+
55+
**Interpretation:**
56+
| Value | Meaning |
57+
|-------|---------|
58+
| 0.0 | All logprobs are NaN - importance sampling NOT active |
59+
| ~0.3-0.5 | Partial logprobs - some tokens have valid logprobs |
60+
| ~0.8-1.0 | Most logprobs valid - importance sampling fully active |
61+
62+
**Why not exactly 1.0?** System messages, tool calls, and prompt tokens don't have logprobs - only assistant response tokens do.
63+
64+
### 2. `mean_importance_ratio`
65+
66+
**What it measures:** Mean importance sampling ratio π_new(a|s) / π_old(a|s) across assistant tokens.
67+
68+
**Implementation:**
69+
```python
70+
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
71+
```
72+
73+
**Interpretation:**
74+
| Value | Meaning |
75+
|-------|---------|
76+
| Exactly 1.0 | No distribution shift (or all NaN logprobs) |
77+
| 0.8 - 1.2 | Healthy training - policy evolving gradually |
78+
| < 0.5 or > 2.0 | Large distribution shift - may indicate issues |
79+
80+
### 3. `clip_fraction`
81+
82+
**What it measures:** Fraction of assistant tokens where PPO clipping was triggered.
83+
84+
**Implementation:**
85+
```python
86+
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
87+
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
88+
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
89+
```
90+
91+
**Interpretation:**
92+
| Value | Meaning |
93+
|-------|---------|
94+
| 0.0 | No clipping - either on-policy or no importance sampling |
95+
| 0.01 - 0.1 | Healthy - some off-policy correction happening |
96+
| > 0.3 | High clipping - policy has diverged significantly from rollout policy |
97+
98+
## Implementation Details
99+
100+
### Files Modified
101+
102+
**`src/art/unsloth/train.py`**
103+
104+
1. Compute `frac_old_logprobs_valid` before the NaN replacement:
105+
```python
106+
old_logprobs_nan_mask = torch.isnan(old_logprobs)
107+
frac_old_logprobs_valid = 1.0 - (
108+
old_logprobs_nan_mask.float().sum() / (old_logprobs.numel() + 1e-6)
109+
).item()
110+
old_logprobs = torch.where(
111+
old_logprobs_nan_mask, # reuse mask
112+
new_logprobs.detach(),
113+
old_logprobs,
114+
)
115+
```
116+
117+
2. Compute clip metrics after prob_ratio calculation:
118+
```python
119+
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
120+
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
121+
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
122+
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (assistant_mask.sum() + 1e-6)
123+
```
124+
125+
3. Log the new metrics:
126+
```python
127+
trainer._metrics["train"]["frac_old_logprobs_valid"].append(frac_old_logprobs_valid)
128+
trainer._metrics["train"]["mean_importance_ratio"].append(mean_importance_ratio.item())
129+
trainer._metrics["train"]["clip_fraction"].append(clip_fraction.item())
130+
```
131+
132+
### Performance Impact
133+
134+
- **Memory:** Negligible - reuses existing tensors, only adds scalar computations
135+
- **Compute:** Negligible - O(n) operations on existing tensors
136+
- **Logging overhead:** 3 additional floats per training step
137+
138+
## Use Cases
139+
140+
### 1. Debugging Missing Logprobs
141+
142+
If `frac_old_logprobs_valid = 0`:
143+
- Check that rollout is requesting logprobs from the model
144+
- Check that logprobs are being attached to trajectory messages
145+
- Check tokenization is extracting logprobs correctly (especially for dict messages)
146+
147+
### 2. Monitoring Training Health
148+
149+
Healthy training should show:
150+
- `frac_old_logprobs_valid` stable and > 0
151+
- `mean_importance_ratio` fluctuating around 1.0
152+
- `clip_fraction` low but non-zero
153+
154+
### 3. Detecting Distribution Drift
155+
156+
If `clip_fraction` suddenly increases:
157+
- Policy may have diverged too far from rollout policy
158+
- Consider reducing learning rate or increasing rollout frequency
159+
160+
## Backwards Compatibility
161+
162+
These changes are additive - existing code continues to work. The new metrics appear in wandb logs automatically if wandb is configured.
163+
164+
## Testing
165+
166+
Manual verification:
167+
1. Run training with valid logprobs → `frac_old_logprobs_valid > 0`
168+
2. Run training with `allow_training_without_logprobs=True` and no logprobs → `frac_old_logprobs_valid = 0`
169+
3. Verify `mean_importance_ratio` deviates from 1.0 over training steps
170+
171+
## Related Work
172+
173+
- PPO paper (Schulman et al., 2017) discusses importance sampling and clipping
174+
- TRL's `PPOTrainer` logs similar metrics (`clipfrac`, `ratio`)
175+
- This brings ART's observability closer to standard PPO implementations

src/art/loss.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class Loss(BaseModel):
1616
mean_kl: torch.Tensor
1717
mean_entropy: torch.Tensor | None
1818
probs_corr: torch.Tensor
19+
frac_old_logprobs_valid: float
20+
mean_importance_ratio: torch.Tensor
21+
clip_fraction: torch.Tensor
1922

2023

2124
def loss_fn(
@@ -32,6 +35,9 @@ def loss_fn(
3235
)
3336
weights = shift_tensor(inputs["weights"], 0.0)
3437
old_logprobs_mask = ~torch.isnan(old_logprobs)
38+
frac_old_logprobs_valid = (
39+
old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6)
40+
).item()
3541
probs_corr = torch.corrcoef(
3642
torch.stack(
3743
[
@@ -77,15 +83,23 @@ def loss_fn(
7783
)
7884
if tau := experimental_config.get("kimi_k2_tau", None):
7985
advantages -= tau * logprob_diff.detach()
86+
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
87+
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
88+
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (
89+
assistant_mask.sum() + 1e-6
90+
)
91+
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (
92+
assistant_mask.sum() + 1e-6
93+
)
8094
if experimental_config.get("ppo", True):
8195
policy_loss = -torch.min(
8296
prob_ratio * advantages,
83-
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
97+
clipped_ratio * advantages,
8498
)
8599
else:
86100
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
87101
policy_loss = -(
88-
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
102+
clipped_ratio.detach()
89103
* advantages
90104
* new_logprobs
91105
)
@@ -123,6 +137,9 @@ def loss_fn(
123137
mean_kl=mean_kl,
124138
mean_entropy=mean_entropy,
125139
probs_corr=probs_corr,
140+
frac_old_logprobs_valid=frac_old_logprobs_valid,
141+
mean_importance_ratio=mean_importance_ratio,
142+
clip_fraction=clip_fraction,
126143
)
127144

128145

src/art/unsloth/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def compute_loss(
167167
trainer._metrics["train"]["learning_rate"].append(config.learning_rate)
168168
trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item())
169169
if loss.mean_entropy is not None:
170-
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore
170+
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item())
171+
trainer._metrics["train"]["frac_old_logprobs_valid"].append(loss.frac_old_logprobs_valid)
172+
trainer._metrics["train"]["mean_importance_ratio"].append(loss.mean_importance_ratio.item())
173+
trainer._metrics["train"]["clip_fraction"].append(loss.clip_fraction.item())
171174
if config.beta > 0.0:
172175
trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item())
173176
return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore

0 commit comments

Comments
 (0)