Skip to content

Commit d1795bd

Browse files
committed
[Feature] kl_mask_threshold
ghstack-source-id: 1635871 Pull-Request: #3208
1 parent 65446d2 commit d1795bd

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class GRPOLoss(LossModule):
8282
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
8383
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
8484
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
85+
kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
86+
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
87+
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
88+
(see table and description; enables per-token trust region).
8589
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8690
loss to favour exploratory policies.
8791
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -142,6 +146,7 @@ def __init__(
142146
actor_network: LLMWrapperBase | None = None,
143147
*,
144148
clip_epsilon: float | tuple[float, float] = 0.2,
149+
kl_mask_threshold: float | None = None,
145150
entropy_bonus: bool = True,
146151
samples_mc_entropy: int = 1,
147152
entropy_coeff: float = 0.01,
@@ -161,6 +166,7 @@ def __init__(
161166
self.samples_mc_entropy = samples_mc_entropy
162167
self.entropy_coeff = entropy_coeff
163168
self.reduction = reduction if reduction is not None else "mean"
169+
self.kl_mask_threshold = kl_mask_threshold
164170

165171
# Determine device and register clip epsilon as buffer
166172
if device is None:
@@ -335,6 +341,32 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
335341
tensordict, adv_shape=advantage.shape[:-1]
336342
)
337343
mask = dist.mask
344+
345+
# Optional per-token trust-region filtering (KL-Mask) vs reference policy
346+
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
347+
try:
348+
ref_log_prob = tensordict.get(
349+
self.tensor_keys.ref_log_probs,
350+
as_padded_tensor=True,
351+
padding_side="left",
352+
padding_value=0.0,
353+
)
354+
except KeyError:
355+
ref_log_prob = None
356+
cur_log_prob = tensordict.get("_cur_log_prob", None)
357+
if (ref_log_prob is not None) and (cur_log_prob is not None):
358+
# Align to valid tokens only (safety)
359+
cur_log_prob_masked = torch.where(
360+
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
361+
)
362+
ref_log_prob_masked = torch.where(
363+
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
364+
)
365+
log_is_ref = cur_log_prob_masked - ref_log_prob_masked
366+
kl_token = 0.5 * (log_is_ref**2)
367+
tr_mask = kl_token <= self.kl_mask_threshold
368+
# Combine with attention mask
369+
mask = mask & tr_mask
338370
# ESS for logging
339371
with torch.no_grad():
340372
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according

0 commit comments

Comments
 (0)