Skip to content

[RL] - MessageEnv, Rollout types, Rubric, Renderer#3453

Open
felipemello1 wants to merge 24 commits into
pytorch:mainfrom
felipemello1:v8-datatypes-env-protocol
Open

[RL] - MessageEnv, Rollout types, Rubric, Renderer#3453
felipemello1 wants to merge 24 commits into
pytorch:mainfrom
felipemello1:v8-datatypes-env-protocol

Conversation

@felipemello1
Copy link
Copy Markdown
Contributor

@felipemello1 felipemello1 commented May 29, 2026

How to review?
a. Read the contents in tasks/sum_digits
b. Read grpo.py:collect_rollouts
c. Read the rest

Summary

  1. A message-level env protocol:

Our current script does not use messages or chat template. Now it will be the default. Users write reset / step_message;

class SumDigitsEnv(MessageEnv):
    async def reset(self) -> MsgResponseReset:
		...
    async def step_message(self, msg: Message) -> MsgResponseStep:
        ...

a RendererEnv wraps it and owns all message <-> token plumbing done by the Renderer.

example = dataset.get_sample()
env = RendererEnv( 
	message_env=SumDigitsEnv(env_input=example.env_input),
    renderer=renderer
)
initial_turn = env.reset()
  1. Typed rollout records: RolloutGroup(Rollout(RolloutTurn)) replace the old Trajectory / (Completion, Step) pairs. They now carry messages and tokens that support multi-turn.

  2. Rubric: Class to hold functions for scoring after rollout is finished

class MyRubric(Rubric):
    def register_funcs(self) -> list[RewardFn]:
        return [
            RewardFn(fn=my_reward_fn1, weight=0.5),
            RewardFn(fn=my_reward_fn2, weight=0.5),
        ]

rubric = MyRubric(config=MyRubric.Config(truncation_reward=0.0))
rewards = await rubric.score_group(my_rollouts, env_input)
for reward, rollout in zip(rewards, my_rollouts):
    my_rollout.reward = reward.reward
    my_rollout.reward_components = reward.components

It also handles partial scoring in case of truncation and error.

  1. Task: A class that knows how to

a) create/store Envs for an specific task
b) Holds the rubric associated to that task
c) in the future will hold the rollout loop -- which can be customized by users if they want to.

This makes it trivial to do i) dataset mix; ii) share/import tasks

experiments/rl/
├── grpo.py                  # controller + rollout loop  (changed)
├── renderer.py              # NEW: RendererConfig -> a renderers.Renderer
├── env_types/                    # NEW: the env protocol
│   ├── message_env.py       #   MessageEnv (ABC), ResetOutput, StepOutput
│   └── renderer_env.py      #   RendererEnv, RendererEnvConfig, TokenizedStepOutput
├── rollouts/                # NEW: the datatypes
│   ├── types.py             #   Rollout, RolloutTurn, RolloutGroup, RolloutStatus, DatasetOutput
│   └── utils.py             #   rollout_to_episode, prepare_rollout_metrics
├── rubrics/rubric.py        # NEW: Rubric, RewardFn, Reward
└── tasks/                   # NEW (was sum_digits.py)
    ├── task.py              #   Task
    └── sum_digits/          #   the worked example
        └── data.py · env.py · grader.py · task.py

Why?

For single turn we can naively concatenate prompt+response. To enable multiturn we had to enable (1) and (2). Given i was refactoring it, i added (3) and (4)

Blockers/next steps:

  1. Continuous Batching (CB): Necessary so multiturn Rollouts can progress independently
  2. Async rollout + AlphabetSort multiturn task: Small PR once this and continuous batching lands

wwwjn and others added 18 commits May 20, 2026 13:19
- Add shared pack() function in torchtitan/components/dataloading/utils.py
- Add Batcher class with Configurable pattern: packs episodes into
  fixed [B, seq_length] TrainBatches with gradient accumulation support
- Refactor TrainBatch to use [B, L] tensors instead of [1, total_tokens]
- Update PolicyTrainer.forward_backward to work with [B, L] microbatches
- Simplify compute_logprobs/verify_logprob_identity in actors/utils.py
- envs/: MessageEnv ABC + MessageReset/MessageStep; RendererEnv wrapper
  + RendererEnvConfig + TokenizedTurn; Rollout / RolloutTurn / RolloutStatus
  / DatasetOutput types.
- rubrics/: Rubric base class with score_one / score_group, typed Reward
  dataclass; Rubric owns truncation_reward / error_reward (doc 37 Option B).
- recipes/: Task base + SumDigitsTask (dataset + env + grader + recipe).
- grpo.py: _run_rollouts refactored (Option G — inline orchestration +
  _do_group_step helper + per-group failure isolation); validate() reuses
  the same path.
- Drops: sum_digits.py (orphan), test_grpo_metrics.py (broken), Trajectory
  + Step from types.py.
- rollouts/: new folder holding Rollout / RolloutTurn / RolloutStatus /
  DatasetOutput (types.py) and last_assistant_text / rollout_to_episode /
  prepare_rollout_metrics (utils.py). envs/types.py deleted.
- Rubric: Configurable with register_funcs() hook + lazy @cached_property
  weight normalization (sum-to-1). New RewardFn dataclass (fn + weight).
  truncation_reward / error_reward become Optional[float] short-circuit
  knobs (None = run reward fns on partial response).
- SumDigitsRubric subclass added in grader.py; SumDigitsDataset is now
  Configurable; SumDigitsTask.Config composes nested sub-configs
  (dataset / rubric / env_limits) matching PolicyTrainer.Config style.
- MessageReset/Step -> MsgResponseReset/Step.
- RendererEnvConfig -> EnvLimits; kwarg config -> limits.
- recipes/sum_digits/recipe.py -> task.py.
- _run_rollouts: overflow-aware prompt routing -- initial-prompt overflow
  builds TRUNCATED_OVERFLOW rollouts directly instead of sending an empty
  prompt to the generator.
- prepare_rollout_metrics consolidates the rollout/validation metric
  blocks; _prepare_reward_metrics removed.
- Cleanups: RolloutStatus is_*() via frozenset; validate_rollout_output
  dropped; Rollout.group_id / sample_idx required; _log_samples
  Episode-only; reward_correct stops asserting env_input type.
Task surface
- DatasetOutput.env_name -> DatasetOutput.task; SumDigitsDataset
  ENV_NAME -> TASK_NAME.
- Dataset moves off Task onto RLTrainer.Config:
    train: Task.Config -> train_dataset + tasks dict keyed by task name.
    Same for validation.
  Rows route to the matching Task via example.task.
- SumDigitsTask: drops dataset + sample_example; only rubric,
  env_limits, and make_envs remain.
- Task base gains score_group(rollouts, env_input) -> list[Rollout]:
  default delegates to self.rubric.score_group + fills reward and
  reward_components. Step logic stays in grpo._do_group_step.
- TODOs: continuous-batching Task.do_single_rollout; revisit Camp A
  vs B (dataset on Task vs framework).

Style + cleanups (same files)
- Dataclasses converted from per-field docstrings to Args-at-top with
  inline shape comments. Drops double-backticks and stale PR
  references.
- Bare # title block comments in _run_rollouts, _build_episodes,
  validate, train, _do_group_step. Dashed banners removed in train.
- Dead code: _shard_episodes removed (Batcher does sharding).
- Bug fixes:
  - prepare_rollout_metrics total_lens computed per-rollout
    (multi-turn safe).
  - _do_group_step initial.next_token_ids raises on None instead of
    silent default to [].
  - Rubric: register_funcs result checked for unique fn names.
  - renderer_env: env_step.status defaulting uses replace() not
    in-place mutation; _terminal accepts last_response_messages
    instead of post-construction mutation.

Smoke: imports, config build, rubric e2e (3 statuses + short-circuit
modes), task.score_group fills rewards.
- Dataclasses: per-field docstrings; single backticks; drop internal-doc refs.
- RendererEnv.step_completion parses before classifying length/abort, so
  truncated/aborted/timed-out turns keep their response message + tokens
  (partial-reward grading + debugging).
- Completion.text removed (text comes from rollout messages); RolloutTurn gains
  reward/reward_components; TokenizedTurn carries terminal status.
- Read pad/eos from the renderer's tokenizer; drop the standalone tokenizer.
- Task.score_group returns list[Reward]; controller applies them.
- Condense _run_rollouts; README installs renderers.
Conflicts resolved:
- batcher.py / TrainingBatch labels / compute_logprobs -> take main (landed batcher PR).
- grpo.py / config_registry / types Completion+Episode -> keep v8 (our rollout/env/rubric/recipe layer), adapted to main's batcher API.
- config_registry: keep recipes structure (tasks/dataset), import BatchConfig from rl.batcher, enable_wandb=True (matches main).
- Dropped stale test_grpo_metrics.py (tested the removed Step/Trajectory API).
renderers 0.1.8.dev37 replaced the name-based factory with a typed-config one.
RendererConfig.build picks the config variant for name from the public
discriminated union and passes only supported knobs; adds an enable_thinking
field. Removes the private-attr workaround.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 29, 2026
Felipe Mello added 5 commits May 28, 2026 23:30
The merge with the landed batcher PR pulled in changes that don't belong
to this PR:

- Drop half-done-batcher-base leftovers: BatchConfig in config/configs.py
  (the landed batcher has its own in rl.batcher), config/__init__ export,
  components/dataloading/{__init__,utils}.py, and the text_datasets.py
  rewrite. RL never imports any of these.
- grpo.py: revert train() docstring, reworded comments, and the
  per-microbatch metrics rewrite back to upstream's aggregation; revert the
  setup_async docstring; re-add _shard_episodes (dead in upstream too).
  Keep only the genuine v8 substitutions (Rollout types, async
  _collect_rollouts, _build_episodes).
- _run_rollouts/_collect_rollouts return list[RolloutGroup]; _PendingGroup
  build struct; whole-group drop on prompt overflow; flat n=1 generate with
  completions zipped 1:1 (no rebucket).
- _do_group_step -> _do_single_rollout (one env+completion -> Rollout) with
  try/except inside so partial turns survive as an ERROR rollout.
- RolloutStatus gains ONGOING + ERROR + is_terminal(); TRUNCATED_OVERFLOW ->
  TRUNCATED_PROMPT_OVERFLOW; TokenizedResponseStep.status is required.
- _is_overflow -> _is_prompt_overflow (prompt_len >= max_rollout_tokens);
  drop unused EnvLimits.max_generation_tokens.
- TokenizedTurn -> TokenizedResponseStep; Rubric.register_funcs is
  @abc.abstractmethod; logging TODOs on Rollout/RolloutTurn.
- config_registry: enable_thinking=True, max_tokens=700 for all rl_grpo_qwen3_*
  configs. Qwen3-0.6B one-shots sum_digits with thinking + enough budget
  (reward ~0.95-1.0 from step 1); 100/200 tokens truncate mid-<think>.
- rollouts/types.py: shorten RolloutStatus docstring; group RolloutTurn fields.
…build

- recipes/ -> tasks/; DatasetOutput.task -> task_name.
- env carriers: MsgResponseReset/Step -> ResetOutput/StepOutput;
  TokenizedResponseStep -> TokenizedStepOutput. StepOutput drops status
  (keeps done); RendererEnv owns RolloutStatus.
- EnvLimits -> RendererEnvConfig (field renderer_env_config; RendererEnv(config=)).
- next_token_ids/next_messages -> next_prompt_token_ids/next_prompt_messages.
- last_response_messages/response_messages split -> assistant_message + env_messages.
- RolloutTurn gains prompt_messages; restore env-set reward path:
  StepOutput.reward_components -> TokenizedStepOutput.env_reward_components
  -> RolloutTurn.reward_components.
- renderer.py build() via pydantic TypeAdapter; drop defensive list copy;
  log parse/timeout exceptions; MessageEnv docstring + Example;
  revert 9 unrelated RLTrainer.Config docstrings; 'policy' -> 'limits'.
- _build_episodes: drop a group when any sibling has no turns (turn-less
  ERROR rollout) instead of checking reward-is-None. The old check never
  fired (score_group always sets rewards) and let a turn-less rollout reach
  rollout_to_episode, which raises (requires exactly one turn).
- rollout_to_episode: derive text from the rollout (last_assistant_text)
  instead of taking it as a param; drop the now-unused import in grpo.
@felipemello1 felipemello1 marked this pull request as ready for review May 29, 2026 21:52
@felipemello1 felipemello1 requested review from tianyu-l and wwwjn May 29, 2026 21:52
@felipemello1 felipemello1 changed the title [NOT READY][RL] - MessageEnv, Rollout types, Rubric, Renderer [RL] - MessageEnv, Rollout types, Rubric, Renderer May 29, 2026
- Apply 71_docstring_concision: tighten/dedupe docstrings & comments across
  9 files (incl. _run_rollouts mental-model/step-list dedup, rubric duplicate
  Example, RewardFn.weight grammar, _score_one stray 'S', 'umbers' typo).
  Docstrings/comments only; no code changes.
- Rename envs/ package -> env_types/ (message_env, renderer_env, __init__);
  update all imports.
from torchtitan.experiments.rl.tasks.sum_digits.grader import SumDigitsRubric


class SumDigitsTask(Task):
Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class guides most of the discussion here. RL requires multiple components to travel together: Dataset, rubrics, env, rollout loop logic. E.g. you don't use the SearchRubric to grade CodingDataset.

Ideally, you want an 'Agent' or 'Workflow' that holds all of this together.

The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.

This makes training easier, because now I can datamix my workflows. I can share them. I can import them.

This class is a step in this direction. I want to eventually put the rollout logic here as well. That way, workflow_A can be different than workflow_B, and the controller doesn't need to know that.

We will have to adapt and understand how stateful we want this class to be. Currently, the controller stitches things together, e.g.

def run_rollout():
	sample = Task.get_sample
	envs = Task.make_envs(sample)
	...
	Task.score_group(RolloutGroup)

but in the future, it could just be:

Task.run_rollout()

Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 May 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the user, it means that they can easily swap our Task with something like another harness like https://github.com/NVIDIA-NeMo/ProRL-Agent-Server or rllm, or vice versa.

I think it is the right mental model, but we need to learn how to package it well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.

I feel we should do this. The RLTrainer should just own trainer, generator, define task / env and maybe pass generator to them to obtain rollouts.



@dataclass(kw_only=True, slots=True)
class RolloutTurn:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With turn, the user has a full snapshot of every point in the rollout, from a token and message perspective. We can log this to a json, and debugging will be beautiful :)

Comment on lines +83 to +85
@abc.abstractmethod
def register_funcs(self) -> list[RewardFn]:
"""Return this rubric's reward fns + weights (see class Example)."""
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User can define anything here: other LLM calls, reward models, simple functions, etc.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can they just define a Rubric and use it in a config, instead of register

self.rubric = config.rubric.build()
self.renderer_env_config = config.renderer_env_config

def make_envs(
Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 May 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we currently always create a new env, but the user can be creative, e.g. have a pool of envs, in case creating an env is expensive.

"""Per-reward-fn raw output, keyed by `fn.__name__`."""


class Rubric(Configurable, abc.ABC):
Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 May 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to have the env provide the reward. This seems to be an antipattern. Most libraries have rubrics applied after the rollout is completed. It makes sense: This can be another LLM, for example. Also, you might want to compare samples or rerank them, so having access to the whole group is necessary.

A downside is that, if you have a sandbox in you env, you might need the same sandbox in your Rubric to evaluate the answer. But there are ways to mitigate it.

"""Ground-truth total digit sum."""


class SumDigitsDataset(Configurable):
Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 May 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some libraries have the dataset live inside of the env or the 'Task'/workflow/agent. It makes sense: the data is attached to those.

Another option is like Tinker: Have the dataset yield a builder, and that builder produce an env.

builder = dataset.sample()
envs = builder.make_envs()

I decided to have the dataset be its own class, outside of the env, so we could easily create interleaved dataset, dataloader, etc, without overthinking it.

I also decided to link the dataset to the Env by having a task_name field, instead of returning a builder. My thought process is that the user may want to have a pool of envs, for example, instead of always creating a new one. It made sense to me to let the Task.make_envs hold that logic, isntead of overloading the dataset with it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't read enough to understand how "task_name" links things together.

But it sounds very natural to me that we should couple Env / Task and dataset, one way or another.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consideration of putting the dataset a field of SumDigitsTask.Config() vs. connecting them using TASK_NAME? A Task can be the container of Env, Rubric and Dataset?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can put the dataset in the Task. The reason I didn't is that, when we do datamix, now we would have to do at the task level, instead of regular dataloader(interleaved_dataset). Its a trade-off.

Hmmm, perhaps we can do:

all_datasets = []
for ds in all_datasets:
	ds = task.get_dataset()

final_ds = interleaved(all_datasets)

Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, did a first pass. Overall I feel the logic among Env, Task, Rollout, Dataset, Rubric could be more clear.

return sum(s.reward for _, s in self.transitions)


# TODO: rename `Episode` -> `TrainSample` and `rollout_to_episode` ->
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: TrainingSample, to be consistent with TrainingBatch

uv pip install torchmonarch==0.4.1
uv pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main"
uv pip install pygtrie portpicker
uv pip install "git+https://github.com/PrimeIntellect-ai/renderers.git@main"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be used for sft as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to depend on renders lastest main?

preserve_thinking_between_tool_calls: bool = False

def build(self, *, model_path: str) -> Renderer:
from transformers import AutoTokenizer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this dependency?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

`model_path` and constructs the `renderers` config matching `name`.

Args:
name: Renderer name (e.g. `"qwen3"`, `"auto"`).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear what this means. We should make it explicitly referring to torchtitan model (and their tokenizer) and avoid transformers dependency.

top_p=_sampling_config.top_p,
max_tokens=_sampling_config.max_tokens,
n=_sampling_config.n,
stop_token_ids=list(_sampling_config.stop_token_ids) or None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stop_token_ids=list(_sampling_config.stop_token_ids) or None,
stop_token_ids=_sampling_config.stop_token_ids or None,

"""The env's reply messages this turn (tool / user)."""

# For rubrics
reward_components: dict[str, float] = field(default_factory=dict)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we call it reward which should always decompose into "components"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we want to have a simple field "reward: float", because thats what the final loss will use. No ambiguity here. And components is the breakdown for logging or weight averaging or some specific advantage computation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If single field reward is always coming from the components, can we have cached_property that derive the single field from components. If not (instead single field is the only thing RL algorithm eventually care, and everything else is only for logging), then I'm fine with it.

assistant_message: Message | None = None
"""The model's parsed turn (generator output as a message)."""

env_messages: list[Message] = field(default_factory=list) # [M_env]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is env_messages related to prompt_messages?

Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt_messages: input to the generator (all history up to that point)
assistant_message: output of genereator
env_message: output of the environment, e.g. tool calls, new user message, etc

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

butsomewhere else you used next_prompt_...

Is it the same as env_message or a subset?

Comment thread torchtitan/experiments/rl/rollouts/types.py


@dataclass(frozen=True, kw_only=True, slots=True)
class DatasetOutput:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Output" sounds confusing, it can be input to rollout / grader

Copy link
Copy Markdown
Contributor Author

@felipemello1 felipemello1 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe DataSample?

from torchtitan.experiments.rl.tasks.sum_digits.grader import SumDigitsRubric


class SumDigitsTask(Task):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The controller can just do: workflow.run_rollout, and it will get the rollout, without knowing about it's internals.

I feel we should do this. The RLTrainer should just own trainer, generator, define task / env and maybe pass generator to them to obtain rollouts.

Returns:
One `Reward` per rollout, in input order.
"""
return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR we are still in Sync RL, but defining these functions here as they are basic for async RL?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do asyncio.gather because the reward_fns can be multiple LLMs, for example. So we can run them in parallel. This is orthogonal to async/sync RL. Does this make sense?

return await asyncio.gather(*(self._score_one(r, env_input) for r in rollouts))


def _fn_name(fn: Callable) -> str:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this helper function for? Can we just inline it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should set up CPU CI test for RL to guard these tests

"""Ground-truth total digit sum."""


class SumDigitsDataset(Configurable):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consideration of putting the dataset a field of SumDigitsTask.Config() vs. connecting them using TASK_NAME? A Task can be the container of Env, Rubric and Dataset?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename to rubrics.py which more aligned with our naming now

self._validation_dataset = config.validation_dataset.build()
self._str2task_map: dict[str, Task] = {
name: cfg.build() for name, cfg in config.tasks.items()
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only have one task in our current RL loop now, are you going to support data mix soon? I guess we can simplify in this PR and only support one task for now, and handle multi-tasks together with data-mix PR

completions, generation_metrics = self._get_rank_0_value(
self.generator.generate.call(tokenized_prompts).get()
group_size = self.config.generator.sampling.n
sampling_cfg = replace(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is stop_token_ids same as eos_ids?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats my understanding, but we should get it directly from the renderer.tokenizer. Also, maybe the user can have some specific logic to stop when some token T appears


Steps:
1. Get examples from dataset
2. For each example, find associated task, e.g. CodingTask, SearchTask, etc
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason we don't put dataset a subfield of Task is because of data-mixing? We will do datamixing at DatasetOutput level, not Task level? What does other repo model these concepts?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does other repo model these concepts?

I dont recall. I can take a look. Another option is to play it by year, so what we are comfortable with and refactor later when we try datamix. For now, lets put it inside of the Task, since thats a pattern i have seen as well and both and tianyu shared that it made more sense to you. I will make the changes.

group_id=f"{example.task_name}/step={step}/group={group_offset + group_idx}",
example=example,
task=task,
envs=task.make_envs(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to create enviornments repeatedly for each sample? Eg, spin up a docker for each single CodingTask?

Or can we create a Enviornment for each Task?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its up to the user to decide what happens inside of "make_envs", i.e. create a fresh new one or pull from a pool

# 4. For each env, get initial prompt (n_groups * n_rollouts_per_group)
initial_steps: list[list[TokenizedStepOutput]] = await asyncio.gather(
*(
asyncio.gather(*(env.initial_prompt() for env in group.envs))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is initial_prompt here? Can you give an example? I'm confused why the prompt doesn't come from dataset

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sample (history of messages) comes from the dataset. The env adds the system message and adds tool calls. The prompt is the final constructed input for the generator. But notice that we are not opinionated about it: if the user wants the env_input to include the system prompt as well, they can do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants