-
Notifications
You must be signed in to change notification settings - Fork 286
[train] Prefix-aware merge for step-wise trajectories (#1277) #1377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| """ | ||
| Prefix-aware merging of step-wise trajectory turns for training. | ||
|
|
||
| When step_wise_trajectories=True, each turn is initially a separate sample. | ||
| We merge consecutive turns into fewer samples only when the next turn's prompt | ||
| token IDs have the previous full sequence (prompt + response) as an exact prefix. | ||
| Otherwise we keep them as separate samples (token-id prefix match only). | ||
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List, Optional, Tuple | ||
|
|
||
|
|
||
| def _is_prefix(sequence: List[int], candidate: List[int]) -> bool: | ||
| """Check if sequence is a prefix of candidate (exact token-id match).""" | ||
| if len(sequence) > len(candidate): | ||
| return False | ||
| return sequence == candidate[: len(sequence)] | ||
|
|
||
|
|
||
| @dataclass | ||
| class MergedStepWiseSample: | ||
| """A single training sample after merging one or more step-wise turns.""" | ||
|
|
||
| prompt_token_ids: List[int] | ||
| response_ids: List[int] | ||
| rewards: List[float] | ||
| loss_masks: List[int] | ||
| rollout_logprobs: Optional[List[float]] = None | ||
| is_last_step: bool = False | ||
|
|
||
|
|
||
| def merge_step_wise_turns_for_trajectory( | ||
| prompt_token_ids: List[List[int]], | ||
| response_ids: List[List[int]], | ||
| rewards: List[List[float]], | ||
| loss_masks: List[List[int]], | ||
| is_last_step: List[bool], | ||
| rollout_logprobs: Optional[List[List[float]]] = None, | ||
| ) -> Tuple[List[MergedStepWiseSample], int]: | ||
| """ | ||
| Merge consecutive turns for a single trajectory when the next observation | ||
| has the previous full sequence (prompt + response) as an exact prefix. | ||
|
|
||
| No data leakage: prompt is the first turn's observation only; response is | ||
| resp1 + delta_ob2 + resp2 + ... (delta_ob tokens have zero loss mask). | ||
|
|
||
| Args: | ||
| prompt_token_ids: Per-turn prompt (observation) token IDs. | ||
| response_ids: Per-turn response (action) token IDs. | ||
| rewards: Per-turn per-token rewards (list of lists). | ||
| loss_masks: Per-turn loss masks (list of lists). | ||
| is_last_step: Per-turn flag True only on the final turn of the trajectory. | ||
| rollout_logprobs: Optional per-turn rollout logprobs (list of lists). | ||
|
|
||
| Returns: | ||
| (merged_samples, prefix_mismatch_count) | ||
| - merged_samples: List of merged training samples for this trajectory. | ||
| - prefix_mismatch_count: Number of times we did not merge due to prefix mismatch. | ||
| """ | ||
| n = len(prompt_token_ids) | ||
| assert n == len(response_ids) == len(rewards) == len(loss_masks) == len(is_last_step) | ||
| if rollout_logprobs is not None: | ||
| assert len(rollout_logprobs) == n | ||
|
|
||
| merged: List[MergedStepWiseSample] = [] | ||
| prefix_mismatch_count = 0 | ||
|
|
||
| # Full sequence so far (obs + response) for prefix check only | ||
| full_sequence: List[int] = [] | ||
| # Initial observation only — prompt so that prompt + response = correct full sequence with no overlap | ||
| initial_prompt: List[int] = [] | ||
| # Response stream: resp1 + delta_ob2 + resp2 + ... (delta_ob with zero loss so no duplicate tokens) | ||
| acc_response_ids: List[int] = [] | ||
| acc_rewards: List[float] = [] | ||
| acc_loss_masks: List[int] = [] | ||
| acc_logprobs: List[float] = [] | ||
| acc_is_last_step = False | ||
|
|
||
| def flush() -> None: | ||
| """Emit current accumulated sample and reset.""" | ||
| nonlocal full_sequence, initial_prompt, acc_response_ids, acc_rewards, acc_loss_masks, acc_logprobs, acc_is_last_step | ||
| if not initial_prompt and not acc_response_ids: | ||
| return | ||
| merged.append( | ||
| MergedStepWiseSample( | ||
| prompt_token_ids=list(initial_prompt), | ||
| response_ids=list(acc_response_ids), | ||
| rewards=list(acc_rewards), | ||
| loss_masks=list(acc_loss_masks), | ||
| rollout_logprobs=list(acc_logprobs) if (rollout_logprobs is not None) else None, | ||
| is_last_step=acc_is_last_step, | ||
| ) | ||
| ) | ||
| full_sequence = [] | ||
| initial_prompt = [] | ||
| acc_response_ids = [] | ||
| acc_rewards = [] | ||
| acc_loss_masks = [] | ||
| acc_logprobs = [] | ||
| acc_is_last_step = False | ||
|
|
||
| for i in range(n): | ||
| ob_tokens = prompt_token_ids[i] | ||
| ac_tokens = response_ids[i] | ||
| ac_rewards = rewards[i] | ||
| ac_masks = loss_masks[i] | ||
| ac_logprobs_i = rollout_logprobs[i] if rollout_logprobs is not None else [0.0] * len(ac_tokens) | ||
|
|
||
| if len(full_sequence) == 0: | ||
| delta_ob = ob_tokens | ||
| initial_prompt = list(delta_ob) | ||
| elif _is_prefix(full_sequence, ob_tokens): | ||
| delta_ob = ob_tokens[len(full_sequence) :] | ||
| # Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence | ||
| acc_response_ids.extend(delta_ob) | ||
| acc_rewards.extend([0.0] * len(delta_ob)) | ||
| acc_loss_masks.extend([0] * len(delta_ob)) | ||
| acc_logprobs.extend([0.0] * len(delta_ob)) | ||
| else: | ||
| prefix_mismatch_count += 1 | ||
| flush() | ||
| delta_ob = ob_tokens | ||
| initial_prompt = list(delta_ob) | ||
|
|
||
| full_sequence.extend(delta_ob) | ||
| full_sequence.extend(ac_tokens) | ||
| acc_response_ids.extend(ac_tokens) | ||
| acc_rewards.extend(ac_rewards) | ||
| acc_loss_masks.extend(ac_masks) | ||
| acc_logprobs.extend(ac_logprobs_i) | ||
| acc_is_last_step = acc_is_last_step or is_last_step[i] | ||
|
|
||
| flush() | ||
| return merged, prefix_mismatch_count | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |||||
| import shutil | ||||||
| from collections import defaultdict | ||||||
| from dataclasses import asdict | ||||||
| from itertools import groupby | ||||||
| from pathlib import Path | ||||||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||||||
|
|
||||||
|
|
@@ -52,6 +53,10 @@ | |||||
| GeneratorInput, | ||||||
| GeneratorInterface, | ||||||
| GeneratorOutput, | ||||||
| TrajectoryID, | ||||||
| ) | ||||||
| from skyrl.train.step_wise_merge import ( | ||||||
| merge_step_wise_turns_for_trajectory, | ||||||
| ) | ||||||
| from skyrl.train.generators.utils import ( | ||||||
| get_metrics_from_generator_output, | ||||||
|
|
@@ -613,6 +618,83 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis | |||||
| "rollout_expert_indices", None | ||||||
| ) | ||||||
|
|
||||||
| num_samples_before_merge = len(prompt_ids) | ||||||
|
|
||||||
| if self.cfg.generator.step_wise_trajectories: | ||||||
| assert "trajectory_ids" in generator_output and "is_last_step" in generator_output | ||||||
| trajectory_ids_raw: List[TrajectoryID] = generator_output["trajectory_ids"] | ||||||
| is_last_step_list: List[bool] = generator_output["is_last_step"] | ||||||
|
|
||||||
| # Group consecutive indices by same trajectory (instance_id + repetition_id). | ||||||
| # groupby merges only adjacent runs with the same key; trajectory_ids_raw must list | ||||||
| # all turns of a trajectory in one contiguous block (no interleaving trajectories). | ||||||
| groups: List[Tuple[TrajectoryID, List[int]]] = [] | ||||||
| for _, group in groupby(enumerate(trajectory_ids_raw), key=lambda x: x[1].to_string()): | ||||||
| indices = [i for i, _ in group] | ||||||
| if indices: | ||||||
| groups.append((trajectory_ids_raw[indices[0]], indices)) | ||||||
|
|
||||||
| merged_prompt_ids: List[List[int]] = [] | ||||||
| merged_response_ids: List[List[int]] = [] | ||||||
| merged_rewards: List[List[float]] = [] | ||||||
| merged_loss_masks: List[List[int]] = [] | ||||||
| merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||
| merged_is_last_step: List[bool] = [] | ||||||
| merged_trajectory_ids: List[TrajectoryID] = [] | ||||||
| total_prefix_mismatch = 0 | ||||||
|
|
||||||
| for traj_id, indices in groups: | ||||||
| turn_prompts = [prompt_ids[j] for j in indices] | ||||||
| turn_responses = [response_ids[j] for j in indices] | ||||||
| turn_rewards = [rewards[j] for j in indices] | ||||||
| turn_masks = [loss_masks[j] for j in indices] | ||||||
| turn_is_last = [is_last_step_list[j] for j in indices] | ||||||
| turn_logprobs = [logprobs[j] for j in indices] if logprobs else None | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to a previous line, using
Suggested change
|
||||||
|
|
||||||
| samples, mismatch_count = merge_step_wise_turns_for_trajectory( | ||||||
| prompt_token_ids=turn_prompts, | ||||||
| response_ids=turn_responses, | ||||||
| rewards=turn_rewards, | ||||||
| loss_masks=turn_masks, | ||||||
| is_last_step=turn_is_last, | ||||||
| rollout_logprobs=turn_logprobs, | ||||||
| ) | ||||||
| total_prefix_mismatch += mismatch_count | ||||||
| for s in samples: | ||||||
| merged_prompt_ids.append(s.prompt_token_ids) | ||||||
| merged_response_ids.append(s.response_ids) | ||||||
| merged_rewards.append(s.rewards) | ||||||
| merged_loss_masks.append(s.loss_masks) | ||||||
| if merged_logprobs is not None: | ||||||
| merged_logprobs.append(s.rollout_logprobs) | ||||||
| merged_is_last_step.append(s.is_last_step) | ||||||
| merged_trajectory_ids.append(traj_id) | ||||||
|
|
||||||
| num_samples_after_merge = len(merged_prompt_ids) | ||||||
| prompt_ids = merged_prompt_ids | ||||||
| response_ids = merged_response_ids | ||||||
| rewards = merged_rewards | ||||||
| loss_masks = merged_loss_masks | ||||||
| logprobs = merged_logprobs | ||||||
| generator_output = { | ||||||
| **generator_output, | ||||||
| "prompt_token_ids": prompt_ids, | ||||||
| "response_ids": response_ids, | ||||||
| "rewards": rewards, | ||||||
| "loss_masks": loss_masks, | ||||||
| "rollout_logprobs": logprobs, | ||||||
| "is_last_step": merged_is_last_step, | ||||||
| "trajectory_ids": merged_trajectory_ids, | ||||||
| } | ||||||
| uids = [tid.instance_id for tid in merged_trajectory_ids] | ||||||
|
|
||||||
| self.all_metrics["trainer/stepwise_num_samples_before"] = num_samples_before_merge | ||||||
| self.all_metrics["trainer/stepwise_num_samples_after"] = num_samples_after_merge | ||||||
| self.all_metrics["trainer/stepwise_merge_ratio"] = ( | ||||||
| num_samples_after_merge / num_samples_before_merge if num_samples_before_merge else 0.0 | ||||||
| ) | ||||||
| self.all_metrics["trainer/stepwise_prefix_mismatch_count"] = total_prefix_mismatch | ||||||
|
|
||||||
| ( | ||||||
| sequences_tensor, | ||||||
| attention_masks_tensor, | ||||||
|
|
@@ -676,9 +758,19 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis | |||||
| training_input.metadata["trajectory_ids"] = [ | ||||||
| trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"] | ||||||
| ] | ||||||
| training_input.metadata["avg_response_length"] = sum( | ||||||
| len(sample_response_ids) for sample_response_ids in response_ids | ||||||
| ) / len(response_ids) | ||||||
| last_step_response_lens = [ | ||||||
| len(sample_response_ids) | ||||||
| for sample_response_ids, is_last in zip(response_ids, generator_output["is_last_step"]) | ||||||
| if is_last | ||||||
| ] | ||||||
| num_last_steps = len(last_step_response_lens) | ||||||
| training_input.metadata["avg_response_length"] = ( | ||||||
| sum(last_step_response_lens) / num_last_steps if num_last_steps else 0.0 | ||||||
| ) | ||||||
| else: | ||||||
| training_input.metadata["avg_response_length"] = sum( | ||||||
| len(sample_response_ids) for sample_response_ids in response_ids | ||||||
| ) / len(response_ids) | ||||||
|
|
||||||
| logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}") | ||||||
| training_input = self.pad_batch(training_input) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,113 @@ | ||||||||||||||||||
| """ | ||||||||||||||||||
| Unit tests for prefix-aware step-wise merge (issue #1277). | ||||||||||||||||||
|
|
||||||||||||||||||
| Run: uv run --isolated --extra dev pytest tests/train/test_step_wise_merge.py -v | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| import pytest | ||||||||||||||||||
| from skyrl.train.step_wise_merge import ( | ||||||||||||||||||
| _is_prefix, | ||||||||||||||||||
| merge_step_wise_turns_for_trajectory, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def test_is_prefix(): | ||||||||||||||||||
| assert _is_prefix([], [1, 2, 3]) is True | ||||||||||||||||||
| assert _is_prefix([1, 2], [1, 2, 3]) is True | ||||||||||||||||||
| assert _is_prefix([1, 2, 3], [1, 2, 3]) is True | ||||||||||||||||||
| assert _is_prefix([1, 2, 3], [1, 2]) is False | ||||||||||||||||||
| assert _is_prefix([1, 99], [1, 2, 3]) is False | ||||||||||||||||||
| assert _is_prefix([1], [1]) is True | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def test_merge_works(): | ||||||||||||||||||
| """Turn2 observation = Turn1(obs+act) + extra → 1 merged sample.""" | ||||||||||||||||||
| # Turn 1: prompt [10,20,30], response [40,50] → full sequence [10,20,30,40,50] | ||||||||||||||||||
| # Turn 2: prompt [10,20,30,40,50,60,70] (prefix match + extra 60,70) | ||||||||||||||||||
| prompt_ids = [ | ||||||||||||||||||
| [10, 20, 30], | ||||||||||||||||||
| [10, 20, 30, 40, 50, 60, 70], | ||||||||||||||||||
| ] | ||||||||||||||||||
| response_ids = [[40, 50], [80, 90]] | ||||||||||||||||||
| rewards = [[0.0, 0.0], [0.0, 0.0]] | ||||||||||||||||||
| loss_masks = [[1, 1], [1, 1]] | ||||||||||||||||||
| is_last_step = [False, True] | ||||||||||||||||||
|
|
||||||||||||||||||
| merged, mismatch_count = merge_step_wise_turns_for_trajectory( | ||||||||||||||||||
| prompt_token_ids=prompt_ids, | ||||||||||||||||||
| response_ids=response_ids, | ||||||||||||||||||
| rewards=rewards, | ||||||||||||||||||
| loss_masks=loss_masks, | ||||||||||||||||||
| is_last_step=is_last_step, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| assert len(merged) == 1 | ||||||||||||||||||
| assert mismatch_count == 0 | ||||||||||||||||||
| # Merged prompt = full context = [10,20,30,40,50,60,70], response = [40,50,80,90] | ||||||||||||||||||
| assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] | ||||||||||||||||||
| assert merged[0].response_ids == [40, 50, 80, 90] | ||||||||||||||||||
|
Comment on lines
+47
to
+48
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These assertions seem to contradict the implementation and docstring of Based on the logic in
The current assertions imply a different merging strategy where the prompt is the full context up to the last action. Please verify the intended merging logic and update either the implementation or the test assertions to be consistent.
Suggested change
Comment on lines
+47
to
+48
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Test assertions in test_merge_works contradict merge function's actual output The test asserts
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||
| assert merged[0].is_last_step is True | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def test_prefix_mismatch(): | ||||||||||||||||||
| """Turn2 observation does not start with previous sequence → 2 samples.""" | ||||||||||||||||||
| prompt_ids = [ | ||||||||||||||||||
| [10, 20, 30], | ||||||||||||||||||
| [99, 88, 77], | ||||||||||||||||||
| ] | ||||||||||||||||||
| response_ids = [[40, 50], [11, 22]] | ||||||||||||||||||
| rewards = [[0.0, 0.0], [0.0, 0.0]] | ||||||||||||||||||
| loss_masks = [[1, 1], [1, 1]] | ||||||||||||||||||
| is_last_step = [False, True] | ||||||||||||||||||
|
|
||||||||||||||||||
| merged, mismatch_count = merge_step_wise_turns_for_trajectory( | ||||||||||||||||||
| prompt_token_ids=prompt_ids, | ||||||||||||||||||
| response_ids=response_ids, | ||||||||||||||||||
| rewards=rewards, | ||||||||||||||||||
| loss_masks=loss_masks, | ||||||||||||||||||
| is_last_step=is_last_step, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| assert len(merged) == 2 | ||||||||||||||||||
| assert mismatch_count == 1 | ||||||||||||||||||
| assert merged[0].prompt_token_ids == [10, 20, 30] | ||||||||||||||||||
| assert merged[0].response_ids == [40, 50] | ||||||||||||||||||
| assert merged[0].is_last_step is False | ||||||||||||||||||
| assert merged[1].prompt_token_ids == [99, 88, 77] | ||||||||||||||||||
| assert merged[1].response_ids == [11, 22] | ||||||||||||||||||
| assert merged[1].is_last_step is True | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def test_partial_merge(): | ||||||||||||||||||
| """Turn1→Turn2 merge, Turn3 mismatches → 2 samples.""" | ||||||||||||||||||
| # Turn 1: [1,2,3] + [4,5] | ||||||||||||||||||
| # Turn 2: [1,2,3,4,5,6,7] (prefix) + [8,9] | ||||||||||||||||||
| # Turn 3: [100,200] (mismatch) + [11,22] | ||||||||||||||||||
| prompt_ids = [ | ||||||||||||||||||
| [1, 2, 3], | ||||||||||||||||||
| [1, 2, 3, 4, 5, 6, 7], | ||||||||||||||||||
| [100, 200], | ||||||||||||||||||
| ] | ||||||||||||||||||
| response_ids = [[4, 5], [8, 9], [11, 22]] | ||||||||||||||||||
| rewards = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] | ||||||||||||||||||
| loss_masks = [[1, 1], [1, 1], [1, 1]] | ||||||||||||||||||
| is_last_step = [False, False, True] | ||||||||||||||||||
|
|
||||||||||||||||||
| merged, mismatch_count = merge_step_wise_turns_for_trajectory( | ||||||||||||||||||
| prompt_token_ids=prompt_ids, | ||||||||||||||||||
| response_ids=response_ids, | ||||||||||||||||||
| rewards=rewards, | ||||||||||||||||||
| loss_masks=loss_masks, | ||||||||||||||||||
| is_last_step=is_last_step, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| assert len(merged) == 2 | ||||||||||||||||||
| assert mismatch_count == 1 | ||||||||||||||||||
| # First sample: merged turn1+2 | ||||||||||||||||||
| assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] | ||||||||||||||||||
| assert merged[0].response_ids == [4, 5, 8, 9] | ||||||||||||||||||
|
Comment on lines
+107
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These assertions are inconsistent with the merging logic in Based on the implementation, the expected values for the first merged sample (turns 1 and 2) should be:
Please align the test with the implementation's logic.
Suggested change
Comment on lines
+107
to
+108
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Test assertions in test_partial_merge contradict merge function's actual output Same issue as
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||||||||||||
| assert merged[0].is_last_step is False | ||||||||||||||||||
| # Second sample: turn3 only | ||||||||||||||||||
| assert merged[1].prompt_token_ids == [100, 200] | ||||||||||||||||||
| assert merged[1].response_ids == [11, 22] | ||||||||||||||||||
| assert merged[1].is_last_step is True | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for handling
rollout_logprobsis inefficient when it isNone. It creates and extends lists with zeros, which are then discarded in theflushfunction. This can be optimized by only performing logprob-related operations whenrollout_logprobsis notNone.