[RL] - MessageEnv, Rollout types, Rubric, Renderer#3453
Conversation
- Add shared pack() function in torchtitan/components/dataloading/utils.py - Add Batcher class with Configurable pattern: packs episodes into fixed [B, seq_length] TrainBatches with gradient accumulation support - Refactor TrainBatch to use [B, L] tensors instead of [1, total_tokens] - Update PolicyTrainer.forward_backward to work with [B, L] microbatches - Simplify compute_logprobs/verify_logprob_identity in actors/utils.py
- envs/: MessageEnv ABC + MessageReset/MessageStep; RendererEnv wrapper + RendererEnvConfig + TokenizedTurn; Rollout / RolloutTurn / RolloutStatus / DatasetOutput types. - rubrics/: Rubric base class with score_one / score_group, typed Reward dataclass; Rubric owns truncation_reward / error_reward (doc 37 Option B). - recipes/: Task base + SumDigitsTask (dataset + env + grader + recipe). - grpo.py: _run_rollouts refactored (Option G — inline orchestration + _do_group_step helper + per-group failure isolation); validate() reuses the same path. - Drops: sum_digits.py (orphan), test_grpo_metrics.py (broken), Trajectory + Step from types.py.
- rollouts/: new folder holding Rollout / RolloutTurn / RolloutStatus / DatasetOutput (types.py) and last_assistant_text / rollout_to_episode / prepare_rollout_metrics (utils.py). envs/types.py deleted. - Rubric: Configurable with register_funcs() hook + lazy @cached_property weight normalization (sum-to-1). New RewardFn dataclass (fn + weight). truncation_reward / error_reward become Optional[float] short-circuit knobs (None = run reward fns on partial response). - SumDigitsRubric subclass added in grader.py; SumDigitsDataset is now Configurable; SumDigitsTask.Config composes nested sub-configs (dataset / rubric / env_limits) matching PolicyTrainer.Config style. - MessageReset/Step -> MsgResponseReset/Step. - RendererEnvConfig -> EnvLimits; kwarg config -> limits. - recipes/sum_digits/recipe.py -> task.py. - _run_rollouts: overflow-aware prompt routing -- initial-prompt overflow builds TRUNCATED_OVERFLOW rollouts directly instead of sending an empty prompt to the generator. - prepare_rollout_metrics consolidates the rollout/validation metric blocks; _prepare_reward_metrics removed. - Cleanups: RolloutStatus is_*() via frozenset; validate_rollout_output dropped; Rollout.group_id / sample_idx required; _log_samples Episode-only; reward_correct stops asserting env_input type.
Task surface
- DatasetOutput.env_name -> DatasetOutput.task; SumDigitsDataset
ENV_NAME -> TASK_NAME.
- Dataset moves off Task onto RLTrainer.Config:
train: Task.Config -> train_dataset + tasks dict keyed by task name.
Same for validation.
Rows route to the matching Task via example.task.
- SumDigitsTask: drops dataset + sample_example; only rubric,
env_limits, and make_envs remain.
- Task base gains score_group(rollouts, env_input) -> list[Rollout]:
default delegates to self.rubric.score_group + fills reward and
reward_components. Step logic stays in grpo._do_group_step.
- TODOs: continuous-batching Task.do_single_rollout; revisit Camp A
vs B (dataset on Task vs framework).
Style + cleanups (same files)
- Dataclasses converted from per-field docstrings to Args-at-top with
inline shape comments. Drops double-backticks and stale PR
references.
- Bare # title block comments in _run_rollouts, _build_episodes,
validate, train, _do_group_step. Dashed banners removed in train.
- Dead code: _shard_episodes removed (Batcher does sharding).
- Bug fixes:
- prepare_rollout_metrics total_lens computed per-rollout
(multi-turn safe).
- _do_group_step initial.next_token_ids raises on None instead of
silent default to [].
- Rubric: register_funcs result checked for unique fn names.
- renderer_env: env_step.status defaulting uses replace() not
in-place mutation; _terminal accepts last_response_messages
instead of post-construction mutation.
Smoke: imports, config build, rubric e2e (3 statuses + short-circuit
modes), task.score_group fills rewards.
- Dataclasses: per-field docstrings; single backticks; drop internal-doc refs. - RendererEnv.step_completion parses before classifying length/abort, so truncated/aborted/timed-out turns keep their response message + tokens (partial-reward grading + debugging). - Completion.text removed (text comes from rollout messages); RolloutTurn gains reward/reward_components; TokenizedTurn carries terminal status. - Read pad/eos from the renderer's tokenizer; drop the standalone tokenizer. - Task.score_group returns list[Reward]; controller applies them. - Condense _run_rollouts; README installs renderers.
Conflicts resolved: - batcher.py / TrainingBatch labels / compute_logprobs -> take main (landed batcher PR). - grpo.py / config_registry / types Completion+Episode -> keep v8 (our rollout/env/rubric/recipe layer), adapted to main's batcher API. - config_registry: keep recipes structure (tasks/dataset), import BatchConfig from rl.batcher, enable_wandb=True (matches main). - Dropped stale test_grpo_metrics.py (tested the removed Step/Trajectory API).
renderers 0.1.8.dev37 replaced the name-based factory with a typed-config one. RendererConfig.build picks the config variant for name from the public discriminated union and passes only supported knobs; adds an enable_thinking field. Removes the private-attr workaround.
The merge with the landed batcher PR pulled in changes that don't belong
to this PR:
- Drop half-done-batcher-base leftovers: BatchConfig in config/configs.py
(the landed batcher has its own in rl.batcher), config/__init__ export,
components/dataloading/{__init__,utils}.py, and the text_datasets.py
rewrite. RL never imports any of these.
- grpo.py: revert train() docstring, reworded comments, and the
per-microbatch metrics rewrite back to upstream's aggregation; revert the
setup_async docstring; re-add _shard_episodes (dead in upstream too).
Keep only the genuine v8 substitutions (Rollout types, async
_collect_rollouts, _build_episodes).
- _run_rollouts/_collect_rollouts return list[RolloutGroup]; _PendingGroup build struct; whole-group drop on prompt overflow; flat n=1 generate with completions zipped 1:1 (no rebucket). - _do_group_step -> _do_single_rollout (one env+completion -> Rollout) with try/except inside so partial turns survive as an ERROR rollout. - RolloutStatus gains ONGOING + ERROR + is_terminal(); TRUNCATED_OVERFLOW -> TRUNCATED_PROMPT_OVERFLOW; TokenizedResponseStep.status is required. - _is_overflow -> _is_prompt_overflow (prompt_len >= max_rollout_tokens); drop unused EnvLimits.max_generation_tokens. - TokenizedTurn -> TokenizedResponseStep; Rubric.register_funcs is @abc.abstractmethod; logging TODOs on Rollout/RolloutTurn.
- config_registry: enable_thinking=True, max_tokens=700 for all rl_grpo_qwen3_* configs. Qwen3-0.6B one-shots sum_digits with thinking + enough budget (reward ~0.95-1.0 from step 1); 100/200 tokens truncate mid-<think>. - rollouts/types.py: shorten RolloutStatus docstring; group RolloutTurn fields.
…build - recipes/ -> tasks/; DatasetOutput.task -> task_name. - env carriers: MsgResponseReset/Step -> ResetOutput/StepOutput; TokenizedResponseStep -> TokenizedStepOutput. StepOutput drops status (keeps done); RendererEnv owns RolloutStatus. - EnvLimits -> RendererEnvConfig (field renderer_env_config; RendererEnv(config=)). - next_token_ids/next_messages -> next_prompt_token_ids/next_prompt_messages. - last_response_messages/response_messages split -> assistant_message + env_messages. - RolloutTurn gains prompt_messages; restore env-set reward path: StepOutput.reward_components -> TokenizedStepOutput.env_reward_components -> RolloutTurn.reward_components. - renderer.py build() via pydantic TypeAdapter; drop defensive list copy; log parse/timeout exceptions; MessageEnv docstring + Example; revert 9 unrelated RLTrainer.Config docstrings; 'policy' -> 'limits'.
- _build_episodes: drop a group when any sibling has no turns (turn-less ERROR rollout) instead of checking reward-is-None. The old check never fired (score_group always sets rewards) and let a turn-less rollout reach rollout_to_episode, which raises (requires exactly one turn). - rollout_to_episode: derive text from the rollout (last_assistant_text) instead of taking it as a param; drop the now-unused import in grpo.
- Apply 71_docstring_concision: tighten/dedupe docstrings & comments across 9 files (incl. _run_rollouts mental-model/step-list dedup, rubric duplicate Example, RewardFn.weight grammar, _score_one stray 'S', 'umbers' typo). Docstrings/comments only; no code changes. - Rename envs/ package -> env_types/ (message_env, renderer_env, __init__); update all imports.
| from torchtitan.experiments.rl.tasks.sum_digits.grader import SumDigitsRubric | ||
|
|
||
|
|
||
| class SumDigitsTask(Task): |
There was a problem hiding this comment.
This class guides most of the discussion here. RL requires multiple components to travel together: Dataset, rubrics, env, rollout loop logic. E.g. you don't use the SearchRubric to grade CodingDataset.
Ideally, you want an 'Agent' or 'Workflow' that holds all of this together.
The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.
This makes training easier, because now I can datamix my workflows. I can share them. I can import them.
This class is a step in this direction. I want to eventually put the rollout logic here as well. That way, workflow_A can be different than workflow_B, and the controller doesn't need to know that.
We will have to adapt and understand how stateful we want this class to be. Currently, the controller stitches things together, e.g.
def run_rollout():
sample = Task.get_sample
envs = Task.make_envs(sample)
...
Task.score_group(RolloutGroup)but in the future, it could just be:
Task.run_rollout()There was a problem hiding this comment.
For the user, it means that they can easily swap our Task with something like another harness like https://github.com/NVIDIA-NeMo/ProRL-Agent-Server or rllm, or vice versa.
I think it is the right mental model, but we need to learn how to package it well.
There was a problem hiding this comment.
The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.
I feel we should do this. The RLTrainer should just own trainer, generator, define task / env and maybe pass generator to them to obtain rollouts.
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class RolloutTurn: |
There was a problem hiding this comment.
With turn, the user has a full snapshot of every point in the rollout, from a token and message perspective. We can log this to a json, and debugging will be beautiful :)
| @abc.abstractmethod | ||
| def register_funcs(self) -> list[RewardFn]: | ||
| """Return this rubric's reward fns + weights (see class Example).""" |
There was a problem hiding this comment.
User can define anything here: other LLM calls, reward models, simple functions, etc.
There was a problem hiding this comment.
can they just define a Rubric and use it in a config, instead of register
| self.rubric = config.rubric.build() | ||
| self.renderer_env_config = config.renderer_env_config | ||
|
|
||
| def make_envs( |
There was a problem hiding this comment.
we currently always create a new env, but the user can be creative, e.g. have a pool of envs, in case creating an env is expensive.
| """Per-reward-fn raw output, keyed by `fn.__name__`.""" | ||
|
|
||
|
|
||
| class Rubric(Configurable, abc.ABC): |
There was a problem hiding this comment.
Another option would be to have the env provide the reward. This seems to be an antipattern. Most libraries have rubrics applied after the rollout is completed. It makes sense: This can be another LLM, for example. Also, you might want to compare samples or rerank them, so having access to the whole group is necessary.
A downside is that, if you have a sandbox in you env, you might need the same sandbox in your Rubric to evaluate the answer. But there are ways to mitigate it.
| """Ground-truth total digit sum.""" | ||
|
|
||
|
|
||
| class SumDigitsDataset(Configurable): |
There was a problem hiding this comment.
some libraries have the dataset live inside of the env or the 'Task'/workflow/agent. It makes sense: the data is attached to those.
Another option is like Tinker: Have the dataset yield a builder, and that builder produce an env.
builder = dataset.sample()
envs = builder.make_envs()I decided to have the dataset be its own class, outside of the env, so we could easily create interleaved dataset, dataloader, etc, without overthinking it.
I also decided to link the dataset to the Env by having a task_name field, instead of returning a builder. My thought process is that the user may want to have a pool of envs, for example, instead of always creating a new one. It made sense to me to let the Task.make_envs hold that logic, isntead of overloading the dataset with it.
There was a problem hiding this comment.
Haven't read enough to understand how "task_name" links things together.
But it sounds very natural to me that we should couple Env / Task and dataset, one way or another.
There was a problem hiding this comment.
What's the consideration of putting the dataset a field of SumDigitsTask.Config() vs. connecting them using TASK_NAME? A Task can be the container of Env, Rubric and Dataset?
There was a problem hiding this comment.
We can put the dataset in the Task. The reason I didn't is that, when we do datamix, now we would have to do at the task level, instead of regular dataloader(interleaved_dataset). Its a trade-off.
Hmmm, perhaps we can do:
all_datasets = []
for ds in all_datasets:
ds = task.get_dataset()
final_ds = interleaved(all_datasets)
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks, did a first pass. Overall I feel the logic among Env, Task, Rollout, Dataset, Rubric could be more clear.
| return sum(s.reward for _, s in self.transitions) | ||
|
|
||
|
|
||
| # TODO: rename `Episode` -> `TrainSample` and `rollout_to_episode` -> |
There was a problem hiding this comment.
nit: TrainingSample, to be consistent with TrainingBatch
| uv pip install torchmonarch==0.4.1 | ||
| uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main" | ||
| uv pip install pygtrie portpicker | ||
| uv pip install "git+https://github.com/PrimeIntellect-ai/renderers.git@main" |
There was a problem hiding this comment.
would this be used for sft as well?
There was a problem hiding this comment.
Do we need to depend on renders lastest main?
| preserve_thinking_between_tool_calls: bool = False | ||
|
|
||
| def build(self, *, model_path: str) -> Renderer: | ||
| from transformers import AutoTokenizer |
There was a problem hiding this comment.
Can we avoid this dependency?
There was a problem hiding this comment.
| `model_path` and constructs the `renderers` config matching `name`. | ||
|
|
||
| Args: | ||
| name: Renderer name (e.g. `"qwen3"`, `"auto"`). |
There was a problem hiding this comment.
Not clear what this means. We should make it explicitly referring to torchtitan model (and their tokenizer) and avoid transformers dependency.
| top_p=_sampling_config.top_p, | ||
| max_tokens=_sampling_config.max_tokens, | ||
| n=_sampling_config.n, | ||
| stop_token_ids=list(_sampling_config.stop_token_ids) or None, |
There was a problem hiding this comment.
| stop_token_ids=list(_sampling_config.stop_token_ids) or None, | |
| stop_token_ids=_sampling_config.stop_token_ids or None, |
| """The env's reply messages this turn (tool / user).""" | ||
|
|
||
| # For rubrics | ||
| reward_components: dict[str, float] = field(default_factory=dict) |
There was a problem hiding this comment.
can we call it reward which should always decompose into "components"
There was a problem hiding this comment.
I think that we want to have a simple field "reward: float", because thats what the final loss will use. No ambiguity here. And components is the breakdown for logging or weight averaging or some specific advantage computation.
There was a problem hiding this comment.
If single field reward is always coming from the components, can we have cached_property that derive the single field from components. If not (instead single field is the only thing RL algorithm eventually care, and everything else is only for logging), then I'm fine with it.
| assistant_message: Message | None = None | ||
| """The model's parsed turn (generator output as a message).""" | ||
|
|
||
| env_messages: list[Message] = field(default_factory=list) # [M_env] |
There was a problem hiding this comment.
how is env_messages related to prompt_messages?
There was a problem hiding this comment.
Prompt_messages: input to the generator (all history up to that point)
assistant_message: output of genereator
env_message: output of the environment, e.g. tool calls, new user message, etc
There was a problem hiding this comment.
butsomewhere else you used next_prompt_...
Is it the same as env_message or a subset?
|
|
||
|
|
||
| @dataclass(frozen=True, kw_only=True, slots=True) | ||
| class DatasetOutput: |
There was a problem hiding this comment.
"Output" sounds confusing, it can be input to rollout / grader
There was a problem hiding this comment.
maybe DataSample?
| from torchtitan.experiments.rl.tasks.sum_digits.grader import SumDigitsRubric | ||
|
|
||
|
|
||
| class SumDigitsTask(Task): |
There was a problem hiding this comment.
The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.
I feel we should do this. The RLTrainer should just own trainer, generator, define task / env and maybe pass generator to them to obtain rollouts.
| Returns: | ||
| One `Reward` per rollout, in input order. | ||
| """ | ||
| return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts)) |
There was a problem hiding this comment.
After this PR we are still in Sync RL, but defining these functions here as they are basic for async RL?
There was a problem hiding this comment.
we do asyncio.gather because the reward_fns can be multiple LLMs, for example. So we can run them in parallel. This is orthogonal to async/sync RL. Does this make sense?
| return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts)) | ||
|
|
||
|
|
||
| def _fn_name(fn: Callable) -> str: |
There was a problem hiding this comment.
What's this helper function for? Can we just inline it?
There was a problem hiding this comment.
We should set up CPU CI test for RL to guard these tests
| """Ground-truth total digit sum.""" | ||
|
|
||
|
|
||
| class SumDigitsDataset(Configurable): |
There was a problem hiding this comment.
What's the consideration of putting the dataset a field of SumDigitsTask.Config() vs. connecting them using TASK_NAME? A Task can be the container of Env, Rubric and Dataset?
There was a problem hiding this comment.
Can we rename to rubrics.py which more aligned with our naming now
| self._validation_dataset = config.validation_dataset.build() | ||
| self._str2task_map: dict[str, Task] = { | ||
| name: cfg.build() for name, cfg in config.tasks.items() | ||
| } |
There was a problem hiding this comment.
We only have one task in our current RL loop now, are you going to support data mix soon? I guess we can simplify in this PR and only support one task for now, and handle multi-tasks together with data-mix PR
| completions, generation_metrics = self._get_rank_0_value( | ||
| self.generator.generate.call(tokenized_prompts).get() | ||
| group_size = self.config.generator.sampling.n | ||
| sampling_cfg = replace( |
There was a problem hiding this comment.
Is stop_token_ids same as eos_ids?
There was a problem hiding this comment.
thats my understanding, but we should get it directly from the renderer.tokenizer. Also, maybe the user can have some specific logic to stop when some token T appears
|
|
||
| Steps: | ||
| 1. Get examples from dataset | ||
| 2. For each example, find associated task, e.g. CodingTask, SearchTask, etc |
There was a problem hiding this comment.
So the reason we don't put dataset a subfield of Task is because of data-mixing? We will do datamixing at DatasetOutput level, not Task level? What does other repo model these concepts?
There was a problem hiding this comment.
What does other repo model these concepts?
I dont recall. I can take a look. Another option is to play it by year, so what we are comfortable with and refactor later when we try datamix. For now, lets put it inside of the Task, since thats a pattern i have seen as well and both and tianyu shared that it made more sense to you. I will make the changes.
| group_id=f"{example.task_name}/step={step}/group={group_offset + group_idx}", | ||
| example=example, | ||
| task=task, | ||
| envs=task.make_envs( |
There was a problem hiding this comment.
Do we need to create enviornments repeatedly for each sample? Eg, spin up a docker for each single CodingTask?
Or can we create a Enviornment for each Task?
There was a problem hiding this comment.
its up to the user to decide what happens inside of "make_envs", i.e. create a fresh new one or pull from a pool
| # 4. For each env, get initial prompt (n_groups * n_rollouts_per_group) | ||
| initial_steps: list[list[TokenizedStepOutput]] = await asyncio.gather( | ||
| *( | ||
| asyncio.gather(*(env.initial_prompt() for env in group.envs)) |
There was a problem hiding this comment.
What is initial_prompt here? Can you give an example? I'm confused why the prompt doesn't come from dataset
There was a problem hiding this comment.
the sample (history of messages) comes from the dataset. The env adds the system message and adds tool calls. The prompt is the final constructed input for the generator. But notice that we are not opinionated about it: if the user wants the env_input to include the system prompt as well, they can do it.
How to review?
a. Read the contents in tasks/sum_digits
b. Read grpo.py:collect_rollouts
c. Read the rest
Summary
Our current script does not use messages or chat template. Now it will be the default. Users write
reset/step_message;a
RendererEnvwraps it and owns all message <-> token plumbing done by theRenderer.Typed rollout records:
RolloutGroup(Rollout(RolloutTurn)) replace the oldTrajectory/(Completion, Step)pairs. They now carry messages and tokens that support multi-turn.Rubric: Class to hold functions for scoring after rollout is finished
It also handles partial scoring in case of truncation and error.
a) create/store Envs for an specific task
b) Holds the
rubricassociated to that taskc) in the future will hold the rollout loop -- which can be customized by users if they want to.
This makes it trivial to do i) dataset mix; ii) share/import tasks
Why?
For single turn we can naively concatenate prompt+response. To enable multiturn we had to enable (1) and (2). Given i was refactoring it, i added (3) and (4)
Blockers/next steps: