-
Notifications
You must be signed in to change notification settings - Fork 839
[RL] - MessageEnv, Rollout types, Rubric, Renderer #3453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6913120
7c7c9f7
8e4fc60
dc690a1
40ef4c5
767022f
50b5572
86ed252
31dcd8b
7dad6f5
d440e46
1523db5
312c292
14aaeca
6292f1d
e9af27a
85880db
9243ba1
701b4d1
11003ba
5044e86
2dcd0f9
7a72a41
0cf8460
df376bb
d3ba2df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,8 @@ | |
| from torchtitan.experiments.rl.batcher import BatchConfig, Batcher | ||
| from torchtitan.experiments.rl.grpo import GRPOLoss, RLTrainer | ||
| from torchtitan.experiments.rl.observability.metrics import MetricsProcessor | ||
| from torchtitan.experiments.rl.sum_digits import SumDigitsEnv | ||
| from torchtitan.experiments.rl.renderer import RendererConfig | ||
| from torchtitan.experiments.rl.tasks.sum_digits import SumDigitsDataset, SumDigitsTask | ||
| from torchtitan.models.qwen3 import model_registry | ||
|
|
||
|
|
||
|
|
@@ -39,10 +40,14 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: | |
| num_prompts_per_step=5, | ||
| num_validation_samples=20, | ||
| compile=CompileConfig(enable=True, backend="aot_eager"), | ||
| env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), | ||
| validation_env=SumDigitsEnv.Config( | ||
| seed=99, correctness_reward=1.0, format_reward=0.3 | ||
| ), | ||
| tasks={ | ||
| "sum_digits": SumDigitsTask.Config( | ||
| train_dataset=SumDigitsDataset.Config(seed=42), | ||
| val_dataset=SumDigitsDataset.Config(seed=99), | ||
| ) | ||
| }, | ||
| group_size=group_size, | ||
| renderer=RendererConfig(name="qwen3", enable_thinking=True), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why render related to a model name? Is it because it will use tokenizer path? Currently
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| metrics=MetricsProcessor.Config(enable_wandb=True), | ||
| batcher=Batcher.Config( | ||
| batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), | ||
|
|
@@ -77,10 +82,9 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: | |
| ), | ||
| checkpoint=CheckpointManager.Config(enable=False), | ||
| sampling=SamplingConfig( | ||
| n=group_size, | ||
| temperature=0.8, | ||
| top_p=0.95, | ||
| max_tokens=100, | ||
| max_tokens=700, | ||
| ), | ||
| ), | ||
| ) | ||
|
|
@@ -96,10 +100,14 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config: | |
| num_prompts_per_step=5, | ||
| num_validation_samples=20, | ||
| compile=CompileConfig(enable=True, backend="aot_eager"), | ||
| env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), | ||
| validation_env=SumDigitsEnv.Config( | ||
| seed=99, correctness_reward=1.0, format_reward=0.3 | ||
| ), | ||
| tasks={ | ||
| "sum_digits": SumDigitsTask.Config( | ||
| train_dataset=SumDigitsDataset.Config(seed=42), | ||
| val_dataset=SumDigitsDataset.Config(seed=99), | ||
| ) | ||
| }, | ||
| group_size=group_size, | ||
| renderer=RendererConfig(name="qwen3", enable_thinking=True), | ||
| metrics=MetricsProcessor.Config(enable_wandb=True), | ||
| batcher=Batcher.Config( | ||
| batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), | ||
|
|
@@ -135,10 +143,9 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config: | |
| ), | ||
| checkpoint=CheckpointManager.Config(enable=False), | ||
| sampling=SamplingConfig( | ||
| n=group_size, | ||
| temperature=0.8, | ||
| top_p=0.95, | ||
| max_tokens=100, | ||
| max_tokens=700, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will need to change batcher max_seq_length accordingly, otherwise all the samples will be dropped because they are longer than |
||
| ), | ||
| ), | ||
| ) | ||
|
|
@@ -154,10 +161,14 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config: | |
| num_prompts_per_step=5, | ||
| num_validation_samples=20, | ||
| compile=CompileConfig(enable=True, backend="aot_eager"), | ||
| env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), | ||
| validation_env=SumDigitsEnv.Config( | ||
| seed=99, correctness_reward=1.0, format_reward=0.3 | ||
| ), | ||
| tasks={ | ||
| "sum_digits": SumDigitsTask.Config( | ||
| train_dataset=SumDigitsDataset.Config(seed=42), | ||
| val_dataset=SumDigitsDataset.Config(seed=99), | ||
| ) | ||
| }, | ||
| group_size=group_size, | ||
| renderer=RendererConfig(name="qwen3", enable_thinking=True), | ||
| metrics=MetricsProcessor.Config(enable_wandb=True), | ||
| batcher=Batcher.Config( | ||
| batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), | ||
|
|
@@ -192,17 +203,16 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config: | |
| ), | ||
| checkpoint=CheckpointManager.Config(enable=False), | ||
| sampling=SamplingConfig( | ||
| n=group_size, | ||
| temperature=0.8, | ||
| top_p=0.95, | ||
| max_tokens=100, | ||
| max_tokens=700, | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config: | ||
| """On-policy GRPO config for Qwen3-0.6B under same parallelism (4 GPUs: 2 gen + 2 train). | ||
| """On-policy GRPO config for Qwen3-0.6B (4 GPUs: 2 gen + 2 train). | ||
|
|
||
| Enables deterministic + batch-invariant mode for true on-policy RL training. | ||
| """ | ||
|
|
@@ -215,10 +225,14 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config: | |
| num_prompts_per_step=5, | ||
| num_validation_samples=20, | ||
| compile=CompileConfig(enable=True, backend="aot_eager"), | ||
| env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3), | ||
| validation_env=SumDigitsEnv.Config( | ||
| seed=99, correctness_reward=1.0, format_reward=0.3 | ||
| ), | ||
| tasks={ | ||
| "sum_digits": SumDigitsTask.Config( | ||
| train_dataset=SumDigitsDataset.Config(seed=42), | ||
| val_dataset=SumDigitsDataset.Config(seed=99), | ||
| ) | ||
| }, | ||
| group_size=group_size, | ||
| renderer=RendererConfig(name="qwen3", enable_thinking=True), | ||
| metrics=MetricsProcessor.Config(enable_wandb=True), | ||
| batcher=Batcher.Config( | ||
| batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048), | ||
|
|
@@ -257,10 +271,9 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config: | |
| ), | ||
| checkpoint=CheckpointManager.Config(enable=False), | ||
| sampling=SamplingConfig( | ||
| n=group_size, | ||
| temperature=0.8, | ||
| top_p=0.95, | ||
| max_tokens=100, | ||
| max_tokens=700, | ||
| ), | ||
| debug=batch_invariant_config, | ||
| ), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.experiments.rl.env_types.message_env import ( | ||
| MessageEnv, | ||
| MessageResetOutput, | ||
| MessageStepOutput, | ||
| ) | ||
| from torchtitan.experiments.rl.env_types.renderer_env import ( | ||
| RendererWrapperEnv, | ||
| TokenizedStepOutput, | ||
| TurnMessages, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "MessageEnv", | ||
| "MessageResetOutput", | ||
| "MessageStepOutput", | ||
| "RendererWrapperEnv", | ||
| "TokenizedStepOutput", | ||
| "TurnMessages", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import abc | ||
| from dataclasses import dataclass, field | ||
|
|
||
| from renderers import Message, ToolSpec | ||
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class MessageResetOutput: | ||
| """Initial prompt messages + tool specs from `MessageEnv.reset`.""" | ||
|
|
||
| prompt_messages: list[Message] # [M_prompt] | ||
| """The messages that form the initial prompt (e.g. [system, user]).""" | ||
|
|
||
| tools: list[ToolSpec] = field(default_factory=list) # [K_tools] | ||
| """Tool schemas exposed to the assistant. Empty for tool-less envs.""" | ||
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class MessageStepOutput: | ||
| """The env's reply to the assistant's turn.""" | ||
|
|
||
| env_messages: list[Message] = field(default_factory=list) # [M_env] | ||
| """The env's reply messages (tool / user). Empty when the rollout terminates | ||
| with no follow-up.""" | ||
|
|
||
| done: bool = False | ||
| """`True` ends the rollout.""" | ||
|
|
||
| env_rewards: dict[str, float] = field(default_factory=dict) | ||
| """Optional reward signal the env provides for this step; the rubric decides | ||
| whether and how to use it. Empty if the env scores nothing.""" | ||
|
|
||
| def __post_init__(self) -> None: | ||
| # env replies are tool/user turns; the assistant turn comes from the generator | ||
| if any(m.get("role") == "assistant" for m in self.env_messages): | ||
| raise ValueError( | ||
| "MessageStepOutput.env_messages may not contain assistant messages" | ||
| ) | ||
|
|
||
|
|
||
| class MessageEnv(abc.ABC): | ||
| """User-written env in message space. Implement `reset` + `step`. | ||
|
|
||
| Tip: `MessageEnv` works in messages and never sees token ids; You can have `RendererWrapperEnv` | ||
| wrap it and use a `Renderer` to convert messages <-> token ids for the generator. | ||
|
|
||
| Example: | ||
| # a one-tool calculator env. It is multi-turn — the env answers the | ||
| # assistant's tool call, then ends once the assistant replies without a tool. | ||
|
|
||
| class CalculatorEnv(MessageEnv): | ||
| async def reset(self) -> MessageResetOutput: | ||
| return MessageResetOutput( | ||
| prompt_messages=[{"role": "user", "content": "What is 12 * 7?"}], | ||
| tools=[CALCULATOR_TOOL], | ||
| ) | ||
|
|
||
| async def step(self, assistant_message: Message) -> MessageStepOutput: | ||
| tool_calls = assistant_message.get("tool_calls") | ||
| if not tool_calls: | ||
| return MessageStepOutput(done=True) # assistant gave its final answer | ||
| result = run_calculator(tool_calls[0]) | ||
| return MessageStepOutput( | ||
| env_messages=[{"role": "tool", "content": result}] | ||
| ) | ||
| """ | ||
|
|
||
| @abc.abstractmethod | ||
| async def reset(self) -> MessageResetOutput: | ||
| """Return the initial conversation + tools for prompt rendering.""" | ||
|
|
||
| @abc.abstractmethod | ||
| async def step(self, assistant_message: Message) -> MessageStepOutput: | ||
| """Advance the env one turn given the assistant's latest message. | ||
|
|
||
| `RendererWrapperEnv` parses the completion and handles | ||
| finish_reason / length / parse / timeout failures before calling this, | ||
| so the env only sees a well-formed assistant message. | ||
|
|
||
| Args: | ||
| assistant_message: the assistant's parsed turn. | ||
|
|
||
| Returns: | ||
| `MessageStepOutput` with the env's reply messages. | ||
| """ | ||
|
|
||
| async def close(self) -> None: | ||
| """Release env-owned resources. Default no-op; idempotent.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would this be used for sft as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to depend on renders lastest main?