Skip to content

Commit 7ab48a4

Browse files
committed
[Feature] kl_mask_threshold
ghstack-source-id: 6bba3dc Pull-Request: #3208
1 parent ed0d8dc commit 7ab48a4

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

test/llm/test_objectives.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
GRPOLossOutput,
2424
MCAdvantage,
2525
)
26+
from torchrl._utils import logger
2627
from torchrl.objectives.llm.sft import SFTLoss
2728

2829
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -200,7 +201,7 @@ def test_grpo(self, mock_transformer_model, dapo):
200201
)
201202

202203
# Create loss module
203-
loss_fn = GRPOLoss(actor_network, eps=eps)
204+
loss_fn = GRPOLoss(actor_network, clip_epsilon=eps)
204205

205206
# Create fake data
206207
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
@@ -245,6 +246,124 @@ def test_grpo(self, mock_transformer_model, dapo):
245246
0 <= loss_vals.clip_fraction <= 1
246247
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
247248

249+
def test_kl_mask_threshold(self, mock_transformer_model):
250+
"""Test that kl_mask_threshold properly filters out high-KL tokens."""
251+
torch.manual_seed(42)
252+
vocab_size = 1024
253+
device = (
254+
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
255+
)
256+
257+
# Create mock model and wrap it
258+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
259+
actor_network = TransformersWrapper(
260+
model,
261+
generate=False,
262+
pad_output=True,
263+
input_mode="history",
264+
)
265+
266+
# Create fake data
267+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
268+
269+
# First, test that the data works without any threshold
270+
loss_fn_baseline = GRPOLoss(
271+
actor_network, clip_epsilon=0.2, kl_mask_threshold=None
272+
)
273+
274+
data_baseline = data.clone()
275+
loss_baseline = loss_fn_baseline(data_baseline)
276+
logger.info(f"Baseline loss (no threshold): {loss_baseline.loss_objective}")
277+
logger.info(f"Baseline ESS: {loss_baseline.ESS}")
278+
279+
# Check baseline is valid
280+
if not torch.isfinite(loss_baseline.loss_objective):
281+
raise ValueError(
282+
f"Baseline loss is not finite: {loss_baseline.loss_objective}, skipping test"
283+
)
284+
285+
# Now test with kl_mask_threshold enabled
286+
# Use a very high threshold that should not mask any tokens
287+
kl_threshold = 100.0 # Extremely high threshold to ensure no masking
288+
loss_fn_with_threshold = GRPOLoss(
289+
actor_network, clip_epsilon=0.2, kl_mask_threshold=kl_threshold
290+
)
291+
292+
data_with_threshold = data.clone()
293+
loss_with_threshold = loss_fn_with_threshold(data_with_threshold)
294+
295+
# Should produce valid output
296+
assert isinstance(loss_with_threshold, GRPOLossOutput)
297+
298+
# Check that the loss is finite (with such a high threshold, it should be)
299+
assert torch.isfinite(
300+
loss_with_threshold.loss_objective
301+
), f"loss_with_threshold is not finite: {loss_with_threshold.loss_objective}"
302+
assert torch.isfinite(
303+
loss_with_threshold.ESS
304+
), f"ESS with threshold is not finite: {loss_with_threshold.ESS}"
305+
306+
logger.info(
307+
f"Loss with high threshold (100.0): {loss_with_threshold.loss_objective}"
308+
)
309+
logger.info(f"ESS with high threshold: {loss_with_threshold.ESS}")
310+
311+
# The losses should be identical or very similar since we're not masking anything
312+
# (the difference comes only from numerical precision)
313+
assert torch.isclose(
314+
loss_baseline.loss_objective, loss_with_threshold.loss_objective, rtol=1e-3
315+
), f"Losses differ too much with high threshold: {loss_baseline.loss_objective} vs {loss_with_threshold.loss_objective}"
316+
317+
def test_failure_missing_entries(self, mock_transformer_model):
318+
"""Test that GRPO fails when required keys are missing but works without optional keys."""
319+
vocab_size = 1024
320+
device = torch.device("cpu")
321+
322+
# Create mock model and wrap it
323+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
324+
actor_network = TransformersWrapper(
325+
model,
326+
generate=False,
327+
pad_output=True,
328+
input_mode="history",
329+
)
330+
331+
# Create loss module
332+
loss_fn = GRPOLoss(actor_network, clip_epsilon=0.2)
333+
334+
# Create fake data
335+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
336+
337+
# Test 1: Missing sample_log_prob (required) should fail
338+
data_missing_sample_log_prob = data.clone()
339+
data_missing_sample_log_prob.exclude(("log_probs", "full"), inplace=True)
340+
341+
with pytest.raises(KeyError, match="Couldn't find the log-prob"):
342+
loss_fn(data_missing_sample_log_prob)
343+
344+
# Test 2: Missing ref_log_probs (optional when kl_to_ref_coeff is None) should work
345+
data_missing_ref = data.clone()
346+
# Remove the ref_log_probs key if it exists
347+
if ("next", "ref_log_probs", "full") in data_missing_ref.keys(True):
348+
data_missing_ref.exclude(("next", "ref_log_probs", "full"), inplace=True)
349+
350+
# Should work fine without ref_log_probs when kl_to_ref_coeff is None
351+
loss_vals = loss_fn(data_missing_ref)
352+
assert isinstance(loss_vals, GRPOLossOutput)
353+
assert torch.isfinite(loss_vals.loss_objective)
354+
355+
# Test 3: Missing ref_log_probs when kl_to_ref_coeff is set should fail
356+
loss_fn_with_kl = GRPOLoss(actor_network, clip_epsilon=0.2, kl_to_ref_coeff=0.1)
357+
358+
data_missing_ref_for_kl = data.clone()
359+
if ("next", "ref_log_probs", "full") in data_missing_ref_for_kl.keys(True):
360+
data_missing_ref_for_kl.exclude(
361+
("next", "ref_log_probs", "full"), inplace=True
362+
)
363+
364+
with pytest.raises(KeyError, match="Couldn't find the ref log-prob"):
365+
loss_fn_with_kl(data_missing_ref_for_kl)
366+
248367
def test_cispo(self, mock_transformer_model):
249368
"""Test CISPO loss computation with mock models."""
250369
vocab_size = 1024

torchrl/objectives/llm/grpo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class GRPOLoss(LossModule):
101101
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
102102
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
103103
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
104+
kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
105+
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
106+
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
107+
(see table and description; enables per-token trust region).
104108
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
105109
loss to favour exploratory policies.
106110
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -189,6 +193,7 @@ def __init__(
189193
actor_network: LLMWrapperBase | None = None,
190194
*,
191195
clip_epsilon: float | tuple[float, float] = 0.2,
196+
kl_mask_threshold: float | None = None,
192197
entropy_bonus: bool = True,
193198
samples_mc_entropy: int = 1,
194199
entropy_coeff: float = 0.01,
@@ -208,6 +213,7 @@ def __init__(
208213
self.samples_mc_entropy = samples_mc_entropy
209214
self.entropy_coeff = entropy_coeff
210215
self.reduction = reduction if reduction is not None else "mean"
216+
self.kl_mask_threshold = kl_mask_threshold
211217

212218
# Determine device and register clip epsilon as buffer
213219
if device is None:
@@ -382,6 +388,32 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
382388
tensordict, adv_shape=advantage.shape[:-1]
383389
)
384390
mask = dist.mask
391+
392+
# Optional per-token trust-region filtering (KL-Mask) vs reference policy
393+
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
394+
try:
395+
inference_log_prob = tensordict.get(
396+
self.tensor_keys.sample_log_prob,
397+
as_padded_tensor=True,
398+
padding_side="left",
399+
padding_value=0.0,
400+
)
401+
except KeyError:
402+
inference_log_prob = None
403+
cur_log_prob = tensordict.get("_cur_log_prob", None)
404+
if (inference_log_prob is not None) and (cur_log_prob is not None):
405+
# Align to valid tokens only (safety)
406+
cur_log_prob_masked = torch.where(
407+
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
408+
)
409+
inference_log_prob_masked = torch.where(
410+
expand_as_right(mask, inference_log_prob), inference_log_prob, 0.0
411+
)
412+
log_is_ref = cur_log_prob_masked - inference_log_prob_masked
413+
kl_token = 0.5 * (log_is_ref**2)
414+
tr_mask = kl_token <= self.kl_mask_threshold
415+
# Combine with attention mask
416+
mask = mask & tr_mask
385417
# ESS for logging
386418
with torch.no_grad():
387419
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according

0 commit comments

Comments
 (0)