diff --git a/packages/prime-rl-configs/src/prime_rl/configs/sft.py b/packages/prime-rl-configs/src/prime_rl/configs/sft.py index 56e905cff7..5b88b92a8e 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/sft.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/sft.py @@ -57,7 +57,15 @@ class FakeDataConfig(BaseDataConfig): """Token id generator: ``increasing`` for deterministic sequences, ``random`` for random ids.""" -class LossMaskConfig(BaseConfig): +class LossMaskRolesConfig(BaseConfig): + """Per-role loss-mask: AND the renderer's ``is_sampled`` signal with + these booleans. Use when you want to restrict SFT supervision to a + subset of roles even when other roles' tokens are model-sampled + (e.g. GLM's ``<|observation|>`` stop opener attributed to the + following tool message — see renderer fix in + ``PrimeIntellect-ai/renderers#66``). + """ + system: bool = False """System messages contribute to the loss.""" @@ -71,6 +79,25 @@ class LossMaskConfig(BaseConfig): """Tool messages contribute to the loss.""" +LossMaskConfig: TypeAlias = Literal["sampled", "all"] | LossMaskRolesConfig +"""How to compute the per-token loss mask in SFT. + +- ``"sampled"`` (default, recommended with renderers): every token the + renderer marks ``is_sampled=True`` is trainable, regardless of role. + Correctly trains chat-template stop tokens whose structural + attribution lives in the next message's span. +- ``"all"``: every renderer token contributes to the loss, regardless + of role or ``is_sampled``. Useful for debugging. +- :class:`LossMaskRolesConfig`: AND ``is_sampled`` with a per-role + filter. Strict opt-in. + +The chat-template fallback path (no renderer registered) ignores +``"sampled"`` / ``"all"`` and treats the loss-mask as +``LossMaskRolesConfig()`` defaults — it has no ``is_sampled`` signal +to fall back on. +""" + + class SFTDataConfig(BaseDataConfig): type: Literal["sft"] = "sft" @@ -96,8 +123,8 @@ class SFTDataConfig(BaseDataConfig): """Random seed for shuffling. Re-shuffled per epoch by adding the epoch count to the seed.""" # Configuring - loss_mask: LossMaskConfig = LossMaskConfig() - """Which message types contribute to the loss.""" + loss_mask: LossMaskConfig = "sampled" + """How to compute the per-token loss mask. See :data:`LossMaskConfig`.""" @model_validator(mode="after") def validate_subsets_and_splits(self): diff --git a/src/prime_rl/trainer/sft/data.py b/src/prime_rl/trainer/sft/data.py index d0edd13787..196d140c6d 100644 --- a/src/prime_rl/trainer/sft/data.py +++ b/src/prime_rl/trainer/sft/data.py @@ -13,7 +13,12 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers.tokenization_utils import PreTrainedTokenizer -from prime_rl.configs.sft import DataConfig, LossMaskConfig, SFTDataConfig +from prime_rl.configs.sft import ( + DataConfig, + LossMaskConfig, + LossMaskRolesConfig, + SFTDataConfig, +) from prime_rl.trainer.world import get_world from prime_rl.utils.chat_template import ( IncrementalTokenizationError, @@ -27,6 +32,41 @@ STACKING_DATASET_BUCKET_TIMEOUT = 10 +def _always_true_role_filter(message: dict) -> bool: + """Role filter that trains on every message regardless of role. + + Used when :data:`LossMaskConfig` is ``"all"`` to bypass the + role-based ``role_to_mask`` gate in + :func:`renderers.base.build_training_sample`. + """ + return True + + +def _role_filter_for(cfg: LossMaskRolesConfig): + """Build a per-message role filter from :class:`LossMaskRolesConfig`. + + The returned callable is the ``role_to_mask`` argument passed to + :func:`renderers.base.build_training_sample` and to + :func:`prime_rl.utils.chat_template.build_incremental_token_mask`. + """ + + def role_filter(message: dict) -> bool: + assert "role" in message, "Message must have a role" + match message["role"]: + case "user": + return cfg.user + case "assistant": + return cfg.assistant + case "system": + return cfg.system + case "tool": + return cfg.tool + case _: + raise ValueError(f"Invalid message role: {message['role']}") + + return role_filter + + class Sample(TypedDict): input_ids: list[int] position_ids: list[int] @@ -125,7 +165,7 @@ def __init__( seed: int = 0, seq_len: int = 128, non_dp_size: int = 1, - loss_mask_config: LossMaskConfig = LossMaskConfig(), + loss_mask_config: LossMaskConfig = "sampled", max_examples: int | None = None, max_epochs: int | None = None, renderer: Renderer | None = None, @@ -219,19 +259,23 @@ def resolve_messages(example: dict) -> list[dict]: for t in raw_tools ] - def should_mask(message: dict) -> bool: - assert "role" in message, "Message must have a role" - match message["role"]: - case "user": - return True if self.loss_mask_config.user else False - case "assistant": - return True if self.loss_mask_config.assistant else False - case "system": - return True if self.loss_mask_config.system else False - case "tool": - return True if self.loss_mask_config.tool else False - case _: - raise ValueError(f"Invalid message role: {message['role']}") + # Dispatch the loss-mask config to a renderer-side role filter. + # The string sentinels map to no filter ("sampled" → trust + # sampled_mask) or an always-True filter ("all"); the + # :class:`LossMaskRolesConfig` instance maps to a per-role + # boolean lookup. The chat-template fallback below has no + # ``is_sampled`` signal so we treat sentinel modes as the + # ``LossMaskRolesConfig()`` defaults (assistant-only). + loss_mask_cfg = self.loss_mask_config + if loss_mask_cfg == "sampled": + renderer_role_filter = None + fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) + elif loss_mask_cfg == "all": + renderer_role_filter = _always_true_role_filter + fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) + else: + renderer_role_filter = _role_filter_for(loss_mask_cfg) + fallback_role_filter = renderer_role_filter if self.renderer is not None: if example.get("chat_template_kwargs") and not self._warned_chat_template_kwargs: @@ -246,7 +290,7 @@ def should_mask(message: dict) -> bool: input_ids, loss_mask = build_training_sample( self.renderer, messages, - role_to_mask=should_mask, + role_to_mask=renderer_role_filter, tools=tools, ) else: @@ -254,7 +298,7 @@ def should_mask(message: dict) -> bool: input_ids, loss_mask = build_incremental_token_mask( self.tokenizer, messages, - role_to_mask=should_mask, + role_to_mask=fallback_role_filter, tools=tools, chat_template_kwargs=example.get("chat_template_kwargs", {}), collapse_consecutive_tool_messages=True, @@ -263,8 +307,13 @@ def should_mask(message: dict) -> bool: self.logger.warning(f"Skipping example {example.get('__index', '')}: {e}") return None - # If EOS token is not found, manually append it - if not self.tokenizer.eos_token_id in input_ids: + # The renderer's token stream is authoritative: it mirrors the model's + # chat template (which carries no EOS by design) and trains the real + # stop signals via sampled_mask (e.g. GLM's turn-closing <|observation|> + # / <|user|>). Don't inject an EOS the renderer didn't emit. Only the + # chat-template fallback path (no sampled_mask) needs EOS appended so + # the model learns to stop. + if self.renderer is None and not self.tokenizer.eos_token_id in input_ids: self.logger.warning( f"Did not find EOS token ID {self.tokenizer.eos_token_id} in input_ids. Is something wrong with the chat template? Manually appending EOS token..." ) @@ -286,7 +335,11 @@ def should_mask(message: dict) -> bool: f"input_ids, loss_mask and target_ids must have the same length, but got {len(input_ids)=}, {len(loss_mask)=}, {len(target_ids)=}" ) assert sum(loss_mask) > 0, "There are no tokens in this sample that contribute to the loss" - assert self.tokenizer.eos_token_id in target_ids, "EOS token ID must be present in target_ids" + # Only the chat-template fallback guarantees an EOS; the renderer honors + # the template and may legitimately omit it (stop signals come from + # sampled turn markers instead). + if self.renderer is None: + assert self.tokenizer.eos_token_id in target_ids, "EOS token ID must be present in target_ids" # Create sample (with one fake target for the last token) return {