Skip to content

Commit 65446d2

Browse files
committed
[Feature] CISPO
ghstack-source-id: c056c25 Pull-Request: #3207
1 parent cccfaa6 commit 65446d2

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -348,16 +348,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
348348
raise ValueError(
349349
f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
350350
)
351-
gain1 = log_weight.exp() * advantage
352-
353-
log_weight_clip = log_weight.clamp(*self._clip_bounds)
354-
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
355-
ratio = log_weight_clip.exp()
356-
gain2 = ratio * advantage
357-
358-
# Token-level objective: compute min over clipped/unclipped at the token level
359-
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
360-
td_out = TensorDict({"loss_objective": -gain})
351+
loss_objective, clip_fraction = self._compute_policy_objective(
352+
log_weight, advantage
353+
)
354+
td_out = TensorDict({"loss_objective": loss_objective})
361355
td_out.set("clip_fraction", clip_fraction)
362356
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
363357

@@ -406,6 +400,21 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
406400
del tensordict["_cur_log_prob"]
407401
return GRPOLossOutput.from_tensordict(td_out)
408402

403+
def _compute_policy_objective(
404+
self, log_weight: torch.Tensor, advantage: torch.Tensor
405+
) -> tuple[torch.Tensor, torch.Tensor]:
406+
"""Default GRPO objective: PPO-style min between unclipped and clipped ratios.
407+
408+
Returns (loss_objective, clip_fraction).
409+
"""
410+
gain1 = log_weight.exp() * advantage
411+
log_weight_clip = log_weight.clamp(*self._clip_bounds)
412+
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
413+
ratio = log_weight_clip.exp()
414+
gain2 = ratio * advantage
415+
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
416+
return -gain, clip_fraction
417+
409418
def _get_entropy(
410419
self, dist: d.Distribution, adv_shape: torch.Size
411420
) -> torch.Tensor | TensorDict:
@@ -594,6 +603,27 @@ def __init__(
594603
return coeff * kl_penalty, kl_penalty
595604

596605

606+
class CISPO(GRPOLoss):
607+
"""CISPO (Clipped Importance Sampling Policy Optimization).
608+
609+
Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but
610+
replaces the PPO-style min with a clipped-importance objective:
611+
loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage
612+
613+
See MiniMax-M1 (CISPO) [arXiv](https://arxiv.org/html/2506.13585).
614+
"""
615+
616+
def _compute_policy_objective(
617+
self, log_weight: torch.Tensor, advantage: torch.Tensor
618+
) -> tuple[torch.Tensor, torch.Tensor]:
619+
# CISPO: use clipped importance weights directly
620+
log_weight_clip = log_weight.clamp(*self._clip_bounds)
621+
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
622+
ratio = log_weight_clip.exp()
623+
gain = ratio * advantage
624+
return -gain, clip_fraction
625+
626+
597627
class MCAdvantage(Transform):
598628
"""Monte-Carlo advantage computation engine.
599629

0 commit comments

Comments
 (0)