88
99from collections import defaultdict , deque
1010from dataclasses import dataclass
11- from typing import Literal
11+ from typing import Literal , TypeVar
1212
1313import torch
1414from tensordict import (
3333from torchrl .objectives .utils import _reduce , _sum_td_features
3434
3535
36- class GRPOLossOutput (TensorClass ["nocast" ]):
37- """GRPO Loss Output."""
36+ class LLMLossOutput (TensorClass ["nocast" ]):
37+ """Base class for LLM loss outputs.
38+
39+ This base class defines the common structure for all LLM-based policy optimization
40+ loss outputs (GRPO, DAPO, CISPO, etc.).
41+ """
3842
3943 loss_objective : torch .Tensor
4044 clip_fraction : torch .Tensor
@@ -48,6 +52,21 @@ class GRPOLossOutput(TensorClass["nocast"]):
4852 kl_to_inference : torch .Tensor | None = None
4953
5054
55+ LLMOutputType = TypeVar ("LLMOutputType" , bound = LLMLossOutput )
56+
57+
58+ class GRPOLossOutput (LLMLossOutput ):
59+ """GRPO Loss Output."""
60+
61+
62+ class DAPOLossOutput (LLMLossOutput ):
63+ """DAPO Loss Output."""
64+
65+
66+ class CISPOLossOutput (LLMLossOutput ):
67+ """CISPO Loss Output."""
68+
69+
5170class GRPOLoss (LossModule ):
5271 """GRPO loss.
5372
@@ -123,6 +142,7 @@ class GRPOLoss(LossModule):
123142 """
124143
125144 actor_network : LLMWrapperBase
145+ output_type : type [LLMLossOutput ] = GRPOLossOutput
126146
127147 @dataclass
128148 class _AcceptedKeys (LossModule ._AcceptedKeys ):
@@ -137,6 +157,33 @@ class _AcceptedKeys(LossModule._AcceptedKeys):
137157 sample_log_prob : NestedKey = ("log_probs" , "full" )
138158 ref_log_probs : NestedKey = ("next" , "ref_log_probs" , "full" )
139159
160+ @property
161+ def tensor_keys (self ) -> _AcceptedKeys :
162+ """Access the tensordict key configuration for this loss.
163+
164+ This property provides access to the configurable keys used by the loss module
165+ to read tensors from input TensorDicts. These keys include:
166+
167+ - ``advantage``: key for the advantage values
168+ - ``action``: key for the action tokens (default: ``("tokens", "full")``)
169+ - ``sample_log_prob``: key for the log probabilities from the reference policy (default: ``("log_probs", "full")``)
170+ - ``ref_log_probs``: key for the reference policy log probabilities (default: ``("next", "ref_log_probs", "full")``)
171+
172+ To modify these keys, use the :meth:`~.set_keys` method.
173+
174+ Examples:
175+ >>> loss = GRPOLoss(actor_network)
176+ >>> # Access current keys
177+ >>> print(loss.tensor_keys.advantage) # "advantage"
178+ >>> # Modify keys
179+ >>> loss.set_keys(advantage="my_advantage_key")
180+ >>> print(loss.tensor_keys.advantage) # "my_advantage_key"
181+
182+ Returns:
183+ An instance of _AcceptedKeys containing all configurable tensordict keys.
184+ """
185+ return self ._tensor_keys
186+
140187 def __init__ (
141188 self ,
142189 actor_network : LLMWrapperBase | None = None ,
@@ -316,7 +363,7 @@ def _get_cur_log_prob(self, tensordict):
316363 )
317364 return log_prob , dist , False
318365
319- def forward (self , tensordict : TensorDictBase ) -> GRPOLossOutput :
366+ def forward (self , tensordict : TensorDictBase ) -> LLMOutputType :
320367 # Some sanity checks and housekeeping:
321368 # - We may not have the tokens yet. If not, we will use the tokenizer of the actor to tokenize the text.
322369 # We default to history rather than text because the history will account for multiturn, or multimodal inputs.
@@ -348,16 +395,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
348395 raise ValueError (
349396 f"advantage and log_weight must have the same number of dimensions, got { advantage .ndim = } and { log_weight .ndim = } "
350397 )
351- gain1 = log_weight .exp () * advantage
352-
353- log_weight_clip = log_weight .clamp (* self ._clip_bounds )
354- clip_fraction = (log_weight_clip != log_weight ).to (log_weight .dtype ).mean ()
355- ratio = log_weight_clip .exp ()
356- gain2 = ratio * advantage
357-
358- # Token-level objective: compute min over clipped/unclipped at the token level
359- gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
360- td_out = TensorDict ({"loss_objective" : - gain })
398+ loss_objective , clip_fraction = self ._compute_policy_objective (
399+ log_weight , advantage
400+ )
401+ td_out = TensorDict ({"loss_objective" : loss_objective })
361402 td_out .set ("clip_fraction" , clip_fraction )
362403 td_out .set ("kl_approx" , kl_approx .detach ().mean ()) # for logging
363404
@@ -404,7 +445,22 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
404445 td_out ["loss_kl_to_inference" ] = loss_kl
405446 td_out ["kl_to_inference" ] = kl_penalty .detach ()
406447 del tensordict ["_cur_log_prob" ]
407- return GRPOLossOutput .from_tensordict (td_out )
448+ return self .output_type .from_tensordict (td_out )
449+
450+ def _compute_policy_objective (
451+ self , log_weight : torch .Tensor , advantage : torch .Tensor
452+ ) -> tuple [torch .Tensor , torch .Tensor ]:
453+ """Default GRPO objective: PPO-style min between unclipped and clipped ratios.
454+
455+ Returns (loss_objective, clip_fraction).
456+ """
457+ gain1 = log_weight .exp () * advantage
458+ log_weight_clip = log_weight .clamp (* self ._clip_bounds )
459+ clip_fraction = (log_weight_clip != log_weight ).to (log_weight .dtype ).mean ()
460+ ratio = log_weight_clip .exp ()
461+ gain2 = ratio * advantage
462+ gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
463+ return - gain , clip_fraction
408464
409465 def _get_entropy (
410466 self , dist : d .Distribution , adv_shape : torch .Size
@@ -548,10 +604,12 @@ def _log_weight(
548604class DAPO (GRPOLoss ):
549605 """DAPO (Clip-Higher over GRPO).
550606
551- Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
552- [arXiv]( https://arxiv.org/html/2503.14476) .
607+ Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in
608+ the `DAPO < https://arxiv.org/html/2503.14476>`_ paper .
553609 """
554610
611+ output_type : type [LLMLossOutput ] = DAPOLossOutput
612+
555613 def __init__ (
556614 self ,
557615 tensordict : TensorDictBase ,
@@ -594,6 +652,29 @@ def __init__(
594652 return coeff * kl_penalty , kl_penalty
595653
596654
655+ class CISPO (GRPOLoss ):
656+ """CISPO (Clipped Importance Sampling Policy Optimization).
657+
658+ Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but
659+ replaces the PPO-style min with a clipped-importance objective:
660+ loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage
661+
662+ See the `MiniMax-M1 (CISPO) <https://arxiv.org/html/2506.13585>`_ paper.
663+ """
664+
665+ output_type : type [LLMLossOutput ] = CISPOLossOutput
666+
667+ def _compute_policy_objective (
668+ self , log_weight : torch .Tensor , advantage : torch .Tensor
669+ ) -> tuple [torch .Tensor , torch .Tensor ]:
670+ # CISPO: use clipped importance weights directly
671+ log_weight_clip = log_weight .clamp (* self ._clip_bounds )
672+ clip_fraction = (log_weight_clip != log_weight ).to (log_weight .dtype ).mean ()
673+ ratio = log_weight_clip .exp ()
674+ gain = ratio * advantage
675+ return - gain , clip_fraction
676+
677+
597678class MCAdvantage (Transform ):
598679 """Monte-Carlo advantage computation engine.
599680
0 commit comments