Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6913120
[RL] Add Batcher as Configurable, support [B, L] microbatches
wwwjn May 18, 2026
7c7c9f7
rebase
wwwjn May 18, 2026
8e4fc60
remove position in loss calc
wwwjn May 19, 2026
dc690a1
update algorithm
wwwjn May 19, 2026
40ef4c5
update rebase
wwwjn May 19, 2026
767022f
update trainer to using new TrainBatch
wwwjn May 20, 2026
50b5572
update loss
wwwjn May 20, 2026
86ed252
update names and configs
wwwjn May 20, 2026
31dcd8b
update batcher with waste metrics
wwwjn May 20, 2026
7dad6f5
update configs
wwwjn May 22, 2026
d440e46
address comments
wwwjn May 26, 2026
1523db5
v8 datatypes + env protocol
May 27, 2026
312c292
v8 rubric + rollouts refactor
May 27, 2026
14aaeca
v8 Task surface refactor + style pass
May 28, 2026
6292f1d
v8 docstring/style pass + truncation fix + renderer tokenizer
May 29, 2026
e9af27a
Merge upstream/main (landed batcher PR) into v8 branch
May 29, 2026
85880db
Fix NoReduce collision: rollout/group_failures sums across collection…
May 29, 2026
9243ba1
Adopt renderers typed-config API (create_renderer(tokenizer, config))
May 29, 2026
701b4d1
Revert unrelated merge drift to upstream/main
May 29, 2026
11003ba
RL rollout loop cleanup: RolloutGroup, per-rollout step, status enum
May 29, 2026
5044e86
RL example default: thinking on, max_tokens=700; tidy rollout types
May 29, 2026
2dcd0f9
RL env/types review pass (docs 69+70): renames, env reward, renderer …
May 29, 2026
7a72a41
Fix _build_episodes group-skip + rollout_to_episode text
May 29, 2026
0cf8460
Docstring concision pass + rename envs/ -> env_types/
May 29, 2026
df376bb
Make Task an ABC; base Config owns rubric, subclass builds it
May 30, 2026
d3ba2df
[rl] v8: message/token env types, Configurable rubric, Task ABC
Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchtitan/experiments/rl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ uv venv --python 3.12 titan-rl
source titan-rl/bin/activate
```

1. Install Monarch and TorchStore from main:
1. Install Monarch, TorchStore, and Renderers from main:
```bash
uv pip install torchmonarch==0.4.1
uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main"
uv pip install pygtrie portpicker
uv pip install "git+https://github.com/PrimeIntellect-ai/renderers.git@main"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be used for sft as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to depend on renders lastest main?

```

2. Install Flash Attention 3 kernels:
Expand Down
11 changes: 7 additions & 4 deletions torchtitan/experiments/rl/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ class SamplingConfig:
max_tokens: int = 100
"""Maximum number of tokens to generate per completion."""

stop_token_ids: list[int] = field(default_factory=list)
"""Role-boundary stop tokens from the renderer (e.g. Qwen3 `<|im_end|>`)."""


class VLLMGenerator(Actor, Configurable):
"""
Expand Down Expand Up @@ -411,6 +414,7 @@ async def generate(
top_p=_sampling_config.top_p,
max_tokens=_sampling_config.max_tokens,
n=_sampling_config.n,
stop_token_ids=_sampling_config.stop_token_ids or None,
seed=self.config.debug.seed,
logprobs=1,
output_kind=RequestOutputKind.FINAL_ONLY,
Expand All @@ -436,14 +440,14 @@ async def generate(
all_outputs.extend(self._engine.step())

# vLLM may return requests out of order; sort by the integer
# request_id we assigned so prompt_idx lines up with the input.
# request_id we assigned so request_idx lines up with the input.
all_outputs.sort(key=lambda o: int(o.request_id))

completions: list[Completion] = []
generation_metrics: list[m.Metric] = []
output_token_counts: list[int] = []
for output in all_outputs:
prompt_idx = int(output.request_id)
request_idx = int(output.request_id)
generation_metrics.extend(
_prepare_generation_request_metrics(output, prefix=metrics_prefix)
)
Expand All @@ -456,8 +460,7 @@ async def generate(
completions.append(
Completion(
policy_version=self.policy_version,
prompt_idx=prompt_idx,
text=sample.text,
request_idx=request_idx,
token_ids=sample.token_ids,
token_logprobs=per_token_logprobs,
finish_reason=sample.finish_reason,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/rl/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ async def forward_backward(
async def optim_step(self) -> OptimStepOutput:
"""Clip gradients, step optimizer + LR scheduler, return updated state."""
# TODO: Accept optional optimizer params (e.g. learning rate)
# to allow controller-owned schedules (see Tinker API).
# to allow controller-owned schedules.

# capture LR before step
current_lrs = self.lr_schedulers.schedulers[0].get_last_lr()
Expand Down
65 changes: 39 additions & 26 deletions torchtitan/experiments/rl/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from torchtitan.experiments.rl.batcher import BatchConfig, Batcher
from torchtitan.experiments.rl.grpo import GRPOLoss, RLTrainer
from torchtitan.experiments.rl.observability.metrics import MetricsProcessor
from torchtitan.experiments.rl.sum_digits import SumDigitsEnv
from torchtitan.experiments.rl.renderer import RendererConfig
from torchtitan.experiments.rl.tasks.sum_digits import SumDigitsDataset, SumDigitsTask
from torchtitan.models.qwen3 import model_registry


Expand All @@ -39,10 +40,14 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config:
num_prompts_per_step=5,
num_validation_samples=20,
compile=CompileConfig(enable=True, backend="aot_eager"),
env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3),
validation_env=SumDigitsEnv.Config(
seed=99, correctness_reward=1.0, format_reward=0.3
),
tasks={
"sum_digits": SumDigitsTask.Config(
train_dataset=SumDigitsDataset.Config(seed=42),
val_dataset=SumDigitsDataset.Config(seed=99),
)
},
group_size=group_size,
renderer=RendererConfig(name="qwen3", enable_thinking=True),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why render related to a model name? Is it because it will use tokenizer path? Currently hf_assets_path is only used to local tokenizer. Can we move the tokenizer and tokenizer path under render?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why render related to a model name? Is it because it will use tokenizer path?
yes

Currently hf_assets_path is only used to local tokenizer. Can we move the tokenizer and tokenizer path under render?
Not yet. They dont support users providing their own tokenizer yet, but they will soon.

metrics=MetricsProcessor.Config(enable_wandb=True),
batcher=Batcher.Config(
batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048),
Expand Down Expand Up @@ -77,10 +82,9 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config:
),
checkpoint=CheckpointManager.Config(enable=False),
sampling=SamplingConfig(
n=group_size,
temperature=0.8,
top_p=0.95,
max_tokens=100,
max_tokens=700,
),
),
)
Expand All @@ -96,10 +100,14 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
num_prompts_per_step=5,
num_validation_samples=20,
compile=CompileConfig(enable=True, backend="aot_eager"),
env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3),
validation_env=SumDigitsEnv.Config(
seed=99, correctness_reward=1.0, format_reward=0.3
),
tasks={
"sum_digits": SumDigitsTask.Config(
train_dataset=SumDigitsDataset.Config(seed=42),
val_dataset=SumDigitsDataset.Config(seed=99),
)
},
group_size=group_size,
renderer=RendererConfig(name="qwen3", enable_thinking=True),
metrics=MetricsProcessor.Config(enable_wandb=True),
batcher=Batcher.Config(
batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048),
Expand Down Expand Up @@ -135,10 +143,9 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
),
checkpoint=CheckpointManager.Config(enable=False),
sampling=SamplingConfig(
n=group_size,
temperature=0.8,
top_p=0.95,
max_tokens=100,
max_tokens=700,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to change batcher max_seq_length accordingly, otherwise all the samples will be dropped because they are longer than max_seq_length

),
),
)
Expand All @@ -154,10 +161,14 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config:
num_prompts_per_step=5,
num_validation_samples=20,
compile=CompileConfig(enable=True, backend="aot_eager"),
env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3),
validation_env=SumDigitsEnv.Config(
seed=99, correctness_reward=1.0, format_reward=0.3
),
tasks={
"sum_digits": SumDigitsTask.Config(
train_dataset=SumDigitsDataset.Config(seed=42),
val_dataset=SumDigitsDataset.Config(seed=99),
)
},
group_size=group_size,
renderer=RendererConfig(name="qwen3", enable_thinking=True),
metrics=MetricsProcessor.Config(enable_wandb=True),
batcher=Batcher.Config(
batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048),
Expand Down Expand Up @@ -192,17 +203,16 @@ def rl_grpo_qwen3_14b() -> RLTrainer.Config:
),
checkpoint=CheckpointManager.Config(enable=False),
sampling=SamplingConfig(
n=group_size,
temperature=0.8,
top_p=0.95,
max_tokens=100,
max_tokens=700,
),
),
)


def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config:
"""On-policy GRPO config for Qwen3-0.6B under same parallelism (4 GPUs: 2 gen + 2 train).
"""On-policy GRPO config for Qwen3-0.6B (4 GPUs: 2 gen + 2 train).

Enables deterministic + batch-invariant mode for true on-policy RL training.
"""
Expand All @@ -215,10 +225,14 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config:
num_prompts_per_step=5,
num_validation_samples=20,
compile=CompileConfig(enable=True, backend="aot_eager"),
env=SumDigitsEnv.Config(seed=42, correctness_reward=1.0, format_reward=0.3),
validation_env=SumDigitsEnv.Config(
seed=99, correctness_reward=1.0, format_reward=0.3
),
tasks={
"sum_digits": SumDigitsTask.Config(
train_dataset=SumDigitsDataset.Config(seed=42),
val_dataset=SumDigitsDataset.Config(seed=99),
)
},
group_size=group_size,
renderer=RendererConfig(name="qwen3", enable_thinking=True),
metrics=MetricsProcessor.Config(enable_wandb=True),
batcher=Batcher.Config(
batch=BatchConfig(local_batch_size=2, global_batch_size=8, seq_len=2048),
Expand Down Expand Up @@ -257,10 +271,9 @@ def rl_grpo_qwen3_0_6b_batch_invariant() -> RLTrainer.Config:
),
checkpoint=CheckpointManager.Config(enable=False),
sampling=SamplingConfig(
n=group_size,
temperature=0.8,
top_p=0.95,
max_tokens=100,
max_tokens=700,
),
debug=batch_invariant_config,
),
Expand Down
25 changes: 25 additions & 0 deletions torchtitan/experiments/rl/env_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.experiments.rl.env_types.message_env import (
MessageEnv,
MessageResetOutput,
MessageStepOutput,
)
from torchtitan.experiments.rl.env_types.renderer_env import (
RendererWrapperEnv,
TokenizedStepOutput,
TurnMessages,
)

__all__ = [
"MessageEnv",
"MessageResetOutput",
"MessageStepOutput",
"RendererWrapperEnv",
"TokenizedStepOutput",
"TurnMessages",
]
96 changes: 96 additions & 0 deletions torchtitan/experiments/rl/env_types/message_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import abc
from dataclasses import dataclass, field

from renderers import Message, ToolSpec


@dataclass(kw_only=True, slots=True)
class MessageResetOutput:
"""Initial prompt messages + tool specs from `MessageEnv.reset`."""

prompt_messages: list[Message] # [M_prompt]
"""The messages that form the initial prompt (e.g. [system, user])."""

tools: list[ToolSpec] = field(default_factory=list) # [K_tools]
"""Tool schemas exposed to the assistant. Empty for tool-less envs."""


@dataclass(kw_only=True, slots=True)
class MessageStepOutput:
"""The env's reply to the assistant's turn."""

env_messages: list[Message] = field(default_factory=list) # [M_env]
"""The env's reply messages (tool / user). Empty when the rollout terminates
with no follow-up."""

done: bool = False
"""`True` ends the rollout."""

env_rewards: dict[str, float] = field(default_factory=dict)
"""Optional reward signal the env provides for this step; the rubric decides
whether and how to use it. Empty if the env scores nothing."""

def __post_init__(self) -> None:
# env replies are tool/user turns; the assistant turn comes from the generator
if any(m.get("role") == "assistant" for m in self.env_messages):
raise ValueError(
"MessageStepOutput.env_messages may not contain assistant messages"
)


class MessageEnv(abc.ABC):
"""User-written env in message space. Implement `reset` + `step`.

Tip: `MessageEnv` works in messages and never sees token ids; You can have `RendererWrapperEnv`
wrap it and use a `Renderer` to convert messages <-> token ids for the generator.

Example:
# a one-tool calculator env. It is multi-turn — the env answers the
# assistant's tool call, then ends once the assistant replies without a tool.

class CalculatorEnv(MessageEnv):
async def reset(self) -> MessageResetOutput:
return MessageResetOutput(
prompt_messages=[{"role": "user", "content": "What is 12 * 7?"}],
tools=[CALCULATOR_TOOL],
)

async def step(self, assistant_message: Message) -> MessageStepOutput:
tool_calls = assistant_message.get("tool_calls")
if not tool_calls:
return MessageStepOutput(done=True) # assistant gave its final answer
result = run_calculator(tool_calls[0])
return MessageStepOutput(
env_messages=[{"role": "tool", "content": result}]
)
"""

@abc.abstractmethod
async def reset(self) -> MessageResetOutput:
"""Return the initial conversation + tools for prompt rendering."""

@abc.abstractmethod
async def step(self, assistant_message: Message) -> MessageStepOutput:
"""Advance the env one turn given the assistant's latest message.

`RendererWrapperEnv` parses the completion and handles
finish_reason / length / parse / timeout failures before calling this,
so the env only sees a well-formed assistant message.

Args:
assistant_message: the assistant's parsed turn.

Returns:
`MessageStepOutput` with the env's reply messages.
"""

async def close(self) -> None:
"""Release env-owned resources. Default no-op; idempotent."""
Loading
Loading