diff --git a/torchtitan/experiments/rl/README.md b/torchtitan/experiments/rl/README.md index 711d3da59c..855fc0d834 100644 --- a/torchtitan/experiments/rl/README.md +++ b/torchtitan/experiments/rl/README.md @@ -27,11 +27,12 @@ uv venv --python 3.12 titan-rl source titan-rl/bin/activate ``` -1. Install Monarch and TorchStore from main: +1. Install Monarch, TorchStore, and Renderers from main: ```bash uv pip install torchmonarch==0.4.1 uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main" uv pip install pygtrie portpicker +uv pip install "git+https://github.com/PrimeIntellect-ai/renderers.git@main" ``` 2. Install Flash Attention 3 kernels: diff --git a/torchtitan/experiments/rl/actors/generator.py b/torchtitan/experiments/rl/actors/generator.py index c53509c28e..4d9d072721 100644 --- a/torchtitan/experiments/rl/actors/generator.py +++ b/torchtitan/experiments/rl/actors/generator.py @@ -163,6 +163,9 @@ class SamplingConfig: max_tokens: int = 100 """Maximum number of tokens to generate per completion.""" + stop_token_ids: list[int] = field(default_factory=list) + """Role-boundary stop tokens from the renderer (e.g. Qwen3 `<|im_end|>`).""" + class VLLMGenerator(Actor, Configurable): """ @@ -411,6 +414,7 @@ async def generate( top_p=_sampling_config.top_p, max_tokens=_sampling_config.max_tokens, n=_sampling_config.n, + stop_token_ids=_sampling_config.stop_token_ids or None, seed=self.config.debug.seed, logprobs=1, output_kind=RequestOutputKind.FINAL_ONLY, @@ -436,14 +440,14 @@ async def generate( all_outputs.extend(self._engine.step()) # vLLM may return requests out of order; sort by the integer - # request_id we assigned so prompt_idx lines up with the input. + # request_id we assigned so request_idx lines up with the input. all_outputs.sort(key=lambda o: int(o.request_id)) completions: list[Completion] = [] generation_metrics: list[m.Metric] = [] output_token_counts: list[int] = [] for output in all_outputs: - prompt_idx = int(output.request_id) + request_idx = int(output.request_id) generation_metrics.extend( _prepare_generation_request_metrics(output, prefix=metrics_prefix) ) @@ -456,8 +460,7 @@ async def generate( completions.append( Completion( policy_version=self.policy_version, - prompt_idx=prompt_idx, - text=sample.text, + request_idx=request_idx, token_ids=sample.token_ids, token_logprobs=per_token_logprobs, finish_reason=sample.finish_reason, diff --git a/torchtitan/experiments/rl/actors/trainer.py b/torchtitan/experiments/rl/actors/trainer.py index 9c0e50c630..dc6a4a12cd 100644 --- a/torchtitan/experiments/rl/actors/trainer.py +++ b/torchtitan/experiments/rl/actors/trainer.py @@ -455,7 +455,7 @@ async def forward_backward( async def optim_step(self) -> OptimStepOutput: """Clip gradients, step optimizer + LR scheduler, return updated state.""" # TODO: Accept optional optimizer params (e.g. learning rate) - # to allow controller-owned schedules (see Tinker API). + # to allow controller-owned schedules. # capture LR before step current_lrs = self.lr_schedulers.schedulers[0].get_last_lr() diff --git a/torchtitan/experiments/rl/config_registry.py b/torchtitan/experiments/rl/config_registry.py index 3d89c743c2..7a7b2631d3 100644 --- a/torchtitan/experiments/rl/config_registry.py +++ b/torchtitan/experiments/rl/config_registry.py @@ -25,7 +25,8 @@ from torchtitan.experiments.rl.batcher import BatchConfig, Batcher from torchtitan.experiments.rl.grpo import GRPOLoss, RLTrainer from torchtitan.experiments.rl.observability.metrics import MetricsProcessor -from torchtitan.experiments.rl.sum_digits import SumDigitsEnv +from torchtitan.experiments.rl.renderer import RendererConfig +from torchtitan.experiments.rl.tasks.sum_digits import SumDigitsDataset, SumDigitsTask from torchtitan.models.qwen3 import model_registry @@ -39,10 +40,14 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: num_prompts_per_step=5, num_validation_samples=20, compile=CompileConfig(enable=True, backend="aot_eager"), - env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), - validation_env=SumDigitsEnv.Config( - seed=99, correctness_reward=1.0, format_reward=0.3 - ), + tasks={ + "sum_digits": SumDigitsTask.Config( + train_dataset=SumDigitsDataset.Config(seed=42), + val_dataset=SumDigitsDataset.Config(seed=99), + ) + }, + group_size=group_size, + renderer=RendererConfig(name="qwen3", enable_thinking=True), metrics=MetricsProcessor.Config(enable_wandb=True), batcher=Batcher.Config( batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), @@ -77,10 +82,9 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: ), checkpoint=CheckpointManager.Config(enable=False), sampling=SamplingConfig( - n=group_size, temperature=0.8, top_p=0.95, - max_tokens=100, + max_tokens=700, ), ), ) @@ -96,10 +100,14 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config: num_prompts_per_step=5, num_validation_samples=20, compile=CompileConfig(enable=True, backend="aot_eager"), - env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), - validation_env=SumDigitsEnv.Config( - seed=99, correctness_reward=1.0, format_reward=0.3 - ), + tasks={ + "sum_digits": SumDigitsTask.Config( + train_dataset=SumDigitsDataset.Config(seed=42), + val_dataset=SumDigitsDataset.Config(seed=99), + ) + }, + group_size=group_size, + renderer=RendererConfig(name="qwen3", enable_thinking=True), metrics=MetricsProcessor.Config(enable_wandb=True), batcher=Batcher.Config( batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), @@ -135,10 +143,9 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config: ), checkpoint=CheckpointManager.Config(enable=False), sampling=SamplingConfig( - n=group_size, temperature=0.8, top_p=0.95, - max_tokens=100, + max_tokens=700, ), ), ) @@ -154,10 +161,14 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config: num_prompts_per_step=5, num_validation_samples=20, compile=CompileConfig(enable=True, backend="aot_eager"), - env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), - validation_env=SumDigitsEnv.Config( - seed=99, correctness_reward=1.0, format_reward=0.3 - ), + tasks={ + "sum_digits": SumDigitsTask.Config( + train_dataset=SumDigitsDataset.Config(seed=42), + val_dataset=SumDigitsDataset.Config(seed=99), + ) + }, + group_size=group_size, + renderer=RendererConfig(name="qwen3", enable_thinking=True), metrics=MetricsProcessor.Config(enable_wandb=True), batcher=Batcher.Config( batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), @@ -192,17 +203,16 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config: ), checkpoint=CheckpointManager.Config(enable=False), sampling=SamplingConfig( - n=group_size, temperature=0.8, top_p=0.95, - max_tokens=100, + max_tokens=700, ), ), ) def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config: - """On-policy GRPO config for Qwen3-0.6B under same parallelism (4 GPUs: 2 gen + 2 train). + """On-policy GRPO config for Qwen3-0.6B (4 GPUs: 2 gen + 2 train). Enables deterministic + batch-invariant mode for true on-policy RL training. """ @@ -215,10 +225,14 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config: num_prompts_per_step=5, num_validation_samples=20, compile=CompileConfig(enable=True, backend="aot_eager"), - env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), - validation_env=SumDigitsEnv.Config( - seed=99, correctness_reward=1.0, format_reward=0.3 - ), + tasks={ + "sum_digits": SumDigitsTask.Config( + train_dataset=SumDigitsDataset.Config(seed=42), + val_dataset=SumDigitsDataset.Config(seed=99), + ) + }, + group_size=group_size, + renderer=RendererConfig(name="qwen3", enable_thinking=True), metrics=MetricsProcessor.Config(enable_wandb=True), batcher=Batcher.Config( batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), @@ -257,10 +271,9 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config: ), checkpoint=CheckpointManager.Config(enable=False), sampling=SamplingConfig( - n=group_size, temperature=0.8, top_p=0.95, - max_tokens=100, + max_tokens=700, ), debug=batch_invariant_config, ), diff --git a/torchtitan/experiments/rl/env_types/__init__.py b/torchtitan/experiments/rl/env_types/__init__.py new file mode 100644 index 0000000000..1f7ade9fd3 --- /dev/null +++ b/torchtitan/experiments/rl/env_types/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.experiments.rl.env_types.message_env import ( + MessageEnv, + MessageResetOutput, + MessageStepOutput, +) +from torchtitan.experiments.rl.env_types.renderer_env import ( + RendererWrapperEnv, + TokenizedStepOutput, + TurnMessages, +) + +__all__ = [ + "MessageEnv", + "MessageResetOutput", + "MessageStepOutput", + "RendererWrapperEnv", + "TokenizedStepOutput", + "TurnMessages", +] diff --git a/torchtitan/experiments/rl/env_types/message_env.py b/torchtitan/experiments/rl/env_types/message_env.py new file mode 100644 index 0000000000..cb8404cba9 --- /dev/null +++ b/torchtitan/experiments/rl/env_types/message_env.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import abc +from dataclasses import dataclass, field + +from renderers import Message, ToolSpec + + +@dataclass(kw_only=True, slots=True) +class MessageResetOutput: + """Initial prompt messages + tool specs from `MessageEnv.reset`.""" + + prompt_messages: list[Message] # [M_prompt] + """The messages that form the initial prompt (e.g. [system, user]).""" + + tools: list[ToolSpec] = field(default_factory=list) # [K_tools] + """Tool schemas exposed to the assistant. Empty for tool-less envs.""" + + +@dataclass(kw_only=True, slots=True) +class MessageStepOutput: + """The env's reply to the assistant's turn.""" + + env_messages: list[Message] = field(default_factory=list) # [M_env] + """The env's reply messages (tool / user). Empty when the rollout terminates + with no follow-up.""" + + done: bool = False + """`True` ends the rollout.""" + + env_rewards: dict[str, float] = field(default_factory=dict) + """Optional reward signal the env provides for this step; the rubric decides + whether and how to use it. Empty if the env scores nothing.""" + + def __post_init__(self) -> None: + # env replies are tool/user turns; the assistant turn comes from the generator + if any(m.get("role") == "assistant" for m in self.env_messages): + raise ValueError( + "MessageStepOutput.env_messages may not contain assistant messages" + ) + + +class MessageEnv(abc.ABC): + """User-written env in message space. Implement `reset` + `step`. + + Tip: `MessageEnv` works in messages and never sees token ids; You can have `RendererWrapperEnv` + wrap it and use a `Renderer` to convert messages <-> token ids for the generator. + + Example: + # a one-tool calculator env. It is multi-turn — the env answers the + # assistant's tool call, then ends once the assistant replies without a tool. + + class CalculatorEnv(MessageEnv): + async def reset(self) -> MessageResetOutput: + return MessageResetOutput( + prompt_messages=[{"role": "user", "content": "What is 12 * 7?"}], + tools=[CALCULATOR_TOOL], + ) + + async def step(self, assistant_message: Message) -> MessageStepOutput: + tool_calls = assistant_message.get("tool_calls") + if not tool_calls: + return MessageStepOutput(done=True) # assistant gave its final answer + result = run_calculator(tool_calls[0]) + return MessageStepOutput( + env_messages=[{"role": "tool", "content": result}] + ) + """ + + @abc.abstractmethod + async def reset(self) -> MessageResetOutput: + """Return the initial conversation + tools for prompt rendering.""" + + @abc.abstractmethod + async def step(self, assistant_message: Message) -> MessageStepOutput: + """Advance the env one turn given the assistant's latest message. + + `RendererWrapperEnv` parses the completion and handles + finish_reason / length / parse / timeout failures before calling this, + so the env only sees a well-formed assistant message. + + Args: + assistant_message: the assistant's parsed turn. + + Returns: + `MessageStepOutput` with the env's reply messages. + """ + + async def close(self) -> None: + """Release env-owned resources. Default no-op; idempotent.""" diff --git a/torchtitan/experiments/rl/env_types/renderer_env.py b/torchtitan/experiments/rl/env_types/renderer_env.py new file mode 100644 index 0000000000..5db8e97aa4 --- /dev/null +++ b/torchtitan/experiments/rl/env_types/renderer_env.py @@ -0,0 +1,283 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from renderers import Message, Renderer, ToolSpec + +from torchtitan.experiments.rl.env_types.message_env import ( + MessageEnv, + MessageStepOutput, +) +from torchtitan.experiments.rl.rollouts.types import RolloutStatus + +if TYPE_CHECKING: + from torchtitan.experiments.rl.actors.generator import Completion + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True, slots=True) +class TurnMessages: + """Container for all message data collected during one rollut turn.""" + + prompt_messages: list[Message] # [M_prompt] + """Full conversation rendered into the prompt (the assistant's input this turn).""" + + assistant_message: Message | None = None + """The received assistant's message (generator output, parsed), if any. `None` on reset and on + parse/overflow failures.""" + + env_step: MessageStepOutput | None = None + """The env's reply this turn. `None` on reset, truncation, abort, and timeout (the env was not stepped).""" + + +@dataclass(kw_only=True, slots=True) +class TokenizedStepOutput: + """Output of `RendererWrapperEnv.reset` and `.step`. Holds the next + prompt token ids, if turn is not terminal, and all messages collected so far.""" + + next_prompt_token_ids: list[int] | None # [L_prompt] or None + """Input tokens for the NEXT generator cal; `None` on a terminal turn.""" + + status: RolloutStatus + """`ONGOING` while the rollout runs; a terminal `RolloutStatus` otherwise.""" + + turn: TurnMessages + """Message-space view of this turn (full conversation + assistant message + env reply).""" + + +class RendererWrapperEnv: + """Token-space wrapper around a `MessageEnv`. + + In a rollout, the input to a generator is a tokenized prompt, but the input to MessageEnv.step an + its output is a message. A step that translates message <-> token is necessary. + This wrapper fills this role, using a renderer to convert between the two and facilitating + the communication between the generator and the MessageEnv. + + In this process, several checks are necessary, e.g. prompt is too long, + total number of tokens is too long, number of turns exceeded the limit, etc. This wrapper + also takes care of that, so we can keep the rollout loop clean and simple. + + If users have extra or different logic, they can wrap their MessageEnv with another class instead. + + Args: + message_env: the user's `MessageEnv` subclass instance. + renderer: a `renderers.Renderer` that converts messages <-> token ids. + config: `RendererWrapperEnv.Config`. + + Example: + + env = RendererWrapperEnv(message_env=SumDigitsEnv(...), renderer=renderer) + step = await env.reset() + while not step.status.is_terminal(): + completion = await generator.generate([step.next_prompt_token_ids]) + step = await env.step(completion) + """ + + @dataclass(kw_only=True, slots=True) + class Config: + """Limits enforced by the wrapper""" + + max_rollout_tokens: int | None = None + """Hard cap on prompt length for the next turn. If the number of tokens meets/exceeds + it, the turn is terminal; `None` disables the check.""" + + # TODO: its unclear if timeout should be on this layer or handled by the messageEnv + step_timeout_s: float | None = 1800.0 + """Wall-clock timeout for one `MessageEnv.step` call.""" + + # TODO: add max_num_turns + + def __init__( + self, + *, + message_env: MessageEnv, + renderer: Renderer, + config: "RendererWrapperEnv.Config | None" = None, + ) -> None: + self._message_env = message_env + self._renderer = renderer + self._config = config or RendererWrapperEnv.Config() + self._tools: list[ToolSpec] | None = None + self._messages: list[Message] = [] + self._last_prompt_ids: list[int] = [] + + async def reset(self) -> TokenizedStepOutput: + """Render the initial conversation into the first generator prompt.""" + env_reset = await self._message_env.reset() + self._messages = list(env_reset.prompt_messages) + + # Render messages into tokens + self._tools = list(env_reset.tools) if env_reset.tools else None + token_ids = await asyncio.to_thread( + self._renderer.render_ids, + messages=self._messages, + tools=self._tools, + add_generation_prompt=True, + ) + + # Terminal if the prompt is already over budget + if self._is_prompt_overflow(prompt_len=len(token_ids)): + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.TRUNCATED_PROMPT_TOO_LONG, + ) + + self._last_prompt_ids = list(token_ids) + return TokenizedStepOutput( + next_prompt_token_ids=list(token_ids), + status=RolloutStatus.ONGOING, + turn=TurnMessages(prompt_messages=list(self._messages)), + ) + + async def step(self, completion: "Completion") -> TokenizedStepOutput: + """Advance the env by one sampled completion from the generator. + + Args: + completion: Generator output for the current prompt. + + Returns: + `TokenizedStepOutput` for the next generator call, or a terminal turn + when the rollout completes, truncates, or errors. + """ + # Parse first, so a truncated / aborted response still carries its message + try: + parsed = await asyncio.to_thread( + self._renderer.parse_response, + token_ids=list(completion.token_ids), + ) + except Exception: + logger.exception( + "parse_response failed (finish_reason=%s, %d tokens); -> ERROR_PARSE", + completion.finish_reason, + len(completion.token_ids), + ) + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.ERROR_PARSE, + ) + + assistant: Message = {"role": "assistant", "content": parsed.content} + if parsed.reasoning_content: + assistant["reasoning_content"] = parsed.reasoning_content + if parsed.tool_calls: + assistant["tool_calls"] = parsed.tool_calls + + # Truncated / aborted: the response is final and partial. Keep it for + # partial-reward grading and debugging; don't step the env on it. + # TODO: check if we should step the env on an incomplete message + if completion.finish_reason == "length": + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.TRUNCATED_LENGTH, + assistant_message=assistant, + ) + if completion.finish_reason == "abort": + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.ERROR_ABORT, + assistant_message=assistant, + ) + + # Apply the user's env step under a timeout + timeout = self._config.step_timeout_s + try: + if timeout is None: + env_step = await self._message_env.step(assistant) + else: + env_step = await asyncio.wait_for( + self._message_env.step(assistant), timeout=timeout + ) + except TimeoutError: + logger.warning("step timed out after %ss; -> ERROR_TIMEOUT", timeout) + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.ERROR_TIMEOUT, + assistant_message=assistant, + ) + + self._messages.append(assistant) + self._messages.extend(env_step.env_messages) + + if env_step.done: + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.COMPLETED, + assistant_message=assistant, + env_step=env_step, + ) + + # Prepare the next prompt; full re-render if the renderer can't bridge. + # `tools` is passed because tool schemas are part of the chat template, so + # the bridged tokens must match what a full re-render (also tools-aware) produces. + bridged = await asyncio.to_thread( + self._renderer.bridge_to_next_turn, + previous_prompt_ids=self._last_prompt_ids, + previous_completion_ids=list(completion.token_ids), + new_messages=env_step.env_messages, + tools=self._tools, + ) + if bridged is None: + next_prompt_token_ids = await asyncio.to_thread( + self._renderer.render_ids, + messages=self._messages, + tools=self._tools, + add_generation_prompt=True, + ) + else: + next_prompt_token_ids = bridged.token_ids + + # Terminal if the next prompt is over budget + if self._is_prompt_overflow(prompt_len=len(next_prompt_token_ids)): + return _terminal( + prompt_messages=self._messages, + status=RolloutStatus.TRUNCATED_PROMPT_TOO_LONG, + assistant_message=assistant, + env_step=env_step, + ) + + self._last_prompt_ids = list(next_prompt_token_ids) + return TokenizedStepOutput( + next_prompt_token_ids=list(next_prompt_token_ids), + status=RolloutStatus.ONGOING, + turn=TurnMessages( + prompt_messages=list(self._messages), + assistant_message=assistant, + env_step=env_step, + ), + ) + + async def close(self) -> None: + await self._message_env.close() + + def _is_prompt_overflow(self, *, prompt_len: int) -> bool: + cap = self._config.max_rollout_tokens + return cap is not None and prompt_len >= cap + + +def _terminal( + *, + prompt_messages: list[Message], + status: RolloutStatus, + assistant_message: Message | None = None, + env_step: MessageStepOutput | None = None, +) -> TokenizedStepOutput: + """Build a terminal `TokenizedStepOutput` with the given status.""" + return TokenizedStepOutput( + next_prompt_token_ids=None, + status=status, + turn=TurnMessages( + prompt_messages=list(prompt_messages), + assistant_message=assistant_message, + env_step=env_step, + ), + ) diff --git a/torchtitan/experiments/rl/grpo.py b/torchtitan/experiments/rl/grpo.py index 1938c11a0d..0d6f3eb6e6 100644 --- a/torchtitan/experiments/rl/grpo.py +++ b/torchtitan/experiments/rl/grpo.py @@ -27,9 +27,9 @@ import os import statistics import time -from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass, field +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field, replace # must run before torch import os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") @@ -39,18 +39,32 @@ from monarch.actor import this_host from monarch.spmd import setup_torch_elastic_env_async -from torchtitan.components.tokenizer import HuggingFaceTokenizer from torchtitan.config import ( CompileConfig, ConfigManager, Configurable, ParallelismConfig, ) -from torchtitan.experiments.rl.actors.generator import SamplingConfig, VLLMGenerator +from torchtitan.experiments.rl.actors.generator import ( + Completion, + SamplingConfig, + VLLMGenerator, +) from torchtitan.experiments.rl.actors.trainer import PolicyTrainer from torchtitan.experiments.rl.batcher import Batcher +from torchtitan.experiments.rl.env_types import RendererWrapperEnv, TokenizedStepOutput from torchtitan.experiments.rl.observability import metrics as m -from torchtitan.experiments.rl.types import Completion, Episode, Trajectory +from torchtitan.experiments.rl.renderer import RendererConfig +from torchtitan.experiments.rl.rollouts import ( + prepare_rollout_metrics, + Rollout, + rollout_to_episode, + RolloutGroup, + RolloutStatus, + RolloutTurn, +) +from torchtitan.experiments.rl.tasks import Task +from torchtitan.experiments.rl.types import Episode from torchtitan.observability import structured_logger as sl from torchtitan.protocols.model_spec import ModelSpec @@ -167,57 +181,33 @@ def _bootstrap(): return _bootstrap -def _log_samples(items: list[Episode] | list[Completion]) -> None: - """Log the first sample per prompt for debugging.""" +def _log_samples(episodes: list[Episode]) -> None: + """Log the first episode per prompt for debugging.""" seen_prompts: set[int] = set() - for item in items: - if item.prompt_idx in seen_prompts: + for ep in episodes: + if ep.prompt_idx in seen_prompts: continue - seen_prompts.add(item.prompt_idx) - reward_str = f" reward={item.reward:+.1f}" if hasattr(item, "reward") else "" - logger.info(f" [prompt {item.prompt_idx}]{reward_str}") - logger.info(f" A: {item.text[:300].replace(chr(10), ' ').strip()}") - + seen_prompts.add(ep.prompt_idx) + logger.info(f" [prompt {ep.prompt_idx}] reward={ep.reward:+.1f}") + logger.info(f" A: {ep.text[:300].replace(chr(10), ' ').strip()}") -def _prepare_reward_metrics( - prefix: str, - trajectories: list[Trajectory], -) -> list[m.Metric]: - """One ``Mean`` metric per observed reward component across trajectories. - Example:: +class RLTrainer(Configurable): + """Top-level RL training orchestrator. - trajectories = [ - Trajectory( - sample_idx=0, - prompt_token_ids=p0, - transitions=[(c0, Step(rewards={"correctness": 1.0, "format": 0.5}, done=True))], - ), - Trajectory( - sample_idx=1, - prompt_token_ids=p1, - transitions=[(c1, Step(rewards={"correctness": 0.0}, done=True))], - ), - ] - _prepare_reward_metrics("reward/component", trajectories) - # -> [ - # Metric("reward/component/correctness", Mean(sum=1.0, count=2)), # 0.5 - # Metric("reward/component/format", Mean(sum=0.5, count=1)), # 0.5 - "format" only in trajectory 0 - # ] - """ - values_by_name: dict[str, list[float]] = defaultdict(list) - for trajectory in trajectories: - for _completion, step in trajectory.transitions: - for name, value in step.rewards.items(): - values_by_name[name].append(float(value)) - return [ - m.Metric(f"{prefix}/{name}", m.Mean.from_list(values)) - for name, values in sorted(values_by_name.items()) - ] + Owns a `PolicyTrainer` actor (gradient updates), a `VLLMGenerator` actor + (sampling), one or more `Task`s (rubric + env construction), and a + `Dataset` per phase (train/validation). Each training step samples groups + of rollouts, scores them via per-task rubrics, builds GRPO advantages, and + syncs trainer weights to the generator. + Example: -class RLTrainer(Configurable): - """Top-level RL training orchestrator.""" + cfg = config_registry.rl_grpo_qwen3_0_6b() + trainer = cfg.build() + await trainer.setup_async() + await trainer.train() + """ @dataclass(kw_only=True, slots=True) class Config(Configurable.Config): @@ -238,19 +228,22 @@ class Config(Configurable.Config): num_prompts_per_step: int = 5 """Number of distinct prompts (= GRPO groups) drawn per training step. + Total episodes per step is `num_prompts_per_step * group_size`.""" - The total episodes per step is `num_prompts_per_step` * `group_size`, - where `group_size` is `generator.sampling.n` (completions per prompt). - """ + group_size: int = 8 + """Sibling rollouts sampled per dataset row (the GRPO group). The generator + is always called with `n=1`; prompts are pre-expanded by `group_size`.""" num_validation_samples: int = 20 """Number of held-out prompts scored greedily (temp=0, n=1) per validation pass.""" - env: Configurable.Config = field(default=None) # type: ignore[assignment] - """Env config for training rollouts.""" + tasks: dict[str, Task.Config] = field(default_factory=dict) + """Map from task name to `Task.Config` (each `Task` owns its datasets). + Single-task runs use a one-key dict; multi-task mixing lands with the + datamix PR.""" - validation_env: Configurable.Config = field(default=None) # type: ignore[assignment] - """Env config for validation rollouts.""" + renderer: RendererConfig = field(default_factory=RendererConfig) + """Message-to-token renderer config.""" log_samples: bool = False """Log first completion per episode during training and validation.""" @@ -280,6 +273,8 @@ def __post_init__(self): "(weights are synced from the trainer via TorchStore). " "Set generator.checkpoint.enable=False." ) + if not self.tasks: + raise ValueError("tasks must not be empty") if self.trainer.debug.batch_invariant: if not self.trainer.debug.deterministic: @@ -312,10 +307,25 @@ def __init__(self, config: Config): log_dir=config.dump_folder, job_config=config.to_dict(), ) - # TODO: Replace this single-turn tokenizer with renderer - self.tokenizer = HuggingFaceTokenizer(tokenizer_path=config.hf_assets_path) - # TODO: Use tokenizer.pad_id when available, falling back to eos_id. - self.batcher = Batcher(config.batcher, pad_id=self.tokenizer.eos_id) + self.renderer = config.renderer.build( + tokenizer_path=config.hf_assets_path, model_spec=config.model_spec + ) + self._stop_token_ids = list(self.renderer.get_stop_token_ids()) + # Sampling config reused for every generate call: stop tokens baked in, + # n=1 because prompts are pre-expanded by group_size. + self._sampling = replace( + config.generator.sampling, stop_token_ids=self._stop_token_ids, n=1 + ) + # TODO: pass our own tokenizer to the renderer and read pad/eos off it + # once `renderers` supports bring-your-own-tokenizer + # (https://github.com/PrimeIntellect-ai/renderers/pull/70). + # Until then, reach into the renderer's tokenizer for the pad id (eos doubles as pad). + self.batcher = Batcher( + config.batcher, pad_id=self.renderer._tokenizer.eos_token_id + ) + self._tasks: dict[str, Task] = { + name: cfg.build() for name, cfg in config.tasks.items() + } async def close(self): """Best-effort: tear down actors, close metric backends, then stop proc meshes.""" @@ -409,6 +419,17 @@ async def setup_async( nodes (no heterogeneous node configurations). Required when host_mesh is provided. """ + # Size the thread executor for the renderer's `asyncio.to_thread` calls: + # one per concurrent rollout, capped by the machine's CPUs. + max_concurrent_rollouts = max( + self.config.num_prompts_per_step * self.config.group_size, + self.config.num_validation_samples, + ) + max_workers = max(1, min(max_concurrent_rollouts, os.cpu_count() or 1)) + asyncio.get_running_loop().set_default_executor( + ThreadPoolExecutor(max_workers=max_workers) + ) + config = self.config self.trainer_world_size = self._compute_world_size(config.trainer.parallelism) @@ -517,7 +538,7 @@ async def setup_async( model_path=config.hf_assets_path, compile_config=config.compile, max_num_seqs=max( - config.num_prompts_per_step * config.generator.sampling.n, + config.num_prompts_per_step * config.group_size, config.num_validation_samples, ), output_dir=config.dump_folder, @@ -539,100 +560,304 @@ async def setup_async( self.generator.pull_model_state_dict.call(0).get() @sl.log_trace_span("_collect_rollouts") - def _collect_rollouts( + async def _collect_rollouts( self, num_groups: int, step: int, - group_offset: int = 0, - ) -> tuple[list[Trajectory], list[m.Metric]]: - """Collect group rollouts and emit completion-shape rollout metrics. + group_offset: int, + ) -> tuple[list[RolloutGroup], list[m.Metric]]: + """Collect train rollout groups and emit rollout-shape metrics. Args: num_groups: Number of prompt groups to collect in this round. - step: Current training step (passed to env for curriculum). - group_offset: Starting group index so that env ``group_idx`` - values are unique across collection rounds within a step. + step: Current training step (tagged into `group_id` for metrics). + group_offset: Starting group index so generated `group_id`s + stay unique across collection rounds within a step. + + Returns: + Scored rollout groups and rollout/generator metrics. """ - envs = [ - self.config.env.build(step=step, group_idx=group_offset + i) - for i in range(num_groups) - ] - # TODO: Add a check max_tokens = min(max_tokens, context_window - model_input.length) - # and pass max_tokens to the generator call or skip the call if max_tokens<=0. - # Do the same for validation. - tokenized_prompts = [ - self.tokenizer.encode(env.prompt, add_bos=True, add_eos=False) - for env in envs - ] - completions, generation_metrics = self._get_rank_0_value( - self.generator.generate.call(tokenized_prompts).get() + rollout_groups, generation_metrics = await self._run_rollouts( + split="train", + num_groups=num_groups, + group_size=self.config.group_size, + sampling=self._sampling, + step=step, + group_offset=group_offset, + metrics_prefix="generator", ) - trajectories: list[Trajectory] = [] - with sl.log_trace_span("score"): - for c in completions: - step_result = envs[c.prompt_idx].step(c.text) - trajectories.append( - Trajectory( - sample_idx=group_offset + c.prompt_idx, - prompt_token_ids=tokenized_prompts[c.prompt_idx], - transitions=[(c, step_result)], + rollout_metrics = prepare_rollout_metrics( + "rollout", + [rollout for group in rollout_groups for rollout in group.rollouts], + ) + rollout_metrics += generation_metrics + return rollout_groups, rollout_metrics + + @sl.log_trace_span("_run_rollouts") + async def _run_rollouts( + self, + *, + split: str, + num_groups: int, + group_size: int, + sampling: SamplingConfig, + step: int, + group_offset: int, + metrics_prefix: str, + ) -> tuple[list[RolloutGroup], list[m.Metric]]: + """Build groups, batch-generate, then per group: env.step + + task.score_group. Per-group failures are logged and dropped. + + Steps: + 1. Sample examples from the task's `split` dataset + 2. Create N envs per example: `task.make_envs(example, group_size)` + 3. For each env, get the initial prompt (env.reset) + 4. Run one batched `generate` call (n=1; prompts pre-expanded) + 5. For each rollout, run `env.step` + 6. For each RolloutGroup, run `reward = task.score_group(RolloutGroup)` + 7. Return scored list[RolloutGroup] + TODO(continuous-batching): once available, run rollouts independently + instead of batching one `generate` over all prompts at once. + TODO(datamix): with >1 task, interleave/weight across self._tasks here. + """ + + @dataclass(kw_only=True, slots=True) + class _PendingGroup: + """One prompt group under construction: built before generation, then + stepped and scored into a `RolloutGroup`.""" + + group_id: str + example: object + task: Task + envs: list[RendererWrapperEnv] # [group_size] + + # 1. Sample examples from the task's `split` dataset. + # 2. Create N envs per example: `task.make_envs(example, group_size)`. + (task,) = self._tasks.values() + pending_groups: list[_PendingGroup] = [] + for group_idx in range(num_groups): + example = ( + task.sample_train_example() + if split == "train" + else task.sample_val_example() + ) + pending_groups.append( + _PendingGroup( + group_id=f"step={step}/group={group_offset + group_idx}", + example=example, + task=task, + envs=task.make_envs( + example=example, + group_size=group_size, + renderer=self.renderer, + ), # [N_samples_per_group] + ) + ) + + try: + # 4. For each env, get initial prompt (n_groups * n_rollouts_per_group) + initial_steps: list[list[TokenizedStepOutput]] = await asyncio.gather( + *( + asyncio.gather(*(env.reset() for env in group.envs)) + for group in pending_groups + ) + ) # [G][N] + + # Drop the whole group if its shared initial prompt is already over the + # token budget (TRUNCATED_PROMPT_TOO_LONG): there's no room to generate. + # TODO: add metrics for dropped groups + runnable_group_idxs = [ + group_idx + for group_idx in range(num_groups) + if initial_steps[group_idx][0].status is RolloutStatus.ONGOING + ] + + # 5. Run one batched `generate` call (n=1: one rollout per prompt). + # `rollout_index` is parallel to the prompts, so the i-th prompt (and + # i-th completion, once ordered) belongs to rollout_index[i]. + # TODO: pass the remaining budget (max_rollout_tokens - len(prompt)) to the + # sampling_config, to limit generation length in one turn. + rollout_index = [ + (group_idx, sample_idx) + for group_idx in runnable_group_idxs + for sample_idx in range(group_size) + ] + completions, gen_metrics = self._get_rank_0_value( + self.generator.generate.call( + [ + list(initial_steps[group_idx][sample_idx].next_prompt_token_ids) + for group_idx, sample_idx in rollout_index + ], + sampling_config=sampling, + metrics_prefix=metrics_prefix, + ).get() + ) + + # 6. For each rollout, run `env.step`. `completion.request_idx` is the + # completion's position in the flattened prompt list we sent to + # `generate`; sort by it so completions line up 1:1 with rollout_index. + ordered_completions = sorted( + completions, key=lambda completion: completion.request_idx + ) + rollouts = await asyncio.gather( + *( + self._do_single_rollout( + group_id=pending_groups[group_idx].group_id, + sample_idx=sample_idx, + env=pending_groups[group_idx].envs[sample_idx], + initial_step=initial_steps[group_idx][sample_idx], + completion=completion, + ) + for (group_idx, sample_idx), completion in zip( + rollout_index, ordered_completions, strict=True ) ) + ) + finally: + await asyncio.gather( + *(env.close() for group in pending_groups for env in group.envs), + return_exceptions=True, + ) - # Metrics - response_lens = [len(c.token_ids) for c in completions] - prompt_lens = [len(t.prompt_token_ids) for t in trajectories] - total_lens = [p + r for p, r in zip(prompt_lens, response_lens, strict=True)] - truncated = [c.finish_reason == "length" for c in completions] - rollout_metrics: list[m.Metric] = [ - m.Metric("rollout/response_length", m.Mean.from_list(response_lens)), - m.Metric("rollout/response_length", m.Max.from_list(response_lens)), - m.Metric("rollout/prompt_length", m.Mean.from_list(prompt_lens)), - m.Metric("rollout/prompt_length", m.Max.from_list(prompt_lens)), - m.Metric("rollout/total_length", m.Max.from_list(total_lens)), - m.Metric("rollout/truncation_rate", m.Mean.from_list(truncated)), - ] - rollout_metrics += generation_metrics - rollout_metrics += _prepare_reward_metrics( - prefix="reward/component", trajectories=trajectories + # 7. For each RolloutGroup, run `reward = task.score_group(RolloutGroup)`. + # Group by the rollout's own group_id (sample order preserved) + rollouts_by_group_id: dict[str, list[Rollout]] = {} + for rollout in rollouts: + rollouts_by_group_id.setdefault(rollout.group_id, []).append(rollout) + + rollout_groups: list[RolloutGroup] = [] + num_failed_groups = 0 + for group_idx in runnable_group_idxs: + pending_group = pending_groups[group_idx] + rollouts = rollouts_by_group_id[pending_group.group_id] # [N], sample order + try: + rewards = await pending_group.task.score_group( + rollouts, pending_group.example + ) + for rollout, reward in zip(rollouts, rewards, strict=True): + rollout.reward = reward.reward + rollout.reward_breakdown = reward.reward_breakdown + rollout_groups.append( + RolloutGroup( + group_id=pending_group.group_id, + env_input=pending_group.example, + rollouts=rollouts, + ) + ) + except Exception: + logger.exception( + "group %s scoring failed; dropping", pending_group.group_id + ) + num_failed_groups += 1 + + gen_metrics = list(gen_metrics) + gen_metrics.append( + m.Metric("rollout/group_failures", m.Sum(float(num_failed_groups))) + ) + return rollout_groups, gen_metrics + + @sl.log_trace_span("do_single_rollout") + async def _do_single_rollout( + self, + *, + group_id: str, + sample_idx: int, + env: RendererWrapperEnv, + initial_step: TokenizedStepOutput, + completion: Completion, + ) -> Rollout: + """Step one env into a `Rollout`. On failure, return the turns + collected so far with an `ERROR` status. + + Reward is left unset; the controller scores via `task.score_group(...)` + afterward and fills `reward` / `reward_breakdown`. + + Args: + group_id: Stable prompt-group ID used for advantage centering. + sample_idx: Sample index within the group (0..group_size-1). + env: The env for this rollout. + initial_step: Initial prompt step for this env. + completion: Generator completion for this env's initial prompt. + + Returns: + One unscored Rollout. + """ + rollout_turns: list[RolloutTurn] = [] + try: + step_result = await env.step(completion) + turn = step_result.turn + env_step = turn.env_step + rollout_turns.append( + RolloutTurn( + prompt_token_ids=list(initial_step.next_prompt_token_ids), + assistant_token_ids=list(completion.token_ids), + assistant_logprobs=list(completion.token_logprobs), + policy_version=completion.policy_version, + prompt_messages=list(initial_step.turn.prompt_messages), + assistant_message=turn.assistant_message, + env_messages=list(env_step.env_messages) if env_step else [], + env_rewards=dict(env_step.env_rewards) if env_step else {}, + ) + ) + status = step_result.status + # TODO(multi-turn): while not status.is_terminal(): generate → step → append turn. + if not status.is_terminal(): + raise RuntimeError( + f"env {group_id}/{sample_idx} returned a non-terminal turn; " + "the controller does not yet support multi-turn rollouts." + ) + except Exception: + logger.exception( + "rollout %s/%d failed; keeping %d turn(s) as ERROR", + group_id, + sample_idx, + len(rollout_turns), + ) + status = RolloutStatus.ERROR + return Rollout( + group_id=group_id, sample_idx=sample_idx, status=status, turns=rollout_turns ) - return trajectories, rollout_metrics @staticmethod @sl.log_trace_span("_build_episodes") def _build_episodes( - trajectories: list[Trajectory], + rollout_groups: list[RolloutGroup], ) -> tuple[list[Episode], list[m.Metric]]: - """Group trajectories by sample, apply mean-baseline advantage, emit metrics.""" - groups: dict[int, list[Trajectory]] = {} - for t in trajectories: - groups.setdefault(t.sample_idx, []).append(t) + """Build train episodes and GRPO advantages from scored rollout groups. + + Centers each group's rewards by its mean, skips rollouts without + training tokens, and emits reward/advantage metrics. + + Args: + rollout_groups: Scored rollout groups from one collection round. + Returns: + Train episodes plus episode-level metrics. + """ + # Mean-baseline advantage per group episodes: list[Episode] = [] group_stds: list[float] = [] - for sample_idx, group in groups.items(): - rewards = [t.total_reward for t in group] - group_mean = sum(rewards) / len(rewards) - # Population standard deviation; NaN for an empty group. - group_stds.append(statistics.pstdev(float(r) for r in rewards)) - for t in group: - # Single-turn: exactly one (completion, step) per trajectory. - c, _ = t.transitions[0] - episodes.append( - Episode( - policy_version=c.policy_version, - prompt_idx=sample_idx, - prompt_token_ids=t.prompt_token_ids, - text=c.text, - token_ids=c.token_ids, - token_logprobs=c.token_logprobs, - reward=t.total_reward, - advantage=t.total_reward - group_mean, - ) + for group_idx, group in enumerate(rollout_groups): + # Drop the whole group if any sibling has no trainable tokens (e.g. an + # ERROR rollout with no turns); rollout_to_episode requires one turn. + if any(not rollout.turns for rollout in group.rollouts): + logger.warning( + "group %s has a turn-less rollout; dropping the group", + group.group_id, ) + continue + + rewards = [rollout.reward for rollout in group.rollouts] + group_mean = sum(rewards) / len(rewards) + group_stds.append(statistics.pstdev(rewards)) + + for rollout in group.rollouts: + rollout.advantage = rollout.reward - group_mean + episode = rollout_to_episode(rollout) + episodes.append(replace(episode, prompt_idx=group_idx)) - num_groups = len(groups) + num_groups = len(rollout_groups) zero_std_frac = ( sum(1 for s in group_stds if s == 0.0) / num_groups if num_groups else 0.0 ) @@ -668,62 +893,37 @@ def _build_episodes( @sl.log_trace_span("validate") async def validate(self) -> list[m.Metric]: - """Run validation on held-out prompts using greedy sampling. + """Run greedy validation on held-out prompts. - TODO: investigate using pass@k. + Returns: + Validation rollout metrics, generation metrics, and validation + timing. """ + # TODO: investigate using pass@k for validation. t_validate_start = time.perf_counter() num_samples = self.config.num_validation_samples - envs = [ - self.config.validation_env.build(step=0, group_idx=i) - for i in range(num_samples) - ] - greedy = SamplingConfig( - n=1, - temperature=0.0, - top_p=1.0, - max_tokens=self.config.generator.sampling.max_tokens, - ) + greedy = replace(self._sampling, temperature=0.0, top_p=1.0) - tokenized_prompts: list[list[int]] = [ - self.tokenizer.encode(env.prompt, add_bos=True, add_eos=False) - for env in envs - ] - completions, generation_metrics = self._get_rank_0_value( - self.generator.generate.call( - tokenized_prompts, - sampling_config=greedy, - metrics_prefix="validation_generator", - ).get() + rollout_groups, generation_metrics = await self._run_rollouts( + split="val", + num_groups=num_samples, + group_size=1, + sampling=greedy, + step=0, + group_offset=0, + metrics_prefix="validation_generator", ) - - trajectories = [ - Trajectory( - sample_idx=i, - prompt_token_ids=tokenized_prompts[i], - transitions=[(c, envs[i].step(c.text))], - ) - for i, c in enumerate(completions) - ] + rollouts = [rollout for group in rollout_groups for rollout in group.rollouts] if self.config.log_samples: - _log_samples(completions) + preview = [rollout_to_episode(r) for r in rollouts if r.reward is not None] + _log_samples(preview) - validation_metrics: list[m.Metric] = [ - m.Metric( - "validation/reward", - m.SummaryStats.from_list([t.total_reward for t in trajectories]), - ), - m.Metric( - "validation/response_length", - m.Mean.from_list([len(c.token_ids) for c in completions]), - ), - m.Metric("validation/num_samples", m.NoReduce(float(len(trajectories)))), - ] - validation_metrics += generation_metrics - validation_metrics += _prepare_reward_metrics( - prefix="validation/reward/component", trajectories=trajectories + validation_metrics = prepare_rollout_metrics("validation", rollouts) + validation_metrics.append( + m.Metric("validation/num_samples", m.NoReduce(float(len(rollouts)))) ) + validation_metrics += generation_metrics t_validate_s = time.perf_counter() - t_validate_start validation_metrics.append(m.Metric("timing/validate", m.NoReduce(t_validate_s))) @@ -766,7 +966,7 @@ async def train(self): # token budget. The Batcher then packs, truncates to # global_batch_size rows, and splits into microbatches. t_rollout_start = time.perf_counter() - trajectories: list[Trajectory] = [] + rollout_groups: list[RolloutGroup] = [] rollout_metrics: list[m.Metric] = [] collected_tokens = 0 group_offset = 0 @@ -776,20 +976,21 @@ async def train(self): # rows, so actual token consumption may exceed collected_tokens. num_tokens_target = self.batcher.num_tokens_target(self.trainer_dp_degree) while collected_tokens < num_tokens_target: - new_trajectories, new_metrics = self._collect_rollouts( + new_rollout_groups, new_metrics = await self._collect_rollouts( num_groups, step=step, group_offset=group_offset ) - trajectories.extend(new_trajectories) + rollout_groups.extend(new_rollout_groups) rollout_metrics.extend(new_metrics) # Both prompt length and completion length are counted. collected_tokens += sum( - len(t.prompt_token_ids) + len(c.token_ids) - 1 - for t in new_trajectories - for c, _ in t.transitions + len(t.prompt_token_ids) + len(t.assistant_token_ids) - 1 + for group in new_rollout_groups + for r in group.rollouts + for t in r.turns ) group_offset += num_groups - episodes, episode_metrics = self._build_episodes(trajectories) + episodes, episode_metrics = self._build_episodes(rollout_groups) t_rollout_s = time.perf_counter() - t_rollout_start if self.config.log_samples: diff --git a/torchtitan/experiments/rl/renderer.py b/torchtitan/experiments/rl/renderer.py new file mode 100644 index 0000000000..c497cd1d83 --- /dev/null +++ b/torchtitan/experiments/rl/renderer.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from pydantic import TypeAdapter +from renderers import create_renderer, Renderer, RendererConfig as _RendererConfig + +from torchtitan.config import Configurable + +if TYPE_CHECKING: + from torchtitan.protocols.model_spec import ModelSpec + +# `renderers` exposes a pydantic discriminated union (on `name`); let it route +# `name` -> the matching config class +_RENDERER_CONFIG = TypeAdapter(_RendererConfig) + +# Map a TorchTitan model family (`ModelSpec.name`) to a `renderers` name, used to +# resolve `name="auto"` without relying on the HF tokenizer's `name_or_path` +# (a local path for our checkpoints, which misses the renderers auto map). +_TORCHTITAN_RENDERER_BY_MODEL: dict[str, str] = { + "qwen3": "qwen3", +} + + +def _resolve_renderer_name(name: str, model_spec: "ModelSpec | None") -> str: + """Resolve `name="auto"` via the TorchTitan model; explicit names pass through.""" + if name != "auto": + return name + if model_spec is not None: + mapped = _TORCHTITAN_RENDERER_BY_MODEL.get(model_spec.name) + if mapped is not None: + return mapped + # Last resort: let `renderers` resolve from the tokenizer's model id. + return "auto" + + +def _build_renderer_config(name: str, cfg: "RendererConfig"): + """Build the `renderers` config for `name`, passing only the knobs it supports.""" + if name == "auto": + return None + config_cls = type(_RENDERER_CONFIG.validate_python({"name": name})) + supported_args: dict[str, bool | str] = { + field: getattr(cfg, field) + for field in ( + "enable_thinking", + "preserve_all_thinking", + "preserve_thinking_between_tool_calls", + "tool_parser", + "reasoning_parser", + ) + if field in config_cls.model_fields and getattr(cfg, field) is not None + } + return config_cls(**supported_args) + + +@dataclass(kw_only=True, slots=True) +class RendererConfig(Configurable.Config): + """Selects the renderer used for message <-> token conversion. + + Wraps `PrimeIntellect-ai/renderers`. `build` loads a tokenizer from + `tokenizer_path` and constructs the `renderers` config for `name`. + + Args: + name: Renderer name. `"auto"` resolves from the TorchTitan model + (`ModelSpec.name`); or name it explicitly: `"qwen3"`, `"gpt-oss"`, + `"deepseek-v3"`, ... (see renderers `RENDERER_REGISTRY`). + tool_parser: Tool-call parser name (renderer-specific). + reasoning_parser: Reasoning parser name (renderer-specific). + enable_thinking: Let the model emit reasoning. + preserve_all_thinking: Keep historical reasoning in future prompts. + preserve_thinking_between_tool_calls: Keep reasoning during tool loops. + + Example: + + renderer = RendererConfig(name="qwen3").build(tokenizer_path="./Qwen3-0.6B") + prompt_ids = renderer.render_ids( + [{"role": "user", "content": "hi"}], add_generation_prompt=True + ) + """ + + name: str = "auto" + tool_parser: str | None = None + reasoning_parser: str | None = None + enable_thinking: bool = True + preserve_all_thinking: bool = False + preserve_thinking_between_tool_calls: bool = False + + def build( + self, *, tokenizer_path: str, model_spec: "ModelSpec | None" = None + ) -> Renderer: + # TODO(renderers#70): use TorchTitan's tokenizer once `renderers` supports + # bring-your-own-tokenizer (PR adds a Tokenizer protocol; drops transformers). + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + renderer_name = _resolve_renderer_name(self.name, model_spec) + return create_renderer(tokenizer, _build_renderer_config(renderer_name, self)) diff --git a/torchtitan/experiments/rl/rollouts/__init__.py b/torchtitan/experiments/rl/rollouts/__init__.py new file mode 100644 index 0000000000..174117aab6 --- /dev/null +++ b/torchtitan/experiments/rl/rollouts/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.experiments.rl.rollouts.types import ( + Rollout, + RolloutGroup, + RolloutStatus, + RolloutTurn, +) +from torchtitan.experiments.rl.rollouts.utils import ( + last_assistant_text, + prepare_rollout_metrics, + rollout_to_episode, +) + +__all__ = [ + "Rollout", + "RolloutGroup", + "RolloutStatus", + "RolloutTurn", + "last_assistant_text", + "prepare_rollout_metrics", + "rollout_to_episode", +] diff --git a/torchtitan/experiments/rl/rollouts/types.py b/torchtitan/experiments/rl/rollouts/types.py new file mode 100644 index 0000000000..9fce2c0347 --- /dev/null +++ b/torchtitan/experiments/rl/rollouts/types.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import StrEnum + +from renderers import Message + + +_TRUNCATED = frozenset({"truncated_length", "truncated_prompt_too_long"}) +_ERROR = frozenset({"error_parse", "error_timeout", "error_abort", "error"}) + + +class RolloutStatus(StrEnum): + """Per-rollout status.""" + + ONGOING = "ongoing" + COMPLETED = "completed" + TRUNCATED_LENGTH = "truncated_length" + TRUNCATED_PROMPT_TOO_LONG = "truncated_prompt_too_long" + ERROR_PARSE = "error_parse" + ERROR_TIMEOUT = "error_timeout" + ERROR_ABORT = "error_abort" + ERROR = "error" + + def is_truncated(self) -> bool: + return self.value in _TRUNCATED + + def is_error(self) -> bool: + return self.value in _ERROR + + def is_terminal(self) -> bool: + return self is not RolloutStatus.ONGOING + + +@dataclass(kw_only=True, slots=True) +class RolloutTurn: + """One generator completion + the env response to that completion.""" + + # TODO: add a `logs` field (raw prompt/response text, finish_reason, timings) + # so a turn can be dumped and inspected without re-deriving from tokens. + + # Fields needed for training + prompt_token_ids: list[int] # [L_prompt] + """Tokenized conversation up to this turn, used to generate the assistant response.""" + + assistant_token_ids: list[int] # [L_response] + """Tokens the assistant produced this turn.""" + + assistant_logprobs: list[float] # [L_response] + """Per-token logprobs from the generator policy for the assistant tokens.""" + + # Filtering / debugging + policy_version: int + """Trainer policy version when this response was sampled.""" + + # Logging + prompt_messages: list[Message] = field(default_factory=list) # [M_prompt] + """Full conversation up to this turn, used to generate the assistant response.""" + + assistant_message: Message | None = None + """The assistant's message (generator output, parsed).""" + + env_messages: list[Message] = field(default_factory=list) # [M_env] + """The env's reply messages this turn (tool / user).""" + + # For rubrics + env_rewards: dict[str, float] = field(default_factory=dict) + """Optional per-turn reward signals the env attached; the rubric decides how to use them.""" + + +@dataclass(kw_only=True, slots=True) +class Rollout: + """A complete rollout: ordered turns + terminal state + reward + identifier.""" + + # TODO: add a `logs` field (per-turn debug records / event trace) to make a + # full rollout reconstructable for debugging. + + group_id: str + """Prompt-group ID; siblings share it for advantage centering.""" + + sample_idx: int + """Sample index within the group (0..group_size-1).""" + + turns: list[RolloutTurn] = field(default_factory=list) # [K_turns] + """Ordered rollout turns.""" + + status: RolloutStatus = RolloutStatus.COMPLETED + """Rollout-level terminal status.""" + + reward: float | None = None + """Final weighted reward, filled by the rubric.""" + + reward_breakdown: dict[str, float] = field(default_factory=dict) + """Raw per-reward-function values, filled by the rubric.""" + + # TODO: make it per token + advantage: float | None = None + """Advantage for this sample.""" + + +@dataclass(kw_only=True, slots=True) +class RolloutGroup: + group_id: str + """Prompt-group ID; siblings share it for advantage centering.""" + + env_input: object + """The env input (dataset payload) shared by the group; passed to the rubric.""" + + rollouts: list[Rollout] # [group_size] + """Sibling rollouts sampled from the group's shared prompt.""" diff --git a/torchtitan/experiments/rl/rollouts/utils.py b/torchtitan/experiments/rl/rollouts/utils.py new file mode 100644 index 0000000000..b7ba67a6e8 --- /dev/null +++ b/torchtitan/experiments/rl/rollouts/utils.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import defaultdict + +from torchtitan.experiments.rl.observability import metrics as m +from torchtitan.experiments.rl.rollouts.types import Rollout +from torchtitan.experiments.rl.types import Episode + + +def last_assistant_text(rollout: Rollout) -> str: + """Return the assistant message text from the last turn, or `""`.""" + if not rollout.turns: + return "" + msg = rollout.turns[-1].assistant_message + return (msg.get("content") or "") if msg else "" + + +def rollout_to_episode(rollout: Rollout) -> Episode: + """Flatten a scored single-turn `Rollout` into an `Episode`, a class + that holds only the information needed for training. + """ + # TODO: support multi-turn rollout flattening. + # TODO: rename Episode -> TrainingSample / rollout_to_episode -> + # rollout_to_training_sample (consistent with TrainingBatch). + if len(rollout.turns) != 1: + raise ValueError( + f"rollout_to_episode expects exactly one turn; got {len(rollout.turns)}." + ) + turn = rollout.turns[0] + return Episode( + policy_version=turn.policy_version, + prompt_idx=rollout.sample_idx, + prompt_token_ids=turn.prompt_token_ids, + text=last_assistant_text(rollout), + token_ids=turn.assistant_token_ids, + token_logprobs=turn.assistant_logprobs, + reward=rollout.reward, + advantage=rollout.advantage if rollout.advantage is not None else 0.0, + ) + + +def prepare_rollout_metrics(prefix: str, rollouts: list[Rollout]) -> list[m.Metric]: + """Build rollout-derived metrics (lengths, truncation, reward breakdown). + + Args: + prefix: Metric namespace (e.g. `"rollout"` or `"validation"`). + rollouts: Rollouts to summarize. + """ + # Lengths, truncation, reward + # TODO: adapt for multi-turn rollouts + response_lens = [len(t.assistant_token_ids) for r in rollouts for t in r.turns] + prompt_lens = [len(r.turns[0].prompt_token_ids) for r in rollouts if r.turns] + total_lens = [ + len(r.turns[-1].prompt_token_ids) + len(r.turns[-1].assistant_token_ids) + for r in rollouts + if r.turns + ] + + truncated = [float(r.status.is_truncated()) for r in rollouts] + rewards = [r.reward for r in rollouts if r.reward is not None] + + out: list[m.Metric] = [ + m.Metric(f"{prefix}/response_length", m.Mean.from_list(response_lens)), + m.Metric(f"{prefix}/response_length", m.Max.from_list(response_lens)), + m.Metric(f"{prefix}/prompt_length", m.Mean.from_list(prompt_lens)), + m.Metric(f"{prefix}/prompt_length", m.Max.from_list(prompt_lens)), + m.Metric(f"{prefix}/total_length", m.Mean.from_list(total_lens)), + m.Metric(f"{prefix}/total_length", m.Max.from_list(total_lens)), + m.Metric(f"{prefix}/truncation_rate", m.Mean.from_list(truncated)), + m.Metric(f"{prefix}_reward", m.SummaryStats.from_list(rewards)), + ] + + # Per-component reward breakdown + values_by_name: dict[str, list[float]] = defaultdict(list) + for rollout in rollouts: + for name, value in rollout.reward_breakdown.items(): + values_by_name[name].append(float(value)) + out.extend( + m.Metric(f"{prefix}_reward/component/{name}", m.Mean.from_list(values)) + for name, values in sorted(values_by_name.items()) + ) + return out diff --git a/torchtitan/experiments/rl/rubrics/__init__.py b/torchtitan/experiments/rl/rubrics/__init__.py new file mode 100644 index 0000000000..0873a97141 --- /dev/null +++ b/torchtitan/experiments/rl/rubrics/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.experiments.rl.rubrics.rubric import Reward, RewardFn, Rubric + +__all__ = ["Reward", "RewardFn", "Rubric"] diff --git a/torchtitan/experiments/rl/rubrics/rubric.py b/torchtitan/experiments/rl/rubrics/rubric.py new file mode 100644 index 0000000000..c8792c076c --- /dev/null +++ b/torchtitan/experiments/rl/rubrics/rubric.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import abc +import asyncio +from dataclasses import dataclass, field + +from torchtitan.config import Configurable +from torchtitan.experiments.rl.rollouts.types import Rollout +from torchtitan.observability import structured_logger as sl + + +class RewardFn(Configurable, abc.ABC): + """A single reward function, as a Configurable callable. + + Subclass and implement `__call__`. Its `Config` carries the `weight` used in + the rubric's weighted sum, plus any args a stateful reward fn needs (a reward + model path, an LLM-judge endpoint, a threshold, ...). + + Example: + class RewardCorrect(RewardFn): + @dataclass(kw_only=True, slots=True) + class Config(RewardFn.Config): + pass # only needs `weight` + + async def __call__(self, rollout, env_input) -> float: + ... + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + weight: float = 1.0 + """Relative weight in the rubric's weighted sum (normalized across fns).""" + + def __init__(self, config: Config) -> None: + self.weight = config.weight + + @abc.abstractmethod + async def __call__(self, rollout: Rollout, env_input: object) -> float: + """Return this fn's score for one rollout. + + Args: + rollout: Rollout to score. + env_input: Dataset payload used to build the env (target/metadata). + """ + + +@dataclass(frozen=True, kw_only=True, slots=True) +class Reward: + """One rollout's reward + per-reward-fn breakdown. + + Example: + >>> Reward(reward=0.5, reward_breakdown={"RewardCorrect": 1.0, "RewardFormat": 0.0}) + """ + + reward: float + """Final weighted reward for this rollout; the only field the loss uses.""" + + reward_breakdown: dict[str, float] = field(default_factory=dict) + """Raw per-reward-fn outputs (unweighted), keyed by reward-fn class name, + used to compute the scalar `reward`.""" + + +class Rubric(Configurable): + """Holds reward functions and scores rollouts with their weighted sum. + + Reward fns and their weights live in config (`reward_fns`), so a rubric is + just configured for common cases— no subclass needed. Subclass and override `score_group` + for cross-sibling scoring (pairwise comparison, diversity, rank normalization). + + Setting `truncation_reward` / `error_reward` short-circuits the reward fns for + rollouts whose status is truncated / errored. + + Example: + rubric = Rubric.Config( + reward_fns=[RewardCorrect.Config(weight=1.0), RewardFormat.Config(weight=0.3)], + truncation_reward=0.0, + ).build() + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + reward_fns: list[RewardFn.Config] = field(default_factory=list) + """The rubric's reward fns + weights; built and weight-normalized at init.""" + + truncation_reward: float | None = None + """Reward to assign a truncated rollout. If provided, reward fns are skipped. + If `None`, reward fns run for partial credit.""" + + error_reward: float | None = None + """Reward to assign an errored rollout. If provided, reward fns are skipped. + If `None`, reward fns run for partial credit.""" + + def __init__(self, config: Config) -> None: + self._config = config + self._reward_fns = [rwd_cfg.build() for rwd_cfg in config.reward_fns] + + # Sanity checks + if not self._reward_fns: + raise ValueError("Rubric.Config.reward_fns must not be empty") + names = [type(fn).__name__ for fn in self._reward_fns] + if len(names) != len(set(names)): + raise ValueError(f"reward fn names must be unique; got {names}") + self._weight_sum = sum(fn.weight for fn in self._reward_fns) + if self._weight_sum <= 0: + raise ValueError( + f"rubric weights must sum to a positive value; got {self._weight_sum}" + ) + + @sl.log_trace_span("score_single_rollout") + async def _score_single_rollout( + self, rollout: Rollout, env_input: object + ) -> Reward: + """Score one rollout. Short-circuits to `truncation_reward` / + `error_reward` when those are set and the rollout truncated / errored. + + Args: + rollout: Rollout to score. + env_input: Dataset payload used to build the env (target/metadata). + + Returns: + Final weighted reward + per-fn raw breakdown. + """ + # Short-circuit on truncate / error and return the configured reward. + cfg = self._config + if cfg.truncation_reward is not None and rollout.status.is_truncated(): + return Reward( + reward=cfg.truncation_reward, + components={"truncated": cfg.truncation_reward}, + ) + if cfg.error_reward is not None and rollout.status.is_error(): + return Reward( + reward=cfg.error_reward, + components={"errored": cfg.error_reward}, + ) + + # Run all reward fns and weight-sum (weights normalized to sum to 1.0). + per_fn_rewards = await asyncio.gather( + *(fn(rollout, env_input) for fn in self._reward_fns) + ) + + reward_breakdown = {} + total_reward = 0.0 + for fn, r in zip(self._reward_fns, per_fn_rewards, strict=True): + reward_breakdown[type(fn).__name__] = r + total_reward += (fn.weight / self._weight_sum) * r + + return Reward(reward=total_reward, reward_breakdown=reward_breakdown) + + @sl.log_trace_span("score_group") + async def score_group( + self, + rollouts: list[Rollout], + env_input: object, + ) -> list[Reward]: + """Score every rollout in one prompt group. + + Override for cross-rollout rewards (pairwise comparison, diversity, + rank normalization). + + Args: + rollouts: Sibling rollouts sampled from one prompt group. + env_input: Dataset payload originally used to construct the rollout env. + + Returns: + One `Reward` per rollout, in input order. + """ + return await asyncio.gather( + *(self._score_single_rollout(r, env_input) for r in rollouts) + ) diff --git a/torchtitan/experiments/rl/sum_digits.py b/torchtitan/experiments/rl/sum_digits.py deleted file mode 100644 index 9b6d4b475f..0000000000 --- a/torchtitan/experiments/rl/sum_digits.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import random -import re -from dataclasses import dataclass - -from torchtitan.config import Configurable -from torchtitan.experiments.rl.types import Step - - -class SumDigitsEnv(Configurable): - """Single-turn, single-use env for one sum-of-digits problem. - - Construct via ``SumDigitsEnv.Config(seed=...).build(step=, group_idx=)``. - The problem is a pure function of ``(config.seed, step, group_idx)``: same - inputs always produce the same prompt and target. No RNG state is shared - between envs. - """ - - @dataclass(kw_only=True, slots=True) - class Config(Configurable.Config): - correctness_reward: float = 1.0 - """Reward for a response containing ``[ANSWER] ``.""" - - format_reward: float = 0.3 - """Reward bonus for any ``[ANSWER] `` tag in the response.""" - - seed: int = 42 - """Seed mixed with ``(step, group_idx)`` to deterministically generate problems.""" - - SYSTEM_PROMPT = """\ -You are a helpful assistant. Solve the problem step by step. -When you have your final answer, state it as [ANSWER] . - -Example: -User: What is the total digit sum of [12, 345, 67]? -Assistant: Break each number into digits: -12 → 1, 2 -345 → 3, 4, 5 -67 → 6, 7 -Sum all digits: 1 + 2 + 3 + 4 + 5 + 6 + 7 = 28 -[ANSWER] 28""" - - def __init__(self, config: Config, *, step: int = 0, group_idx: int = 0): - self._config = config - rng = random.Random(f"{config.seed}:{step}:{group_idx}") - n = rng.randint(2, 4) - numbers = [rng.randint(10, 99) for _ in range(n)] - self._target = sum(int(d) for num in numbers for d in str(num)) - question = f"What is the total digit sum of {numbers}?" - self.prompt = f"{self.SYSTEM_PROMPT}\n\n{question}" - - def step(self, completion: str) -> Step: - return Step( - rewards={ - "correctness": self._correctness_reward(completion), - "format": self._format_reward(completion), - }, - done=True, - ) - - def _correctness_reward(self, completion: str) -> float: - matches = re.findall(r"\[ANSWER\]\s*(-?\d+)", completion) - correct = bool(matches) and int(matches[-1]) == self._target - return self._config.correctness_reward if correct else 0.0 - - def _format_reward(self, completion: str) -> float: - if re.search(r"\[ANSWER\]\s*-?\d+", completion): - return self._config.format_reward - return 0.0 diff --git a/torchtitan/experiments/rl/tasks/__init__.py b/torchtitan/experiments/rl/tasks/__init__.py new file mode 100644 index 0000000000..69647afc39 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.experiments.rl.tasks.task import Task + +__all__ = ["Task"] diff --git a/torchtitan/experiments/rl/tasks/sum_digits/__init__.py b/torchtitan/experiments/rl/tasks/sum_digits/__init__.py new file mode 100644 index 0000000000..6d7172c811 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/sum_digits/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.experiments.rl.tasks.sum_digits.data import ( + SumDigitsDataset, + SumDigitsInput, +) +from torchtitan.experiments.rl.tasks.sum_digits.env import SumDigitsEnv +from torchtitan.experiments.rl.tasks.sum_digits.rubric import ( + RewardCorrect, + RewardFormat, +) +from torchtitan.experiments.rl.tasks.sum_digits.task import SumDigitsTask + +__all__ = [ + "RewardCorrect", + "RewardFormat", + "SumDigitsDataset", + "SumDigitsEnv", + "SumDigitsInput", + "SumDigitsTask", +] diff --git a/torchtitan/experiments/rl/tasks/sum_digits/data.py b/torchtitan/experiments/rl/tasks/sum_digits/data.py new file mode 100644 index 0000000000..3b46ba5c32 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/sum_digits/data.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import random +from dataclasses import dataclass + +from torchtitan.config import Configurable + + +@dataclass(frozen=True, kw_only=True, slots=True) +class SumDigitsInput: + """Typed payload for one SumDigits problem.""" + + numbers: list[int] # [N_numbers] + """Numbers the model must digit-sum.""" + + target: int + """Ground-truth total digit sum.""" + + +class SumDigitsDataset(Configurable): + """Stateful, seeded RNG dataset of SumDigits problems. + + Example: + + ds = SumDigitsDataset(SumDigitsDataset.Config(seed=42)) + ex = ds.sample_example() + # ex.env_input is a SumDigitsInput + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + seed: int = 42 + + def __init__(self, config: Config) -> None: + self._rng = random.Random(config.seed) + + def sample_example(self) -> SumDigitsInput: + """Sample one SumDigits problem.""" + n = self._rng.randint(2, 4) + numbers = [self._rng.randint(10, 99) for _ in range(n)] + target = sum(int(d) for num in numbers for d in str(num)) + return SumDigitsInput(numbers=numbers, target=target) diff --git a/torchtitan/experiments/rl/tasks/sum_digits/env.py b/torchtitan/experiments/rl/tasks/sum_digits/env.py new file mode 100644 index 0000000000..ae8d4611d0 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/sum_digits/env.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from renderers import Message + +from torchtitan.experiments.rl.env_types import ( + MessageEnv, + MessageResetOutput, + MessageStepOutput, +) +from torchtitan.experiments.rl.tasks.sum_digits.data import SumDigitsInput + + +SYSTEM_PROMPT = """\ +You are a helpful assistant. Solve the problem step by step. +When you have your final answer, state it as [ANSWER] . + +Example: +User: What is the total digit sum of [12, 345, 67]? +Assistant: Break each number into digits: +12 -> 1, 2 +345 -> 3, 4, 5 +67 -> 6, 7 +Sum all digits: 1 + 2 + 3 + 4 + 5 + 6 + 7 = 28 +[ANSWER] 28""" + + +class SumDigitsEnv(MessageEnv): + def __init__(self, *, env_input: SumDigitsInput) -> None: + self._numbers = env_input.numbers + + async def reset(self) -> MessageResetOutput: + """Return the system prompt and one SumDigits user question.""" + question = f"What is the total digit sum of {self._numbers}?" + return MessageResetOutput( + prompt_messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": question}, + ] + ) + + async def step(self, assistant_message: Message) -> MessageStepOutput: + # Single-turn env: end after the assistant's first message. + return MessageStepOutput(done=True) diff --git a/torchtitan/experiments/rl/tasks/sum_digits/rubric.py b/torchtitan/experiments/rl/tasks/sum_digits/rubric.py new file mode 100644 index 0000000000..388a843094 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/sum_digits/rubric.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import re +from dataclasses import dataclass + +from torchtitan.experiments.rl.rollouts import last_assistant_text, Rollout +from torchtitan.experiments.rl.rubrics import RewardFn + +from torchtitan.experiments.rl.tasks.sum_digits.data import SumDigitsInput + + +_ANSWER_RE = re.compile(r"\[ANSWER\]\s*(-?\d+)") +_FORMAT_RE = re.compile(r"\[ANSWER\]\s*-?\d+") + + +class RewardCorrect(RewardFn): + """`1.0` if the last `[ANSWER] ` equals the target, else `0.0`.""" + + @dataclass(kw_only=True, slots=True) + class Config(RewardFn.Config): + pass + + async def __call__(self, rollout: Rollout, env_input: SumDigitsInput) -> float: + text = last_assistant_text(rollout) + matches = _ANSWER_RE.findall(text) + if not matches: + return 0.0 + return 1.0 if int(matches[-1]) == env_input.target else 0.0 + + +class RewardFormat(RewardFn): + """`1.0` if the response contains `[ANSWER] `, else `0.0`.""" + + @dataclass(kw_only=True, slots=True) + class Config(RewardFn.Config): + pass + + async def __call__(self, rollout: Rollout, env_input: object) -> float: + return 1.0 if _FORMAT_RE.search(last_assistant_text(rollout)) else 0.0 + + +__all__ = ["RewardCorrect", "RewardFormat"] diff --git a/torchtitan/experiments/rl/tasks/sum_digits/task.py b/torchtitan/experiments/rl/tasks/sum_digits/task.py new file mode 100644 index 0000000000..57064d6144 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/sum_digits/task.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass, field + +from renderers import Renderer + +from torchtitan.experiments.rl.env_types import RendererWrapperEnv +from torchtitan.experiments.rl.rubrics import Rubric +from torchtitan.experiments.rl.tasks import Task +from torchtitan.experiments.rl.tasks.sum_digits.data import ( + SumDigitsDataset, + SumDigitsInput, +) +from torchtitan.experiments.rl.tasks.sum_digits.env import SumDigitsEnv +from torchtitan.experiments.rl.tasks.sum_digits.rubric import ( + RewardCorrect, + RewardFormat, +) + + +class SumDigitsTask(Task): + """SumDigits task: have the model sum a sequence of digits.""" + + @dataclass(kw_only=True, slots=True) + class Config(Task.Config): + train_dataset: SumDigitsDataset.Config = field( + default_factory=lambda: SumDigitsDataset.Config(seed=42) + ) + val_dataset: SumDigitsDataset.Config = field( + default_factory=lambda: SumDigitsDataset.Config(seed=99) + ) + rubric: Rubric.Config = field( + default_factory=lambda: Rubric.Config( + reward_fns=[ + RewardCorrect.Config(weight=1.0), + RewardFormat.Config(weight=0.3), + ] + ) + ) + env_config: RendererWrapperEnv.Config = field( + default_factory=RendererWrapperEnv.Config + ) + """Renderer-wrapper limits, e.g. `max_rollout_tokens`.""" + + def __init__(self, config: Config) -> None: + super().__init__(config) # builds self.rubric from config.rubric + self._train_dataset = config.train_dataset.build() + self._val_dataset = config.val_dataset.build() + self._env_config = config.env_config + + def sample_train_example(self) -> SumDigitsInput: + return self._train_dataset.sample_example() + + def sample_val_example(self) -> SumDigitsInput: + return self._val_dataset.sample_example() + + def make_envs( + self, + *, + example: SumDigitsInput, + group_size: int, + renderer: Renderer, + ) -> list[RendererWrapperEnv]: + """Construct SumDigits envs for one prompt group.""" + return [ + RendererWrapperEnv( + message_env=SumDigitsEnv(env_input=example), + renderer=renderer, + config=self._env_config, + ) + for _ in range(group_size) + ] diff --git a/torchtitan/experiments/rl/tasks/task.py b/torchtitan/experiments/rl/tasks/task.py new file mode 100644 index 0000000000..ee808c2728 --- /dev/null +++ b/torchtitan/experiments/rl/tasks/task.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import abc +from dataclasses import dataclass + +from renderers import Renderer + +from torchtitan.config import Configurable +from torchtitan.experiments.rl.env_types import RendererWrapperEnv +from torchtitan.experiments.rl.rollouts.types import Rollout +from torchtitan.experiments.rl.rubrics import Reward, Rubric + + +# TODO(continuous-batching): when VLLMGenerator gains continuous batching, +# move the rollout loop onto Task as `run_rollout(example, client) -> Rollout`, +# so each rollout drives its own generate calls. +class Task(Configurable, abc.ABC): + """Bundles everything needed to run and score one rollout: a dataset, + how to build its envs, and a `Rubric`. + + The flow for one prompt group: + + sample = task.sample_train_example() # the env input from the task's dataset + envs = task.make_envs(sample, group_size, renderer) # MessageEnvs wrapped in RendererWrapperEnv + run_rollout_fn(sampler, envs) # the controller runs the rollout loop + rewards = task.score_group(rollouts, sample.env_input) # the Rubric scores them + + `MessageEnv` works in messages; `RendererWrapperEnv` (what `make_envs` returns) + adds the message <-> token plumbing. `score_group` defaults to per-rollout + `rubric.score_group`; override it for cross-sibling scoring. + + Example: + class MyTask(Task): + @dataclass(kw_only=True, slots=True) + class Config(Task.Config): + train_dataset: MyDataset.Config = field(default_factory=MyDataset.Config) + val_dataset: MyDataset.Config = field(default_factory=MyDataset.Config) + rubric: MyRubric.Config = field(default_factory=MyRubric.Config) + env_config: RendererWrapperEnv.Config = field( + default_factory=RendererWrapperEnv.Config + ) + + def __init__(self, config: Config) -> None: + super().__init__(config) # builds self.rubric from config.rubric + self._train = config.train_dataset.build() + self._val = config.val_dataset.build() + self._env_config = config.env_config + + def sample_train_example(self) -> MyInput: + return self._train.sample_example() + + def sample_val_example(self) -> MyInput: + return self._val.sample_example() + + def make_envs(self, *, example, group_size, renderer): + return [RendererWrapperEnv(...) for _ in range(group_size)] + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + rubric: Rubric.Config + + rubric: Rubric + """Built from `config.rubric` by the base `__init__`; used by `score_group`.""" + + def __init__(self, config: Config) -> None: + self.rubric = config.rubric.build() + + @abc.abstractmethod + def sample_train_example(self) -> object: + """Sample one training example (the env input) from this task's dataset.""" + + @abc.abstractmethod + def sample_val_example(self) -> object: + """Sample one validation example (the env input) from this task's dataset.""" + + # TODO: revisit the Renderer being injected into `make_envs` once we + # know whether Task should own a Renderer (per-task chat templates). + @abc.abstractmethod + def make_envs( + self, + *, + example: object, + group_size: int, + renderer: Renderer, + ) -> list[RendererWrapperEnv]: + """Construct `group_size` single-use envs from one dataset example. + + Args: + example: the env input from `sample_train_example` / `sample_val_example`. + group_size: number of sibling envs for this prompt group. + renderer: Renderer shared by the rollout controller. + + Returns: + `group_size` `RendererWrapperEnv` instances, each ready for one rollout. + """ + + async def score_group( + self, + rollouts: list[Rollout], + env_input: object, + ) -> list[Reward]: + """Score one group's rollouts; the controller applies the rewards. + + Default impl delegates to `self.rubric.score_group`. Override for + cross-sibling scoring (judge, pairwise, diversity) or partial-credit + reward shaping. + + Args: + rollouts: Sibling rollouts in one prompt group, already stepped. + env_input: Dataset payload shared by the group. + + Returns: + One `Reward` per rollout, in input order. + """ + return await self.rubric.score_group(rollouts, env_input) diff --git a/torchtitan/experiments/rl/tests/test_grpo_metrics.py b/torchtitan/experiments/rl/tests/test_grpo_metrics.py deleted file mode 100644 index 1d3011b33c..0000000000 --- a/torchtitan/experiments/rl/tests/test_grpo_metrics.py +++ /dev/null @@ -1,662 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unit tests for RL metric helpers + controller-subroutine outputs. - -These tests do **not** start Monarch, vLLM, W&B, or distributed -process groups. Controller subroutines are invoked as static / instance -methods on plain dataclasses. -""" - -from __future__ import annotations - -import math -from unittest.mock import MagicMock, patch - -import pytest - -import torch - -from torchtitan.experiments.rl.grpo import _prepare_reward_metrics, GRPOLoss, RLTrainer -from torchtitan.experiments.rl.observability import metrics as m -from torchtitan.experiments.rl.types import Completion, Step, Trajectory - - -# --------------------------------------------------------------------------- -# _prepare_reward_metrics -# --------------------------------------------------------------------------- - - -def _step(rewards: dict[str, float]) -> Step: - return Step(rewards=rewards, done=True) - - -def _reward_trajectory(rewards: dict[str, float], sample_idx: int = 0) -> Trajectory: - """Single-turn trajectory with a fake completion + the given rewards.""" - fake_completion = Completion( - policy_version=0, - prompt_idx=sample_idx, - text="", - token_ids=[], - token_logprobs=[], - ) - return Trajectory( - sample_idx=sample_idx, - prompt_token_ids=[], - transitions=[(fake_completion, _step(rewards))], - ) - - -class TestBuildRewardMetrics: - def test_one_metric_per_observed_name(self) -> None: - trajectories = [ - _reward_trajectory({"correctness": 1.0, "format": 0.5}, sample_idx=0), - _reward_trajectory({"correctness": 0.0, "format": 1.0}, sample_idx=1), - ] - metrics = _prepare_reward_metrics("reward/component", trajectories) - keys = {entry.key for entry in metrics} - assert keys == { - "reward/component/correctness", - "reward/component/format", - } - for entry in metrics: - assert isinstance(entry.value, m.Mean) - - def test_components_observed_in_some_trajectories_only(self) -> None: - # `format` only appears in the second trajectory — it should - # average over that one entry (no zero-fill). - trajectories = [ - _reward_trajectory({"correctness": 1.0}, sample_idx=0), - _reward_trajectory({"format": 0.5}, sample_idx=1), - ] - metrics = _prepare_reward_metrics("reward/component", trajectories) - agg = m.MetricsProcessor._aggregate_metrics(metrics) - assert agg["reward/component/correctness/mean"] == 1.0 - assert agg["reward/component/format/mean"] == 0.5 - - def test_empty_input(self) -> None: - assert _prepare_reward_metrics("reward/component", []) == [] - - def test_prefix_controls_namespace(self) -> None: - trajectories = [_reward_trajectory({"correctness": 1.0}, sample_idx=0)] - metrics = _prepare_reward_metrics("validation/reward/component", trajectories) - assert metrics[0].key == "validation/reward/component/correctness" - - -# --------------------------------------------------------------------------- -# Controller subroutines: _collect_rollouts, _build_episodes, validate -# --------------------------------------------------------------------------- - - -def _completion( - prompt_idx: int, - response_len: int, - finish_reason: str | None = "stop", - *, - policy_version: int = 0, -) -> Completion: - return Completion( - policy_version=policy_version, - prompt_idx=prompt_idx, - text="x" * response_len, - token_ids=list(range(response_len)), - token_logprobs=[0.0] * response_len, - finish_reason=finish_reason, - ) - - -class _FakeEnv: - """Minimal env stub: step(text) returns a preset reward dict.""" - - def __init__(self, rewards: dict[str, float], prompt: str = "p"): - self.prompt = prompt - self._rewards = rewards - - def step(self, text: str) -> Step: - return _step(self._rewards) - - -def _build_collect_rollouts_inputs(self_obj): - """Build a hollow RLTrainer instance + completions wired into the - fake generator — without spawning real meshes. - """ - - completions = [ - _completion(prompt_idx=0, response_len=10), - _completion(prompt_idx=0, response_len=20), - _completion(prompt_idx=1, response_len=15), - ] - - class _RewardEnvBuilder: - @staticmethod - def build(*, step, group_idx): - return _FakeEnv({"correctness": float(group_idx), "format": 0.5}) - - self_obj.config = MagicMock() - self_obj.config.env = _RewardEnvBuilder - self_obj.tokenizer = MagicMock() - self_obj.tokenizer.encode.side_effect = lambda prompt, **_: [ord(prompt)] - self_obj.generator = MagicMock() - # `_get_rank_0_value` is the layer that strips Monarch's ValueMesh, - # so just make it return whatever it's handed. - self_obj._get_rank_0_value = lambda value, has_gpus=True: (completions, []) - return completions - - -class TestCollectRollouts: - def test_passes_token_ids_to_generator(self) -> None: - """Controller tokenizes env prompts and hands the IDs (not strings) - to ``generator.generate.call``.""" - controller = RLTrainer.__new__(RLTrainer) - _build_collect_rollouts_inputs(controller) - controller._collect_rollouts(num_groups=2, step=0) - # _FakeEnv.prompt == "p"; encode side_effect returns [ord(prompt)] = [112]. - controller.generator.generate.call.assert_called_once_with([[112], [112]]) - - def test_emits_expected_metric_keys(self) -> None: - controller = RLTrainer.__new__(RLTrainer) - completions = _build_collect_rollouts_inputs(controller) - trajectories, rollout_metrics = controller._collect_rollouts( - num_groups=2, step=0 - ) - assert len(trajectories) == len(completions) - agg = m.MetricsProcessor._aggregate_metrics(rollout_metrics) - # Length keys: Mean+Max for prompt/response, Max-only for total. - assert "rollout/response_length/mean" in agg - assert "rollout/response_length/max" in agg - assert "rollout/prompt_length/mean" in agg - assert "rollout/prompt_length/max" in agg - assert "rollout/total_length/max" in agg - # Reward-component keys derived from env step output (now under - # the top-level reward/ namespace). - assert "reward/component/correctness/mean" in agg - assert "reward/component/format/mean" in agg - - def test_truncation_rate(self) -> None: - """rollout/truncation_rate averages - finish_reason == 'length' over completions.""" - controller = RLTrainer.__new__(RLTrainer) - completions = [ - _completion(0, 10, finish_reason="length"), - _completion(0, 10, finish_reason="stop"), - _completion(1, 10, finish_reason="length"), - _completion(1, 10, finish_reason="length"), - ] - - controller.config = MagicMock() - controller.config.env = MagicMock() - controller.config.env.build = lambda *, step, group_idx: _FakeEnv({"r": 1.0}) - controller.tokenizer = MagicMock() - controller.tokenizer.encode.side_effect = lambda prompt, **_: [ord(prompt)] - controller.generator = MagicMock() - controller._get_rank_0_value = lambda value, has_gpus=True: (completions, []) - - _, rollout_metrics = controller._collect_rollouts(num_groups=2, step=0) - agg = m.MetricsProcessor._aggregate_metrics(rollout_metrics) - # 3 of 4 completions hit max_tokens. - assert agg["rollout/truncation_rate/mean"] == pytest.approx(0.75) - - def test_total_length_uses_per_episode_max(self) -> None: - """rollout/total_length/max must be max(prompt+response per - episode), **not** max(prompt) + max(response) — the latter - may combine two different episodes.""" - controller = RLTrainer.__new__(RLTrainer) - - # Carefully chosen so per-side maxes don't align: the longest - # prompt has the shortest response, etc. The tokenizer mock below - # produces token lists of length == len(prompt), so the env prompts - # control prompt-side lengths. - env_prompts = {0: "x" * 10, 1: "x" * 2, 2: "x" * 5} - completions = [ - _completion(prompt_idx=0, response_len=2), # total 10 + 2 = 12 - _completion(prompt_idx=1, response_len=10), # total 2 + 10 = 12 - _completion(prompt_idx=2, response_len=5), # total 5 + 5 = 10 - ] - controller.config = MagicMock() - controller.config.env = MagicMock() - controller.config.env.build = lambda *, step, group_idx: _FakeEnv( - {"r": 1.0}, prompt=env_prompts[group_idx] - ) - controller.tokenizer = MagicMock() - controller.tokenizer.encode.side_effect = lambda prompt, **_: list( - prompt.encode() - ) - controller.generator = MagicMock() - controller._get_rank_0_value = lambda value, has_gpus=True: (completions, []) - - trajectories, rollout_metrics = controller._collect_rollouts( - num_groups=3, step=0 - ) - agg = m.MetricsProcessor._aggregate_metrics(rollout_metrics) - per_side_max_sum = max(len(t.prompt_token_ids) for t in trajectories) + max( - len(c.token_ids) for c in completions - ) # = 10 + 10 = 20 - actual_max = max( - len(t.prompt_token_ids) + len(c.token_ids) - for t, c in zip(trajectories, completions, strict=True) - ) # = 12 - assert agg["rollout/total_length/max"] == actual_max - assert agg["rollout/total_length/max"] < per_side_max_sum - - def test_group_offset_keeps_rollout_rounds_distinct(self) -> None: - controller = RLTrainer.__new__(RLTrainer) - _build_collect_rollouts_inputs(controller) - - first_round, _ = controller._collect_rollouts( - num_groups=2, step=0, group_offset=0 - ) - second_round, _ = controller._collect_rollouts( - num_groups=2, step=0, group_offset=2 - ) - - assert [t.sample_idx for t in first_round] == [0, 0, 1] - assert [t.sample_idx for t in second_round] == [2, 2, 3] - - episodes, _ = RLTrainer._build_episodes(first_round + second_round) - assert all(ep.advantage == 0.0 for ep in episodes) - - -def _trajectory( - sample_idx: int, - prompt_len: int, - response_len: int, - reward: float, - *, - policy_version: int = 0, -) -> Trajectory: - completion = _completion(sample_idx, response_len, policy_version=policy_version) - return Trajectory( - sample_idx=sample_idx, - prompt_token_ids=list(range(prompt_len)), - transitions=[(completion, _step({"r": reward}))], - ) - - -class TestBuildEpisodes: - def test_emits_expected_metric_keys(self) -> None: - trajectories = [ - _trajectory(0, 4, 5, reward=1.0), - _trajectory(0, 4, 5, reward=0.0), - _trajectory(1, 4, 5, reward=0.5), - _trajectory(1, 4, 5, reward=0.5), - ] - episodes, episode_metrics = RLTrainer._build_episodes(trajectories) - assert len(episodes) == 4 - agg = m.MetricsProcessor._aggregate_metrics(episode_metrics) - # SummaryStats expansion (5 sub-keys each, top-level reward/advantage). - for prefix in ("reward", "advantage"): - for sub in ("max", "mean", "min", "std", "sum"): - assert f"{prefix}/_{sub}" in agg - # Group std + degenerate fraction live under reward/. - assert "reward/group_std/mean" in agg - assert "reward/group_std/max" in agg - assert "reward/zero_std_frac" in agg - # Per-rollout policy version distribution (min/max only). - assert "rollout/policy_version/mean" not in agg - assert "rollout/policy_version/min" in agg - assert "rollout/policy_version/max" in agg - # num_prompts/num_episodes were dropped: make sure they don't - # creep back in. - assert "rollout/num_prompts" not in agg - assert "rollout/num_episodes" not in agg - - def test_policy_version_metrics_single_version(self) -> None: - """When all rollouts came from the same policy version, min == max.""" - single_version = [ - _trajectory(0, 4, 5, reward=1.0, policy_version=5), - _trajectory(1, 4, 5, reward=0.5, policy_version=5), - ] - _, em = RLTrainer._build_episodes(single_version) - agg = m.MetricsProcessor._aggregate_metrics(em) - assert agg["rollout/policy_version/min"] == 5.0 - assert agg["rollout/policy_version/max"] == 5.0 - - def test_policy_version_metrics_mixed_versions(self) -> None: - """Mixed rollout versions emit min and max.""" - mixed_versions = [ - _trajectory(0, 4, 5, reward=1.0, policy_version=2), - _trajectory(1, 4, 5, reward=0.5, policy_version=4), - ] - _, em = RLTrainer._build_episodes(mixed_versions) - agg = m.MetricsProcessor._aggregate_metrics(em) - assert agg["rollout/policy_version/min"] == 2.0 - assert agg["rollout/policy_version/max"] == 4.0 - - def test_degenerate_group_fraction(self) -> None: - # Two groups: both constant => fraction == 1.0. - all_constant = [ - _trajectory(0, 4, 5, reward=1.0), - _trajectory(0, 4, 5, reward=1.0), - _trajectory(1, 4, 5, reward=0.5), - _trajectory(1, 4, 5, reward=0.5), - ] - _, em = RLTrainer._build_episodes(all_constant) - agg = m.MetricsProcessor._aggregate_metrics(em) - assert agg["reward/zero_std_frac"] == 1.0 - - # Mixed: one group constant (degenerate), one group varied. - mixed = [ - _trajectory(0, 4, 5, reward=1.0), - _trajectory(0, 4, 5, reward=1.0), - _trajectory(1, 4, 5, reward=0.0), - _trajectory(1, 4, 5, reward=1.0), - ] - _, em = RLTrainer._build_episodes(mixed) - agg = m.MetricsProcessor._aggregate_metrics(em) - assert agg["reward/zero_std_frac"] == 0.5 - - # Both groups have variance => 0.0. - none_constant = [ - _trajectory(0, 4, 5, reward=0.0), - _trajectory(0, 4, 5, reward=1.0), - _trajectory(1, 4, 5, reward=0.5), - _trajectory(1, 4, 5, reward=1.5), - ] - _, em = RLTrainer._build_episodes(none_constant) - agg = m.MetricsProcessor._aggregate_metrics(em) - assert agg["reward/zero_std_frac"] == 0.0 - - -# --------------------------------------------------------------------------- -# RLTrainer.Config wiring -# --------------------------------------------------------------------------- - - -class TestRLTrainerConfigWiring: - """Use the canonical `rl_grpo_qwen3_0_6b` registry config so the test - matches a real production config and stays insulated from any future - tightening of VLLMGenerator/PolicyTrainer field validators.""" - - def test_metrics_default_uses_factory(self) -> None: - from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b - - cfg = rl_grpo_qwen3_0_6b() - baseline = m.MetricsProcessor.Config() - assert cfg.metrics.console_log_keys_train == baseline.console_log_keys_train - assert ( - cfg.metrics.console_log_keys_validation - == baseline.console_log_keys_validation - ) - - def test_metrics_defaults_are_independent_copies(self) -> None: - """Mutating one Config's allow lists must not bleed into other instances.""" - from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b - - cfg = rl_grpo_qwen3_0_6b() - cfg.metrics.console_log_keys_train.append("X") - cfg.metrics.console_log_keys_validation.append("Y") - # A fresh Config still has the pristine defaults. - fresh = rl_grpo_qwen3_0_6b() - assert "X" not in fresh.metrics.console_log_keys_train - assert "Y" not in fresh.metrics.console_log_keys_validation - - def test_metrics_default_wandb_enabled(self) -> None: - from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b - - cfg = rl_grpo_qwen3_0_6b() - assert cfg.metrics.enable_wandb is True - assert cfg.metrics.enable_tensorboard is False - - -# --------------------------------------------------------------------------- -# GRPOLoss bridge -# --------------------------------------------------------------------------- - - -class TestGRPOLossBridge: - def test_loss_keeps_gradient(self) -> None: - """`loss` must remain differentiable so `.backward()` works. - Regression test for `_token_weighted_mean` accidentally detaching.""" - loss_fn = GRPOLoss(GRPOLoss.Config(clip_eps=0.2)) - policy_logprobs = [ - torch.zeros(2, requires_grad=True), - torch.zeros(8, requires_grad=True), - ] - - loss, _loss_metrics = loss_fn( - policy_logprobs=policy_logprobs, - advantages=torch.tensor([1.0, -1.0]), - num_global_valid_tokens=torch.tensor(10.0), - ) - - assert loss.requires_grad - assert loss.grad_fn is not None - loss.backward() - assert all(sample.grad is not None for sample in policy_logprobs) - - def test_returns_loss_and_pre_normalized_metrics(self) -> None: - loss_fn = GRPOLoss(GRPOLoss.Config(clip_eps=0.2)) - # Two samples with unequal response lengths. - policy_logprobs = [ - torch.zeros(2, requires_grad=True), - torch.zeros(8, requires_grad=True), - ] - advantages = torch.tensor([1.0, -1.0]) - # Single-rank case: global == local valid tokens. - num_global_valid_tokens = torch.tensor(10.0) - - loss, loss_metrics = loss_fn( - policy_logprobs=policy_logprobs, - advantages=advantages, - num_global_valid_tokens=num_global_valid_tokens, - ) - assert isinstance(loss, torch.Tensor) - assert isinstance(loss_metrics, dict) - for key in ("loss/mean", "loss/ratio/mean", "loss/ratio/clipped_frac"): - assert key in loss_metrics - - def test_loss_is_token_weighted_sum_over_global_tokens(self) -> None: - """loss = sum_i(sample_loss_i * num_tokens_i) / num_global_valid_tokens. - - Under unequal response lengths this differs from a naive sample mean. - """ - loss_fn = GRPOLoss(GRPOLoss.Config(clip_eps=0.2)) - policy_logprobs = [ - torch.full((2,), 0.1, requires_grad=True), - torch.full((8,), 0.0, requires_grad=True), - ] - advantages = torch.tensor([1.0, -1.0]) - num_global_valid_tokens = torch.tensor(10.0) - - loss, loss_metrics = loss_fn( - policy_logprobs=policy_logprobs, - advantages=advantages, - num_global_valid_tokens=num_global_valid_tokens, - ) - # loss/mean metric is the same value as loss (both pre-normalized). - assert math.isclose( - loss_metrics["loss/mean"].item(), - loss.item(), - rel_tol=1e-6, - ) - - # And it is NOT equal to the unweighted sample mean of policy gradient - # losses, which is what the prior implementation used. - per_sample_mean_logprobs = torch.stack( - [sample_logprobs.mean() for sample_logprobs in policy_logprobs] - ) - ratio = torch.exp(per_sample_mean_logprobs) - clipped_ratio = torch.clamp(ratio, 1 - 0.2, 1 + 0.2) - sample_policy_gradient_losses = -torch.min( - ratio * advantages, clipped_ratio * advantages - ) - unweighted_sample_mean = float(sample_policy_gradient_losses.mean().item()) - assert not math.isclose( - loss.item(), unweighted_sample_mean, rel_tol=1e-4, abs_tol=1e-6 - ) - - -# --------------------------------------------------------------------------- -# Trainer reducers (single-DP fast paths) -# --------------------------------------------------------------------------- - - -def _stub_trainer_for_reducers(dp_size: int): - """Build a minimal stand-in for `PolicyTrainer` so we can exercise - the reducer methods on a CPU box without spawning Monarch / NCCL. - """ - # Late import: importing `PolicyTrainer` triggers monarch/torchtitan - # actor wiring at module load time, which is fine for CPU. - from torchtitan.experiments.rl.actors.trainer import PolicyTrainer - - inst = PolicyTrainer.__new__(PolicyTrainer) - inst.dp_size = dp_size - inst.device = torch.device("cpu") - inst.parallel_dims = MagicMock() - inst.parallel_dims.get_optional_mesh = MagicMock(return_value=None) - return inst - - -class TestReducerFastPaths: - def test_single_dp_identical(self) -> None: - # Pre-normalized values pass through SUM-reduce unchanged on a single - # rank (no mesh -> no all-reduce -> values are exactly what we passed). - trainer = _stub_trainer_for_reducers(dp_size=1) - out = trainer.reduce_forward_backward_metrics( - sum_reduced_metrics={ - "loss/mean": torch.tensor(3.0), - "bit_wise/logprob_diff/mean": torch.tensor(0.001), - "bit_wise/ratio_tokens_different/mean": torch.tensor(0.0), - }, - max_reduced_metrics={"bit_wise/logprob_diff/max": torch.tensor(0.005)}, - ) - assert out["loss/mean"] == pytest.approx(3.0) - assert out["bit_wise/logprob_diff/mean"] == pytest.approx(0.001) - assert out["bit_wise/logprob_diff/max"] == pytest.approx(0.005) - assert out["bit_wise/ratio_tokens_different/mean"] == 0.0 - - def test_unbiased_sum_reduction_across_ranks(self) -> None: - """Two ranks contribute pre-normalized shares; SUM-reducing - reconstructs the global value. - - Rank 0 shares: loss/mean=10/15 (token-weighted local share). - Rank 1 shares: loss/mean=30/15. - SUM-reduce: 40/15 = 2.667 (the global token-weighted mean). - """ - trainer = _stub_trainer_for_reducers(dp_size=2) - trainer.parallel_dims.get_optional_mesh = MagicMock(return_value="loss") - - rank0_share = torch.tensor([10.0 / 15.0], dtype=torch.float32) - rank1_share = torch.tensor([30.0 / 15.0], dtype=torch.float32) - - def fake_all_reduce(t, *, reduceOp, group): - if t.numel() == 1 and t.dtype == torch.float32: - return rank0_share + rank1_share - return t - - with patch( - "torchtitan.experiments.rl.actors.trainer.funcol.all_reduce", - side_effect=fake_all_reduce, - ): - out = trainer.reduce_forward_backward_metrics( - sum_reduced_metrics={"loss/mean": rank0_share[0]}, - max_reduced_metrics={"bit_wise/logprob_diff/max": torch.tensor(0.0)}, - ) - assert out["loss/mean"] == pytest.approx(40.0 / 15.0) - - def test_max_reduce_path(self) -> None: - """MAX-reduced metrics compose via elementwise max across ranks. - - Patches funcol.all_reduce to dispatch on reduceOp: SUM doubles - (simulating two ranks contributing equal shares); MAX takes the - elementwise max with a higher second-rank value. - """ - import torch.distributed.distributed_c10d as c10d - - trainer = _stub_trainer_for_reducers(dp_size=2) - trainer.parallel_dims.get_optional_mesh = MagicMock(return_value="loss") - - rank1_max = torch.tensor([0.006], dtype=torch.float32) - - def fake_all_reduce(t, *, reduceOp, group): - if reduceOp == c10d.ReduceOp.SUM.name: - return t * 2 - if reduceOp == c10d.ReduceOp.MAX.name: - return torch.maximum(t, rank1_max) - raise AssertionError(f"unexpected reduceOp={reduceOp!r}") - - with patch( - "torchtitan.experiments.rl.actors.trainer.funcol.all_reduce", - side_effect=fake_all_reduce, - ): - out = trainer.reduce_forward_backward_metrics( - sum_reduced_metrics={"loss/mean": torch.tensor(0.5)}, - max_reduced_metrics={"bit_wise/logprob_diff/max": torch.tensor(0.003)}, - ) - # SUM doubled: 0.5 + 0.5 = 1.0. MAX = max(0.003, 0.006) = 0.006. - assert out["loss/mean"] == pytest.approx(1.0) - assert out["bit_wise/logprob_diff/max"] == pytest.approx(0.006) - - def test_sum_only_skips_max_collective(self) -> None: - """max_reduced_metrics={} must not crash and must not call the - MAX collective; the SUM bucket is still reduced normally.""" - import torch.distributed.distributed_c10d as c10d - - trainer = _stub_trainer_for_reducers(dp_size=2) - trainer.parallel_dims.get_optional_mesh = MagicMock(return_value="loss") - - seen_ops: list[str] = [] - - def fake_all_reduce(t, *, reduceOp, group): - seen_ops.append(reduceOp) - if reduceOp == c10d.ReduceOp.SUM.name: - return t * 2 - raise AssertionError(f"unexpected reduceOp={reduceOp!r}") - - with patch( - "torchtitan.experiments.rl.actors.trainer.funcol.all_reduce", - side_effect=fake_all_reduce, - ): - out = trainer.reduce_forward_backward_metrics( - sum_reduced_metrics={"loss/mean": torch.tensor(0.5)}, - max_reduced_metrics={}, - ) - assert seen_ops == [c10d.ReduceOp.SUM.name] - assert out == {"loss/mean": pytest.approx(1.0)} - - def test_max_only_skips_sum_collective(self) -> None: - """sum_reduced_metrics={} must not crash and must not call the - SUM collective; the MAX bucket is still reduced normally.""" - import torch.distributed.distributed_c10d as c10d - - trainer = _stub_trainer_for_reducers(dp_size=2) - trainer.parallel_dims.get_optional_mesh = MagicMock(return_value="loss") - - seen_ops: list[str] = [] - rank1_max = torch.tensor([0.006], dtype=torch.float32) - - def fake_all_reduce(t, *, reduceOp, group): - seen_ops.append(reduceOp) - if reduceOp == c10d.ReduceOp.MAX.name: - return torch.maximum(t, rank1_max) - raise AssertionError(f"unexpected reduceOp={reduceOp!r}") - - with patch( - "torchtitan.experiments.rl.actors.trainer.funcol.all_reduce", - side_effect=fake_all_reduce, - ): - out = trainer.reduce_forward_backward_metrics( - sum_reduced_metrics={}, - max_reduced_metrics={ - "bit_wise/logprob_diff/max": torch.tensor(0.003), - }, - ) - assert seen_ops == [c10d.ReduceOp.MAX.name] - assert out == {"bit_wise/logprob_diff/max": pytest.approx(0.006)} - - def test_both_empty_returns_empty(self) -> None: - """Both buckets empty: no collectives called, empty dict returned.""" - trainer = _stub_trainer_for_reducers(dp_size=2) - trainer.parallel_dims.get_optional_mesh = MagicMock(return_value="loss") - with patch( - "torchtitan.experiments.rl.actors.trainer.funcol.all_reduce", - side_effect=AssertionError("should not be called"), - ): - out = trainer.reduce_forward_backward_metrics( - sum_reduced_metrics={}, - max_reduced_metrics={}, - ) - assert out == {} diff --git a/torchtitan/experiments/rl/tests/test_shutdown.py b/torchtitan/experiments/rl/tests/test_shutdown.py index 80218ddbbc..dc57e9d2ec 100644 --- a/torchtitan/experiments/rl/tests/test_shutdown.py +++ b/torchtitan/experiments/rl/tests/test_shutdown.py @@ -10,24 +10,18 @@ from torchtitan.experiments.rl import grpo from torchtitan.experiments.rl.actors.generator import VLLMGenerator - - -class _FakeConfigManager: - config = object() - - def parse_args(self): - return self.config +from torchtitan.experiments.rl.batcher import Batcher class _FakeRLTrainer: instances = [] - def __init__(self, config): + def __init__(self, config=None): self.config = config self.events = [] self.instances.append(self) - async def setup(self): + async def setup_async(self): self.events.append("setup") if getattr(self.config, "fail_setup", False): raise RuntimeError("setup failed") @@ -43,11 +37,55 @@ async def close(self): self.events.append("close") +class _FakeDebug: + enable_structured_logging = False + + +class _FakeTrainerConfig: + debug = _FakeDebug() + + +class _FakeConfig: + """Fake config whose build() returns a _FakeRLTrainer.""" + + dump_folder = "/tmp/test_rl" + trainer = _FakeTrainerConfig() + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def build(self): + return _FakeRLTrainer(config=self) + + +class _FakeConfigManager: + config = _FakeConfig() + + def parse_args(self): + return self.config + + +def _make_stub_rl_trainer(): + """Create an RLTrainer with a minimal stub config (no VLLMGenerator validation).""" + from torchtitan.experiments.rl.observability import metrics as m + + class _StubConfig: + batcher = Batcher.Config() + metrics = m.MetricsProcessor.Config() + dump_folder = "/tmp/test_rl" + hf_assets_path = "./tests/assets/tokenizer" + + def to_dict(self): + return {} + + return grpo.RLTrainer(_StubConfig()) + + def test_main_shuts_down_after_success(monkeypatch): - _FakeConfigManager.config = object() + _FakeConfigManager.config = _FakeConfig() _FakeRLTrainer.instances = [] monkeypatch.setattr(grpo, "ConfigManager", _FakeConfigManager) - monkeypatch.setattr(grpo, "RLTrainer", _FakeRLTrainer) asyncio.run(grpo.main()) @@ -55,13 +93,9 @@ def test_main_shuts_down_after_success(monkeypatch): def test_main_shuts_down_after_train_failure(monkeypatch): - class FailingConfig: - fail_train = True - - _FakeConfigManager.config = FailingConfig() + _FakeConfigManager.config = _FakeConfig(fail_train=True) _FakeRLTrainer.instances = [] monkeypatch.setattr(grpo, "ConfigManager", _FakeConfigManager) - monkeypatch.setattr(grpo, "RLTrainer", _FakeRLTrainer) with pytest.raises(RuntimeError, match="train failed"): asyncio.run(grpo.main()) @@ -70,13 +104,9 @@ class FailingConfig: def test_main_shuts_down_after_setup_failure(monkeypatch): - class FailingConfig: - fail_setup = True - - _FakeConfigManager.config = FailingConfig() + _FakeConfigManager.config = _FakeConfig(fail_setup=True) _FakeRLTrainer.instances = [] monkeypatch.setattr(grpo, "ConfigManager", _FakeConfigManager) - monkeypatch.setattr(grpo, "RLTrainer", _FakeRLTrainer) with pytest.raises(RuntimeError, match="setup failed"): asyncio.run(grpo.main()) @@ -85,7 +115,7 @@ class FailingConfig: def test_rl_trainer_shutdown_is_noop_before_meshes_spawn(): - trainer = grpo.RLTrainer(object()) + trainer = _make_stub_rl_trainer() asyncio.run(trainer.close()) @@ -99,14 +129,9 @@ def test_main_swallows_cancellation_after_shutdown(monkeypatch): running task; ``main`` runs ``close`` in ``finally`` and the explicit ``except`` clause swallows the interrupt so the process exits 0 without a traceback.""" - - class CancelledConfig: - cancel_train = True - - _FakeConfigManager.config = CancelledConfig() + _FakeConfigManager.config = _FakeConfig(cancel_train=True) _FakeRLTrainer.instances = [] monkeypatch.setattr(grpo, "ConfigManager", _FakeConfigManager) - monkeypatch.setattr(grpo, "RLTrainer", _FakeRLTrainer) # No exception escapes; close still ran. asyncio.run(grpo.main()) @@ -146,7 +171,7 @@ async def stop(self): def test_shutdown_calls_actor_close_before_mesh_stop(): events: list[str] = [] - rl_trainer = grpo.RLTrainer(object()) + rl_trainer = _make_stub_rl_trainer() rl_trainer.trainer = _StubActor("trainer.close", events) rl_trainer.generator = _StubActor("generator.close", events) rl_trainer._proc_meshes = [ @@ -167,7 +192,7 @@ def test_shutdown_calls_actor_close_before_mesh_stop(): def test_shutdown_continues_after_actor_close_failure(): events: list[str] = [] - rl_trainer = grpo.RLTrainer(object()) + rl_trainer = _make_stub_rl_trainer() rl_trainer.trainer = _StubActor("trainer.close", events, raises=True) rl_trainer.generator = _StubActor("generator.close", events) rl_trainer._proc_meshes = [_StubMesh("mesh.stop[0]", events)] diff --git a/torchtitan/experiments/rl/types.py b/torchtitan/experiments/rl/types.py index abad041596..8bd6b53a54 100644 --- a/torchtitan/experiments/rl/types.py +++ b/torchtitan/experiments/rl/types.py @@ -9,71 +9,24 @@ import torch -@dataclass(kw_only=True, slots=True) -class Step: - """Env transition: named reward components, done flag, optional next observation. - - ``rewards`` is a dict of component-name to value (e.g. - ``{"correctness": 1.0, "format": 0.3}``); envs are free to define - any decomposition. Trainers read the scalar ``reward`` property - (sum of components); loggers iterate ``rewards.items()`` for - per-component reporting without needing to know the keys. - - ``observation`` (the next prompt the agent will see) is only - populated by multi-turn envs. Single-turn envs leave it None. - """ - - rewards: dict[str, float] - done: bool - observation: str | None = None - - @property - def reward(self) -> float: - return sum(self.rewards.values()) - - @dataclass(kw_only=True, slots=True) class Completion: - """A single generated sequence from the generator. - - Pure generation artifact - no reward, no advantage. ``prompt_idx`` - is the position of the source prompt in the input ``prompts`` list. - """ + """A single generated sequence from the generator.""" policy_version: int - prompt_idx: int - text: str + request_idx: int token_ids: list[int] token_logprobs: list[float] finish_reason: str | None = None """vLLM `CompletionOutput.finish_reason` ("stop" | "length" | "abort")""" -@dataclass(kw_only=True, slots=True) -class Trajectory: - """One rollout: a sequence of ``(Completion, Step)`` transitions. - - Single-turn tasks produce trajectories with one transition. The - Completion carries the generator's response-side metadata; the Step - carries the env's reward and done flag; - """ - - sample_idx: int - prompt_token_ids: list[int] - transitions: list[tuple[Completion, Step]] - - @property - def total_reward(self) -> float: - return sum(s.reward for _, s in self.transitions) - - +# TODO: rename `Episode` -> `TrainingSample` +# and `rollout_to_episode` -> `rollout_to_training_sample` @dataclass(kw_only=True, slots=True) class Episode: - """Training sample: flattened trajectory + GRPO advantage. - - Flat shape (rather than composition) because the trainer collate - path and logging read these fields directly. - """ + """Training sample: flattened Rollout turns + GRPO advantage, + ready for collation into a batch.""" policy_version: int prompt_idx: int @@ -104,7 +57,7 @@ class TrainingBatch: @dataclass(frozen=True, slots=True) class OptimStepOutput: - """Result returned by ``PolicyTrainer.optim_step`` to the controller.""" + """Result returned by `PolicyTrainer.optim_step` to the controller.""" policy_version: int metrics: dict[str, float]