@@ -348,16 +348,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
348348 raise ValueError (
349349 f"advantage and log_weight must have the same number of dimensions, got { advantage .ndim = } and { log_weight .ndim = } "
350350 )
351- gain1 = log_weight .exp () * advantage
352-
353- log_weight_clip = log_weight .clamp (* self ._clip_bounds )
354- clip_fraction = (log_weight_clip != log_weight ).to (log_weight .dtype ).mean ()
355- ratio = log_weight_clip .exp ()
356- gain2 = ratio * advantage
357-
358- # Token-level objective: compute min over clipped/unclipped at the token level
359- gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
360- td_out = TensorDict ({"loss_objective" : - gain })
351+ loss_objective , clip_fraction = self ._compute_policy_objective (
352+ log_weight , advantage
353+ )
354+ td_out = TensorDict ({"loss_objective" : loss_objective })
361355 td_out .set ("clip_fraction" , clip_fraction )
362356 td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
363357
@@ -406,6 +400,21 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
406400 del tensordict ["_cur_log_prob" ]
407401 return GRPOLossOutput .from_tensordict (td_out )
408402
403+ def _compute_policy_objective (
404+ self , log_weight : torch .Tensor , advantage : torch .Tensor
405+ ) -> tuple [torch .Tensor , torch .Tensor ]:
406+ """Default GRPO objective: PPO-style min between unclipped and clipped ratios.
407+
408+ Returns (loss_objective, clip_fraction).
409+ """
410+ gain1 = log_weight .exp () * advantage
411+ log_weight_clip = log_weight .clamp (* self ._clip_bounds )
412+ clip_fraction = (log_weight_clip != log_weight ).to (log_weight .dtype ).mean ()
413+ ratio = log_weight_clip .exp ()
414+ gain2 = ratio * advantage
415+ gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
416+ return - gain , clip_fraction
417+
409418 def _get_entropy (
410419 self , dist : d .Distribution , adv_shape : torch .Size
411420 ) -> torch .Tensor | TensorDict :
@@ -594,6 +603,27 @@ def __init__(
594603 return coeff * kl_penalty , kl_penalty
595604
596605
606+ class CISPO (GRPOLoss ):
607+ """CISPO (Clipped Importance Sampling Policy Optimization).
608+
609+ Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but
610+ replaces the PPO-style min with a clipped-importance objective:
611+ loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage
612+
613+ See MiniMax-M1 (CISPO) [arXiv](https://arxiv.org/html/2506.13585).
614+ """
615+
616+ def _compute_policy_objective (
617+ self , log_weight : torch .Tensor , advantage : torch .Tensor
618+ ) -> tuple [torch .Tensor , torch .Tensor ]:
619+ # CISPO: use clipped importance weights directly
620+ log_weight_clip = log_weight .clamp (* self ._clip_bounds )
621+ clip_fraction = (log_weight_clip != log_weight ).to (log_weight .dtype ).mean ()
622+ ratio = log_weight_clip .exp ()
623+ gain = ratio * advantage
624+ return - gain , clip_fraction
625+
626+
597627class MCAdvantage (Transform ):
598628 """Monte-Carlo advantage computation engine.
599629
0 commit comments