-
Notifications
You must be signed in to change notification settings - Fork 302
feat(sft): default loss_mask to renderer's sampled_mask #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,12 @@ | |
| from torchdata.stateful_dataloader import StatefulDataLoader | ||
| from transformers.tokenization_utils import PreTrainedTokenizer | ||
|
|
||
| from prime_rl.configs.sft import DataConfig, LossMaskConfig, SFTDataConfig | ||
| from prime_rl.configs.sft import ( | ||
| DataConfig, | ||
| LossMaskConfig, | ||
| LossMaskRolesConfig, | ||
| SFTDataConfig, | ||
| ) | ||
| from prime_rl.trainer.world import get_world | ||
| from prime_rl.utils.chat_template import ( | ||
| IncrementalTokenizationError, | ||
|
|
@@ -27,6 +32,41 @@ | |
| STACKING_DATASET_BUCKET_TIMEOUT = 10 | ||
|
|
||
|
|
||
| def _always_true_role_filter(message: dict) -> bool: | ||
| """Role filter that trains on every message regardless of role. | ||
|
|
||
| Used when :data:`LossMaskConfig` is ``"all"`` to bypass the | ||
| role-based ``role_to_mask`` gate in | ||
| :func:`renderers.base.build_training_sample`. | ||
| """ | ||
| return True | ||
|
|
||
|
|
||
| def _role_filter_for(cfg: LossMaskRolesConfig): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why private |
||
| """Build a per-message role filter from :class:`LossMaskRolesConfig`. | ||
|
|
||
| The returned callable is the ``role_to_mask`` argument passed to | ||
| :func:`renderers.base.build_training_sample` and to | ||
| :func:`prime_rl.utils.chat_template.build_incremental_token_mask`. | ||
| """ | ||
|
|
||
| def role_filter(message: dict) -> bool: | ||
| assert "role" in message, "Message must have a role" | ||
| match message["role"]: | ||
| case "user": | ||
| return cfg.user | ||
| case "assistant": | ||
| return cfg.assistant | ||
| case "system": | ||
| return cfg.system | ||
| case "tool": | ||
| return cfg.tool | ||
| case _: | ||
| raise ValueError(f"Invalid message role: {message['role']}") | ||
|
|
||
| return role_filter | ||
|
|
||
|
|
||
| class Sample(TypedDict): | ||
| input_ids: list[int] | ||
| position_ids: list[int] | ||
|
|
@@ -125,7 +165,7 @@ def __init__( | |
| seed: int = 0, | ||
| seq_len: int = 128, | ||
| non_dp_size: int = 1, | ||
| loss_mask_config: LossMaskConfig = LossMaskConfig(), | ||
| loss_mask_config: LossMaskConfig = "sampled", | ||
| max_examples: int | None = None, | ||
| max_epochs: int | None = None, | ||
| renderer: Renderer | None = None, | ||
|
|
@@ -219,19 +259,23 @@ def resolve_messages(example: dict) -> list[dict]: | |
| for t in raw_tools | ||
| ] | ||
|
|
||
| def should_mask(message: dict) -> bool: | ||
| assert "role" in message, "Message must have a role" | ||
| match message["role"]: | ||
| case "user": | ||
| return True if self.loss_mask_config.user else False | ||
| case "assistant": | ||
| return True if self.loss_mask_config.assistant else False | ||
| case "system": | ||
| return True if self.loss_mask_config.system else False | ||
| case "tool": | ||
| return True if self.loss_mask_config.tool else False | ||
| case _: | ||
| raise ValueError(f"Invalid message role: {message['role']}") | ||
| # Dispatch the loss-mask config to a renderer-side role filter. | ||
| # The string sentinels map to no filter ("sampled" → trust | ||
| # sampled_mask) or an always-True filter ("all"); the | ||
| # :class:`LossMaskRolesConfig` instance maps to a per-role | ||
| # boolean lookup. The chat-template fallback below has no | ||
| # ``is_sampled`` signal so we treat sentinel modes as the | ||
| # ``LossMaskRolesConfig()`` defaults (assistant-only). | ||
| loss_mask_cfg = self.loss_mask_config | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like all this logic, plus the |
||
| if loss_mask_cfg == "sampled": | ||
| renderer_role_filter = None | ||
| fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) | ||
| elif loss_mask_cfg == "all": | ||
| renderer_role_filter = _always_true_role_filter | ||
| fallback_role_filter = _role_filter_for(LossMaskRolesConfig()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| else: | ||
| renderer_role_filter = _role_filter_for(loss_mask_cfg) | ||
| fallback_role_filter = renderer_role_filter | ||
|
|
||
| if self.renderer is not None: | ||
| if example.get("chat_template_kwargs") and not self._warned_chat_template_kwargs: | ||
|
|
@@ -246,15 +290,15 @@ def should_mask(message: dict) -> bool: | |
| input_ids, loss_mask = build_training_sample( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i dislike the extremely different names between |
||
| self.renderer, | ||
| messages, | ||
| role_to_mask=should_mask, | ||
| role_to_mask=renderer_role_filter, | ||
| tools=tools, | ||
| ) | ||
| else: | ||
| try: | ||
| input_ids, loss_mask = build_incremental_token_mask( | ||
| self.tokenizer, | ||
| messages, | ||
| role_to_mask=should_mask, | ||
| role_to_mask=fallback_role_filter, | ||
| tools=tools, | ||
| chat_template_kwargs=example.get("chat_template_kwargs", {}), | ||
| collapse_consecutive_tool_messages=True, | ||
|
|
@@ -263,8 +307,13 @@ def should_mask(message: dict) -> bool: | |
| self.logger.warning(f"Skipping example {example.get('__index', '')}: {e}") | ||
| return None | ||
|
|
||
| # If EOS token is not found, manually append it | ||
| if not self.tokenizer.eos_token_id in input_ids: | ||
| # The renderer's token stream is authoritative: it mirrors the model's | ||
| # chat template (which carries no EOS by design) and trains the real | ||
| # stop signals via sampled_mask (e.g. GLM's turn-closing <|observation|> | ||
| # / <|user|>). Don't inject an EOS the renderer didn't emit. Only the | ||
| # chat-template fallback path (no sampled_mask) needs EOS appended so | ||
| # the model learns to stop. | ||
| if self.renderer is None and not self.tokenizer.eos_token_id in input_ids: | ||
| self.logger.warning( | ||
| f"Did not find EOS token ID {self.tokenizer.eos_token_id} in input_ids. Is something wrong with the chat template? Manually appending EOS token..." | ||
| ) | ||
|
|
@@ -286,7 +335,11 @@ def should_mask(message: dict) -> bool: | |
| f"input_ids, loss_mask and target_ids must have the same length, but got {len(input_ids)=}, {len(loss_mask)=}, {len(target_ids)=}" | ||
| ) | ||
| assert sum(loss_mask) > 0, "There are no tokens in this sample that contribute to the loss" | ||
| assert self.tokenizer.eos_token_id in target_ids, "EOS token ID must be present in target_ids" | ||
| # Only the chat-template fallback guarantees an EOS; the renderer honors | ||
| # the template and may legitimately omit it (stop signals come from | ||
| # sampled turn markers instead). | ||
| if self.renderer is None: | ||
| assert self.tokenizer.eos_token_id in target_ids, "EOS token ID must be present in target_ids" | ||
|
|
||
| # Create sample (with one fake target for the last token) | ||
| return { | ||
|
|
||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why private