Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions packages/prime-rl-configs/src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ class FakeDataConfig(BaseDataConfig):
"""Token id generator: ``increasing`` for deterministic sequences, ``random`` for random ids."""


class LossMaskConfig(BaseConfig):
class LossMaskRolesConfig(BaseConfig):
"""Per-role loss-mask: AND the renderer's ``is_sampled`` signal with
these booleans. Use when you want to restrict SFT supervision to a
subset of roles even when other roles' tokens are model-sampled
(e.g. GLM's ``<|observation|>`` stop opener attributed to the
following tool message — see renderer fix in
``PrimeIntellect-ai/renderers#66``).
"""

system: bool = False
"""System messages contribute to the loss."""

Expand All @@ -71,6 +79,25 @@ class LossMaskConfig(BaseConfig):
"""Tool messages contribute to the loss."""


LossMaskConfig: TypeAlias = Literal["sampled", "all"] | LossMaskRolesConfig
"""How to compute the per-token loss mask in SFT.

- ``"sampled"`` (default, recommended with renderers): every token the
renderer marks ``is_sampled=True`` is trainable, regardless of role.
Correctly trains chat-template stop tokens whose structural
attribution lives in the next message's span.
- ``"all"``: every renderer token contributes to the loss, regardless
of role or ``is_sampled``. Useful for debugging.
- :class:`LossMaskRolesConfig`: AND ``is_sampled`` with a per-role
filter. Strict opt-in.

The chat-template fallback path (no renderer registered) ignores
``"sampled"`` / ``"all"`` and treats the loss-mask as
``LossMaskRolesConfig()`` defaults — it has no ``is_sampled`` signal
to fall back on.
"""


class SFTDataConfig(BaseDataConfig):
type: Literal["sft"] = "sft"

Expand All @@ -96,8 +123,8 @@ class SFTDataConfig(BaseDataConfig):
"""Random seed for shuffling. Re-shuffled per epoch by adding the epoch count to the seed."""

# Configuring
loss_mask: LossMaskConfig = LossMaskConfig()
"""Which message types contribute to the loss."""
loss_mask: LossMaskConfig = "sampled"
"""How to compute the per-token loss mask. See :data:`LossMaskConfig`."""

@model_validator(mode="after")
def validate_subsets_and_splits(self):
Expand Down
93 changes: 73 additions & 20 deletions src/prime_rl/trainer/sft/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers.tokenization_utils import PreTrainedTokenizer

from prime_rl.configs.sft import DataConfig, LossMaskConfig, SFTDataConfig
from prime_rl.configs.sft import (
DataConfig,
LossMaskConfig,
LossMaskRolesConfig,
SFTDataConfig,
)
from prime_rl.trainer.world import get_world
from prime_rl.utils.chat_template import (
IncrementalTokenizationError,
Expand All @@ -27,6 +32,41 @@
STACKING_DATASET_BUCKET_TIMEOUT = 10


def _always_true_role_filter(message: dict) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why private

"""Role filter that trains on every message regardless of role.

Used when :data:`LossMaskConfig` is ``"all"`` to bypass the
role-based ``role_to_mask`` gate in
:func:`renderers.base.build_training_sample`.
"""
return True


def _role_filter_for(cfg: LossMaskRolesConfig):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why private

"""Build a per-message role filter from :class:`LossMaskRolesConfig`.

The returned callable is the ``role_to_mask`` argument passed to
:func:`renderers.base.build_training_sample` and to
:func:`prime_rl.utils.chat_template.build_incremental_token_mask`.
"""

def role_filter(message: dict) -> bool:
assert "role" in message, "Message must have a role"
match message["role"]:
case "user":
return cfg.user
case "assistant":
return cfg.assistant
case "system":
return cfg.system
case "tool":
return cfg.tool
case _:
raise ValueError(f"Invalid message role: {message['role']}")

return role_filter


class Sample(TypedDict):
input_ids: list[int]
position_ids: list[int]
Expand Down Expand Up @@ -125,7 +165,7 @@ def __init__(
seed: int = 0,
seq_len: int = 128,
non_dp_size: int = 1,
loss_mask_config: LossMaskConfig = LossMaskConfig(),
loss_mask_config: LossMaskConfig = "sampled",
max_examples: int | None = None,
max_epochs: int | None = None,
renderer: Renderer | None = None,
Expand Down Expand Up @@ -219,19 +259,23 @@ def resolve_messages(example: dict) -> list[dict]:
for t in raw_tools
]

def should_mask(message: dict) -> bool:
assert "role" in message, "Message must have a role"
match message["role"]:
case "user":
return True if self.loss_mask_config.user else False
case "assistant":
return True if self.loss_mask_config.assistant else False
case "system":
return True if self.loss_mask_config.system else False
case "tool":
return True if self.loss_mask_config.tool else False
case _:
raise ValueError(f"Invalid message role: {message['role']}")
# Dispatch the loss-mask config to a renderer-side role filter.
# The string sentinels map to no filter ("sampled" → trust
# sampled_mask) or an always-True filter ("all"); the
# :class:`LossMaskRolesConfig` instance maps to a per-role
# boolean lookup. The chat-template fallback below has no
# ``is_sampled`` signal so we treat sentinel modes as the
# ``LossMaskRolesConfig()`` defaults (assistant-only).
loss_mask_cfg = self.loss_mask_config
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like all this logic, plus the always_true_role_filter, should be in role_filter_for which just returns a tuple of the renderer_role_filter and fallback_role_filter. I found this split in logic between the role filter helper and this dataset quite confusing.

if loss_mask_cfg == "sampled":
renderer_role_filter = None
fallback_role_filter = _role_filter_for(LossMaskRolesConfig())
elif loss_mask_cfg == "all":
renderer_role_filter = _always_true_role_filter
fallback_role_filter = _role_filter_for(LossMaskRolesConfig())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"all" fallback silently restricts loss to assistant-only tokens

Medium Severity

When loss_mask = "all" and no renderer is configured, fallback_role_filter is set to _role_filter_for(LossMaskRolesConfig()), which defaults to assistant-only masking. The "all" mode is documented as training on "every…token…regardless of role," but this fallback silently contradicts that. The chat-template build_incremental_token_mask uses role_to_mask directly (no is_sampled dependency), so _always_true_role_filter would correctly honor the "all" semantic in the fallback path.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit de439e3. Configure here.

else:
renderer_role_filter = _role_filter_for(loss_mask_cfg)
fallback_role_filter = renderer_role_filter

if self.renderer is not None:
if example.get("chat_template_kwargs") and not self._warned_chat_template_kwargs:
Expand All @@ -246,15 +290,15 @@ def should_mask(message: dict) -> bool:
input_ids, loss_mask = build_training_sample(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dislike the extremely different names between build_training_sample and build_incremental_token_mask even though they do the same thing, just with a different backend. Not necessarily important for this PR, but I didn't want to forget it

self.renderer,
messages,
role_to_mask=should_mask,
role_to_mask=renderer_role_filter,
tools=tools,
)
else:
try:
input_ids, loss_mask = build_incremental_token_mask(
self.tokenizer,
messages,
role_to_mask=should_mask,
role_to_mask=fallback_role_filter,
tools=tools,
chat_template_kwargs=example.get("chat_template_kwargs", {}),
collapse_consecutive_tool_messages=True,
Expand All @@ -263,8 +307,13 @@ def should_mask(message: dict) -> bool:
self.logger.warning(f"Skipping example {example.get('__index', '')}: {e}")
return None

# If EOS token is not found, manually append it
if not self.tokenizer.eos_token_id in input_ids:
# The renderer's token stream is authoritative: it mirrors the model's
# chat template (which carries no EOS by design) and trains the real
# stop signals via sampled_mask (e.g. GLM's turn-closing <|observation|>
# / <|user|>). Don't inject an EOS the renderer didn't emit. Only the
# chat-template fallback path (no sampled_mask) needs EOS appended so
# the model learns to stop.
if self.renderer is None and not self.tokenizer.eos_token_id in input_ids:
self.logger.warning(
f"Did not find EOS token ID {self.tokenizer.eos_token_id} in input_ids. Is something wrong with the chat template? Manually appending EOS token..."
)
Expand All @@ -286,7 +335,11 @@ def should_mask(message: dict) -> bool:
f"input_ids, loss_mask and target_ids must have the same length, but got {len(input_ids)=}, {len(loss_mask)=}, {len(target_ids)=}"
)
assert sum(loss_mask) > 0, "There are no tokens in this sample that contribute to the loss"
assert self.tokenizer.eos_token_id in target_ids, "EOS token ID must be present in target_ids"
# Only the chat-template fallback guarantees an EOS; the renderer honors
# the template and may legitimately omit it (stop signals come from
# sampled turn markers instead).
if self.renderer is None:
assert self.tokenizer.eos_token_id in target_ids, "EOS token ID must be present in target_ids"

# Create sample (with one fake target for the last token)
return {
Expand Down
Loading