Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions skyrl/train/step_wise_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
Prefix-aware merging of step-wise trajectory turns for training.

When step_wise_trajectories=True, each turn is initially a separate sample.
We merge consecutive turns into fewer samples only when the next turn's prompt
token IDs have the previous full sequence (prompt + response) as an exact prefix.
Otherwise we keep them as separate samples (token-id prefix match only).
"""

from dataclasses import dataclass
from typing import List, Optional, Tuple


def _is_prefix(sequence: List[int], candidate: List[int]) -> bool:
"""Check if sequence is a prefix of candidate (exact token-id match)."""
if len(sequence) > len(candidate):
return False
return sequence == candidate[: len(sequence)]


@dataclass
class MergedStepWiseSample:
"""A single training sample after merging one or more step-wise turns."""

prompt_token_ids: List[int]
response_ids: List[int]
rewards: List[float]
loss_masks: List[int]
rollout_logprobs: Optional[List[float]] = None
is_last_step: bool = False


def merge_step_wise_turns_for_trajectory(
prompt_token_ids: List[List[int]],
response_ids: List[List[int]],
rewards: List[List[float]],
loss_masks: List[List[int]],
is_last_step: List[bool],
rollout_logprobs: Optional[List[List[float]]] = None,
) -> Tuple[List[MergedStepWiseSample], int]:
"""
Merge consecutive turns for a single trajectory when the next observation
has the previous full sequence (prompt + response) as an exact prefix.

No data leakage: prompt is the first turn's observation only; response is
resp1 + delta_ob2 + resp2 + ... (delta_ob tokens have zero loss mask).

Args:
prompt_token_ids: Per-turn prompt (observation) token IDs.
response_ids: Per-turn response (action) token IDs.
rewards: Per-turn per-token rewards (list of lists).
loss_masks: Per-turn loss masks (list of lists).
is_last_step: Per-turn flag True only on the final turn of the trajectory.
rollout_logprobs: Optional per-turn rollout logprobs (list of lists).

Returns:
(merged_samples, prefix_mismatch_count)
- merged_samples: List of merged training samples for this trajectory.
- prefix_mismatch_count: Number of times we did not merge due to prefix mismatch.
"""
n = len(prompt_token_ids)
assert n == len(response_ids) == len(rewards) == len(loss_masks) == len(is_last_step)
if rollout_logprobs is not None:
assert len(rollout_logprobs) == n

merged: List[MergedStepWiseSample] = []
prefix_mismatch_count = 0

# Full sequence so far (obs + response) for prefix check only
full_sequence: List[int] = []
# Initial observation only — prompt so that prompt + response = correct full sequence with no overlap
initial_prompt: List[int] = []
# Response stream: resp1 + delta_ob2 + resp2 + ... (delta_ob with zero loss so no duplicate tokens)
acc_response_ids: List[int] = []
acc_rewards: List[float] = []
acc_loss_masks: List[int] = []
acc_logprobs: List[float] = []
acc_is_last_step = False

def flush() -> None:
"""Emit current accumulated sample and reset."""
nonlocal full_sequence, initial_prompt, acc_response_ids, acc_rewards, acc_loss_masks, acc_logprobs, acc_is_last_step
if not initial_prompt and not acc_response_ids:
return
merged.append(
MergedStepWiseSample(
prompt_token_ids=list(initial_prompt),
response_ids=list(acc_response_ids),
rewards=list(acc_rewards),
loss_masks=list(acc_loss_masks),
rollout_logprobs=list(acc_logprobs) if (rollout_logprobs is not None) else None,
is_last_step=acc_is_last_step,
)
)
full_sequence = []
initial_prompt = []
acc_response_ids = []
acc_rewards = []
acc_loss_masks = []
acc_logprobs = []
acc_is_last_step = False

for i in range(n):
ob_tokens = prompt_token_ids[i]
ac_tokens = response_ids[i]
ac_rewards = rewards[i]
ac_masks = loss_masks[i]
ac_logprobs_i = rollout_logprobs[i] if rollout_logprobs is not None else [0.0] * len(ac_tokens)

if len(full_sequence) == 0:
delta_ob = ob_tokens
initial_prompt = list(delta_ob)
elif _is_prefix(full_sequence, ob_tokens):
delta_ob = ob_tokens[len(full_sequence) :]
# Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence
acc_response_ids.extend(delta_ob)
acc_rewards.extend([0.0] * len(delta_ob))
acc_loss_masks.extend([0] * len(delta_ob))
acc_logprobs.extend([0.0] * len(delta_ob))
else:
prefix_mismatch_count += 1
flush()
delta_ob = ob_tokens
initial_prompt = list(delta_ob)

full_sequence.extend(delta_ob)
full_sequence.extend(ac_tokens)
acc_response_ids.extend(ac_tokens)
acc_rewards.extend(ac_rewards)
acc_loss_masks.extend(ac_masks)
acc_logprobs.extend(ac_logprobs_i)
Comment on lines +108 to +131
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for handling rollout_logprobs is inefficient when it is None. It creates and extends lists with zeros, which are then discarded in the flush function. This can be optimized by only performing logprob-related operations when rollout_logprobs is not None.

        if len(full_sequence) == 0:
            delta_ob = ob_tokens
            initial_prompt = list(delta_ob)
        elif _is_prefix(full_sequence, ob_tokens):
            delta_ob = ob_tokens[len(full_sequence) :]
            # Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence
            acc_response_ids.extend(delta_ob)
            acc_rewards.extend([0.0] * len(delta_ob))
            acc_loss_masks.extend([0] * len(delta_ob))
            if rollout_logprobs is not None:
                acc_logprobs.extend([0.0] * len(delta_ob))
        else:
            prefix_mismatch_count += 1
            flush()
            delta_ob = ob_tokens
            initial_prompt = list(delta_ob)

        full_sequence.extend(delta_ob)
        full_sequence.extend(ac_tokens)
        acc_response_ids.extend(ac_tokens)
        acc_rewards.extend(ac_rewards)
        acc_loss_masks.extend(ac_masks)
        if rollout_logprobs is not None:
            acc_logprobs.extend(rollout_logprobs[i])

acc_is_last_step = acc_is_last_step or is_last_step[i]

flush()
return merged, prefix_mismatch_count
98 changes: 95 additions & 3 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
from collections import defaultdict
from dataclasses import asdict
from itertools import groupby
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -52,6 +53,10 @@
GeneratorInput,
GeneratorInterface,
GeneratorOutput,
TrajectoryID,
)
from skyrl.train.step_wise_merge import (
merge_step_wise_turns_for_trajectory,
)
from skyrl.train.generators.utils import (
get_metrics_from_generator_output,
Expand Down Expand Up @@ -613,6 +618,83 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
"rollout_expert_indices", None
)

num_samples_before_merge = len(prompt_ids)

if self.cfg.generator.step_wise_trajectories:
assert "trajectory_ids" in generator_output and "is_last_step" in generator_output
trajectory_ids_raw: List[TrajectoryID] = generator_output["trajectory_ids"]
is_last_step_list: List[bool] = generator_output["is_last_step"]

# Group consecutive indices by same trajectory (instance_id + repetition_id).
# groupby merges only adjacent runs with the same key; trajectory_ids_raw must list
# all turns of a trajectory in one contiguous block (no interleaving trajectories).
groups: List[Tuple[TrajectoryID, List[int]]] = []
for _, group in groupby(enumerate(trajectory_ids_raw), key=lambda x: x[1].to_string()):
indices = [i for i, _ in group]
if indices:
groups.append((trajectory_ids_raw[indices[0]], indices))

merged_prompt_ids: List[List[int]] = []
merged_response_ids: List[List[int]] = []
merged_rewards: List[List[float]] = []
merged_loss_masks: List[List[int]] = []
merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using if logprobs can be ambiguous as it evaluates to False for both None and an empty list ([]). To distinguish between the case where logprobs are not provided (None) and the case where they are provided but the list is empty, it's safer to use if logprobs is not None. This ensures that an empty list of logprobs is handled correctly as an empty list, not as None.

Suggested change
merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None
merged_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None

merged_is_last_step: List[bool] = []
merged_trajectory_ids: List[TrajectoryID] = []
total_prefix_mismatch = 0

for traj_id, indices in groups:
turn_prompts = [prompt_ids[j] for j in indices]
turn_responses = [response_ids[j] for j in indices]
turn_rewards = [rewards[j] for j in indices]
turn_masks = [loss_masks[j] for j in indices]
turn_is_last = [is_last_step_list[j] for j in indices]
turn_logprobs = [logprobs[j] for j in indices] if logprobs else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to a previous line, using if logprobs can be ambiguous as it's False for both None and an empty list. Using if logprobs is not None makes the intent clearer and correctly handles the case of an empty list of logprobs.

Suggested change
turn_logprobs = [logprobs[j] for j in indices] if logprobs else None
turn_logprobs = [logprobs[j] for j in indices] if logprobs is not None else None


samples, mismatch_count = merge_step_wise_turns_for_trajectory(
prompt_token_ids=turn_prompts,
response_ids=turn_responses,
rewards=turn_rewards,
loss_masks=turn_masks,
is_last_step=turn_is_last,
rollout_logprobs=turn_logprobs,
)
total_prefix_mismatch += mismatch_count
for s in samples:
merged_prompt_ids.append(s.prompt_token_ids)
merged_response_ids.append(s.response_ids)
merged_rewards.append(s.rewards)
merged_loss_masks.append(s.loss_masks)
if merged_logprobs is not None:
merged_logprobs.append(s.rollout_logprobs)
merged_is_last_step.append(s.is_last_step)
merged_trajectory_ids.append(traj_id)

num_samples_after_merge = len(merged_prompt_ids)
prompt_ids = merged_prompt_ids
response_ids = merged_response_ids
rewards = merged_rewards
loss_masks = merged_loss_masks
logprobs = merged_logprobs
generator_output = {
**generator_output,
"prompt_token_ids": prompt_ids,
"response_ids": response_ids,
"rewards": rewards,
"loss_masks": loss_masks,
"rollout_logprobs": logprobs,
"is_last_step": merged_is_last_step,
"trajectory_ids": merged_trajectory_ids,
}
uids = [tid.instance_id for tid in merged_trajectory_ids]

self.all_metrics["trainer/stepwise_num_samples_before"] = num_samples_before_merge
self.all_metrics["trainer/stepwise_num_samples_after"] = num_samples_after_merge
self.all_metrics["trainer/stepwise_merge_ratio"] = (
num_samples_after_merge / num_samples_before_merge if num_samples_before_merge else 0.0
)
self.all_metrics["trainer/stepwise_prefix_mismatch_count"] = total_prefix_mismatch

(
sequences_tensor,
attention_masks_tensor,
Expand Down Expand Up @@ -676,9 +758,19 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
training_input.metadata["trajectory_ids"] = [
trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
]
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
last_step_response_lens = [
len(sample_response_ids)
for sample_response_ids, is_last in zip(response_ids, generator_output["is_last_step"])
if is_last
]
num_last_steps = len(last_step_response_lens)
training_input.metadata["avg_response_length"] = (
sum(last_step_response_lens) / num_last_steps if num_last_steps else 0.0
)
else:
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)

logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
Expand Down
113 changes: 113 additions & 0 deletions tests/train/test_step_wise_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Unit tests for prefix-aware step-wise merge (issue #1277).

Run: uv run --isolated --extra dev pytest tests/train/test_step_wise_merge.py -v
"""

import pytest
from skyrl.train.step_wise_merge import (
_is_prefix,
merge_step_wise_turns_for_trajectory,
)


def test_is_prefix():
assert _is_prefix([], [1, 2, 3]) is True
assert _is_prefix([1, 2], [1, 2, 3]) is True
assert _is_prefix([1, 2, 3], [1, 2, 3]) is True
assert _is_prefix([1, 2, 3], [1, 2]) is False
assert _is_prefix([1, 99], [1, 2, 3]) is False
assert _is_prefix([1], [1]) is True


def test_merge_works():
"""Turn2 observation = Turn1(obs+act) + extra → 1 merged sample."""
# Turn 1: prompt [10,20,30], response [40,50] → full sequence [10,20,30,40,50]
# Turn 2: prompt [10,20,30,40,50,60,70] (prefix match + extra 60,70)
prompt_ids = [
[10, 20, 30],
[10, 20, 30, 40, 50, 60, 70],
]
response_ids = [[40, 50], [80, 90]]
rewards = [[0.0, 0.0], [0.0, 0.0]]
loss_masks = [[1, 1], [1, 1]]
is_last_step = [False, True]

merged, mismatch_count = merge_step_wise_turns_for_trajectory(
prompt_token_ids=prompt_ids,
response_ids=response_ids,
rewards=rewards,
loss_masks=loss_masks,
is_last_step=is_last_step,
)

assert len(merged) == 1
assert mismatch_count == 0
# Merged prompt = full context = [10,20,30,40,50,60,70], response = [40,50,80,90]
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
Comment on lines +47 to +48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These assertions seem to contradict the implementation and docstring of merge_step_wise_turns_for_trajectory. The implementation sets prompt_token_ids to the observation of the first turn in a merged sequence, and response_ids includes the actions and observation deltas from subsequent turns. This is also what the function's docstring and the PR description state.

Based on the logic in merge_step_wise_turns_for_trajectory:

  • prompt_token_ids should be [10, 20, 30] (from the first turn's observation).
  • response_ids should be [40, 50, 60, 70, 80, 90] (action from turn 1 + delta observation from turn 2 + action from turn 2).

The current assertions imply a different merging strategy where the prompt is the full context up to the last action. Please verify the intended merging logic and update either the implementation or the test assertions to be consistent.

Suggested change
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
assert merged[0].prompt_token_ids == [10, 20, 30]
assert merged[0].response_ids == [40, 50, 60, 70, 80, 90]

Comment on lines +47 to +48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Test assertions in test_merge_works contradict merge function's actual output

The test asserts merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] and merged[0].response_ids == [40, 50, 80, 90], but tracing through merge_step_wise_turns_for_trajectory (skyrl/train/step_wise_merge.py:110-131), the function produces prompt_token_ids=[10, 20, 30] (only the first turn's observation, per the docstring at line 45) and response_ids=[40, 50, 60, 70, 80, 90] (resp1 + delta_ob2 + resp2, with delta_ob having zero loss mask). The code's behavior is correct: concatenating prompt+response in convert_prompts_responses_to_batch_tensors (skyrl/train/dataset/preprocess.py:125) yields a non-duplicated full sequence [10,20,30,40,50,60,70,80,90]. The test's expected values would produce duplicate tokens [...,40,50,60,70,40,50,80,90] when concatenated. This test will fail.

Suggested change
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
assert merged[0].prompt_token_ids == [10, 20, 30]
assert merged[0].response_ids == [40, 50, 60, 70, 80, 90]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

assert merged[0].is_last_step is True


def test_prefix_mismatch():
"""Turn2 observation does not start with previous sequence → 2 samples."""
prompt_ids = [
[10, 20, 30],
[99, 88, 77],
]
response_ids = [[40, 50], [11, 22]]
rewards = [[0.0, 0.0], [0.0, 0.0]]
loss_masks = [[1, 1], [1, 1]]
is_last_step = [False, True]

merged, mismatch_count = merge_step_wise_turns_for_trajectory(
prompt_token_ids=prompt_ids,
response_ids=response_ids,
rewards=rewards,
loss_masks=loss_masks,
is_last_step=is_last_step,
)

assert len(merged) == 2
assert mismatch_count == 1
assert merged[0].prompt_token_ids == [10, 20, 30]
assert merged[0].response_ids == [40, 50]
assert merged[0].is_last_step is False
assert merged[1].prompt_token_ids == [99, 88, 77]
assert merged[1].response_ids == [11, 22]
assert merged[1].is_last_step is True


def test_partial_merge():
"""Turn1→Turn2 merge, Turn3 mismatches → 2 samples."""
# Turn 1: [1,2,3] + [4,5]
# Turn 2: [1,2,3,4,5,6,7] (prefix) + [8,9]
# Turn 3: [100,200] (mismatch) + [11,22]
prompt_ids = [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7],
[100, 200],
]
response_ids = [[4, 5], [8, 9], [11, 22]]
rewards = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]
loss_masks = [[1, 1], [1, 1], [1, 1]]
is_last_step = [False, False, True]

merged, mismatch_count = merge_step_wise_turns_for_trajectory(
prompt_token_ids=prompt_ids,
response_ids=response_ids,
rewards=rewards,
loss_masks=loss_masks,
is_last_step=is_last_step,
)

assert len(merged) == 2
assert mismatch_count == 1
# First sample: merged turn1+2
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
Comment on lines +107 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These assertions are inconsistent with the merging logic in merge_step_wise_turns_for_trajectory for the same reasons as in test_merge_works. The prompt should be from the first turn of the merged sequence, and the response should include observation deltas.

Based on the implementation, the expected values for the first merged sample (turns 1 and 2) should be:

  • prompt_token_ids: [1, 2, 3]
  • response_ids: [4, 5, 6, 7, 8, 9]

Please align the test with the implementation's logic.

Suggested change
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
assert merged[0].prompt_token_ids == [1, 2, 3]
assert merged[0].response_ids == [4, 5, 6, 7, 8, 9]

Comment on lines +107 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Test assertions in test_partial_merge contradict merge function's actual output

Same issue as test_merge_works: the test asserts merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] and merged[0].response_ids == [4, 5, 8, 9], but the merge function produces prompt_token_ids=[1, 2, 3] (first turn's observation only, per skyrl/train/step_wise_merge.py:112) and response_ids=[4, 5, 6, 7, 8, 9] (resp1 + delta_ob [6,7] + resp2, per lines 116 and 128). The test's expected values would produce duplicate tokens [1,2,3,4,5,6,7,4,5,8,9] when concatenated in convert_prompts_responses_to_batch_tensors. This test will fail.

Suggested change
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
assert merged[0].prompt_token_ids == [1, 2, 3]
assert merged[0].response_ids == [4, 5, 6, 7, 8, 9]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

assert merged[0].is_last_step is False
# Second sample: turn3 only
assert merged[1].prompt_token_ids == [100, 200]
assert merged[1].response_ids == [11, 22]
assert merged[1].is_last_step is True
Loading