Skip to content

Commit 54036f8

Browse files
committed
[Feature] CISPO
ghstack-source-id: 966d46f Pull-Request: #3207
1 parent 13434eb commit 54036f8

File tree

3 files changed

+189
-21
lines changed

3 files changed

+189
-21
lines changed

test/llm/test_objectives.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
1717
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
1818
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens
19-
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
19+
from torchrl.objectives.llm.grpo import (
20+
CISPO,
21+
CISPOLossOutput,
22+
GRPOLoss,
23+
GRPOLossOutput,
24+
MCAdvantage,
25+
)
2026
from torchrl.objectives.llm.sft import SFTLoss
2127

2228
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -203,7 +209,6 @@ def test_grpo(self, mock_transformer_model, dapo):
203209
loss_vals = loss_fn(data)
204210

205211
# Assertions: Check output type and structure
206-
from torchrl.objectives.llm.grpo import GRPOLossOutput
207212

208213
assert isinstance(
209214
loss_vals, GRPOLossOutput
@@ -240,6 +245,68 @@ def test_grpo(self, mock_transformer_model, dapo):
240245
0 <= loss_vals.clip_fraction <= 1
241246
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
242247

248+
def test_cispo(self, mock_transformer_model):
249+
"""Test CISPO loss computation with mock models."""
250+
vocab_size = 1024
251+
device = torch.device("cpu")
252+
eps = 0.20
253+
254+
# Create mock model and wrap it
255+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
256+
actor_network = TransformersWrapper(
257+
model,
258+
generate=False,
259+
pad_output=True,
260+
input_mode="history",
261+
)
262+
263+
# Create loss module
264+
265+
loss_fn = CISPO(actor_network, clip_epsilon=eps)
266+
267+
# Create fake data
268+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
269+
270+
# Compute loss
271+
loss_vals = loss_fn(data)
272+
273+
# Assertions: Check output type and structure
274+
275+
assert isinstance(
276+
loss_vals, CISPOLossOutput
277+
), f"Expected CISPOLossOutput, got {type(loss_vals)}"
278+
279+
# Check that all expected keys are present (same as GRPO)
280+
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
281+
assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction"
282+
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
283+
assert hasattr(loss_vals, "ESS"), "Missing ESS"
284+
assert hasattr(loss_vals, "entropy"), "Missing entropy"
285+
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"
286+
287+
# Check tensor shapes (all losses should be scalars after reduction)
288+
assert (
289+
loss_vals.loss_objective.shape == ()
290+
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
291+
assert (
292+
loss_vals.clip_fraction.shape == ()
293+
), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}"
294+
assert (
295+
loss_vals.kl_approx.shape == ()
296+
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
297+
assert (
298+
loss_vals.ESS.shape == ()
299+
), f"ESS should be scalar, got {loss_vals.ESS.shape}"
300+
301+
# Check that losses are finite
302+
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
303+
assert torch.isfinite(loss_vals.ESS), "ESS is not finite"
304+
305+
# Check that clip_fraction is in valid range [0, 1]
306+
assert (
307+
0 <= loss_vals.clip_fraction <= 1
308+
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
309+
243310

244311
class TestSFT:
245312
@pytest.fixture(scope="class")

torchrl/objectives/llm/__init__.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,27 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
from .grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
7+
from .grpo import (
8+
CISPO,
9+
CISPOLossOutput,
10+
DAPO,
11+
DAPOLossOutput,
12+
GRPOLoss,
13+
GRPOLossOutput,
14+
LLMLossOutput,
15+
MCAdvantage,
16+
)
817
from .sft import SFTLoss, SFTLossOutput
918

10-
__all__ = ["GRPOLoss", "GRPOLossOutput", "MCAdvantage", "SFTLoss", "SFTLossOutput"]
19+
__all__ = [
20+
"CISPO",
21+
"CISPOLossOutput",
22+
"DAPO",
23+
"DAPOLossOutput",
24+
"GRPOLoss",
25+
"GRPOLossOutput",
26+
"LLMLossOutput",
27+
"MCAdvantage",
28+
"SFTLoss",
29+
"SFTLossOutput",
30+
]

torchrl/objectives/llm/grpo.py

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from collections import defaultdict, deque
1010
from dataclasses import dataclass
11-
from typing import Literal
11+
from typing import Literal, TypeVar
1212

1313
import torch
1414
from tensordict import (
@@ -33,8 +33,12 @@
3333
from torchrl.objectives.utils import _reduce, _sum_td_features
3434

3535

36-
class GRPOLossOutput(TensorClass["nocast"]):
37-
"""GRPO Loss Output."""
36+
class LLMLossOutput(TensorClass["nocast"]):
37+
"""Base class for LLM loss outputs.
38+
39+
This base class defines the common structure for all LLM-based policy optimization
40+
loss outputs (GRPO, DAPO, CISPO, etc.).
41+
"""
3842

3943
loss_objective: torch.Tensor
4044
clip_fraction: torch.Tensor
@@ -48,6 +52,21 @@ class GRPOLossOutput(TensorClass["nocast"]):
4852
kl_to_inference: torch.Tensor | None = None
4953

5054

55+
LLMOutputType = TypeVar("LLMOutputType", bound=LLMLossOutput)
56+
57+
58+
class GRPOLossOutput(LLMLossOutput):
59+
"""GRPO Loss Output."""
60+
61+
62+
class DAPOLossOutput(LLMLossOutput):
63+
"""DAPO Loss Output."""
64+
65+
66+
class CISPOLossOutput(LLMLossOutput):
67+
"""CISPO Loss Output."""
68+
69+
5170
class GRPOLoss(LossModule):
5271
"""GRPO loss.
5372
@@ -123,6 +142,7 @@ class GRPOLoss(LossModule):
123142
"""
124143

125144
actor_network: LLMWrapperBase
145+
output_type: type[LLMLossOutput] = GRPOLossOutput
126146

127147
@dataclass
128148
class _AcceptedKeys(LossModule._AcceptedKeys):
@@ -137,6 +157,33 @@ class _AcceptedKeys(LossModule._AcceptedKeys):
137157
sample_log_prob: NestedKey = ("log_probs", "full")
138158
ref_log_probs: NestedKey = ("next", "ref_log_probs", "full")
139159

160+
@property
161+
def tensor_keys(self) -> _AcceptedKeys:
162+
"""Access the tensordict key configuration for this loss.
163+
164+
This property provides access to the configurable keys used by the loss module
165+
to read tensors from input TensorDicts. These keys include:
166+
167+
- ``advantage``: key for the advantage values
168+
- ``action``: key for the action tokens (default: ``("tokens", "full")``)
169+
- ``sample_log_prob``: key for the log probabilities from the reference policy (default: ``("log_probs", "full")``)
170+
- ``ref_log_probs``: key for the reference policy log probabilities (default: ``("next", "ref_log_probs", "full")``)
171+
172+
To modify these keys, use the :meth:`~.set_keys` method.
173+
174+
Examples:
175+
>>> loss = GRPOLoss(actor_network)
176+
>>> # Access current keys
177+
>>> print(loss.tensor_keys.advantage) # "advantage"
178+
>>> # Modify keys
179+
>>> loss.set_keys(advantage="my_advantage_key")
180+
>>> print(loss.tensor_keys.advantage) # "my_advantage_key"
181+
182+
Returns:
183+
An instance of _AcceptedKeys containing all configurable tensordict keys.
184+
"""
185+
return self._tensor_keys
186+
140187
def __init__(
141188
self,
142189
actor_network: LLMWrapperBase | None = None,
@@ -316,7 +363,7 @@ def _get_cur_log_prob(self, tensordict):
316363
)
317364
return log_prob, dist, False
318365

319-
def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
366+
def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
320367
# Some sanity checks and housekeeping:
321368
# - We may not have the tokens yet. If not, we will use the tokenizer of the actor to tokenize the text.
322369
# We default to history rather than text because the history will account for multiturn, or multimodal inputs.
@@ -348,16 +395,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
348395
raise ValueError(
349396
f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
350397
)
351-
gain1 = log_weight.exp() * advantage
352-
353-
log_weight_clip = log_weight.clamp(*self._clip_bounds)
354-
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
355-
ratio = log_weight_clip.exp()
356-
gain2 = ratio * advantage
357-
358-
# Token-level objective: compute min over clipped/unclipped at the token level
359-
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
360-
td_out = TensorDict({"loss_objective": -gain})
398+
loss_objective, clip_fraction = self._compute_policy_objective(
399+
log_weight, advantage
400+
)
401+
td_out = TensorDict({"loss_objective": loss_objective})
361402
td_out.set("clip_fraction", clip_fraction)
362403
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
363404

@@ -404,7 +445,22 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
404445
td_out["loss_kl_to_inference"] = loss_kl
405446
td_out["kl_to_inference"] = kl_penalty.detach()
406447
del tensordict["_cur_log_prob"]
407-
return GRPOLossOutput.from_tensordict(td_out)
448+
return self.output_type.from_tensordict(td_out)
449+
450+
def _compute_policy_objective(
451+
self, log_weight: torch.Tensor, advantage: torch.Tensor
452+
) -> tuple[torch.Tensor, torch.Tensor]:
453+
"""Default GRPO objective: PPO-style min between unclipped and clipped ratios.
454+
455+
Returns (loss_objective, clip_fraction).
456+
"""
457+
gain1 = log_weight.exp() * advantage
458+
log_weight_clip = log_weight.clamp(*self._clip_bounds)
459+
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
460+
ratio = log_weight_clip.exp()
461+
gain2 = ratio * advantage
462+
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
463+
return -gain, clip_fraction
408464

409465
def _get_entropy(
410466
self, dist: d.Distribution, adv_shape: torch.Size
@@ -548,10 +604,12 @@ def _log_weight(
548604
class DAPO(GRPOLoss):
549605
"""DAPO (Clip-Higher over GRPO).
550606
551-
Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in DAPO
552-
[arXiv](https://arxiv.org/html/2503.14476).
607+
Validates asymmetric clip thresholds; recommended (0.20, 0.28), see Eq. (10) in
608+
the `DAPO <https://arxiv.org/html/2503.14476>`_ paper.
553609
"""
554610

611+
output_type: type[LLMLossOutput] = DAPOLossOutput
612+
555613
def __init__(
556614
self,
557615
tensordict: TensorDictBase,
@@ -594,6 +652,29 @@ def __init__(
594652
return coeff * kl_penalty, kl_penalty
595653

596654

655+
class CISPO(GRPOLoss):
656+
"""CISPO (Clipped Importance Sampling Policy Optimization).
657+
658+
Inherits the GRPO pipeline (masking, ESS, entropy, optional KL penalties) but
659+
replaces the PPO-style min with a clipped-importance objective:
660+
loss = - clip(weight, [1 - eps_low, 1 + eps_high]) * advantage
661+
662+
See the `MiniMax-M1 (CISPO) <https://arxiv.org/html/2506.13585>`_ paper.
663+
"""
664+
665+
output_type: type[LLMLossOutput] = CISPOLossOutput
666+
667+
def _compute_policy_objective(
668+
self, log_weight: torch.Tensor, advantage: torch.Tensor
669+
) -> tuple[torch.Tensor, torch.Tensor]:
670+
# CISPO: use clipped importance weights directly
671+
log_weight_clip = log_weight.clamp(*self._clip_bounds)
672+
clip_fraction = (log_weight_clip != log_weight).to(log_weight.dtype).mean()
673+
ratio = log_weight_clip.exp()
674+
gain = ratio * advantage
675+
return -gain, clip_fraction
676+
677+
597678
class MCAdvantage(Transform):
598679
"""Monte-Carlo advantage computation engine.
599680

0 commit comments

Comments
 (0)