Skip to content

[train] Prefix-aware merge for step-wise trajectories (#1277)#1377

Open
deepsheth3 wants to merge 2 commits intoNovaSky-AI:mainfrom
deepsheth3:main
Open

[train] Prefix-aware merge for step-wise trajectories (#1277)#1377
deepsheth3 wants to merge 2 commits intoNovaSky-AI:mainfrom
deepsheth3:main

Conversation

@deepsheth3
Copy link
Copy Markdown

@deepsheth3 deepsheth3 commented Mar 24, 2026

  • Merge consecutive step-wise turns when the next prompt extends the prior prompt + response by prefix; otherwise keep separate samples.
  • Merged rows: first-turn observation = prompt; actions + observation deltas = response, with loss masked on delta tokens.
  • Trainer: group turns with itertools.groupby, append logprobs without a zero fallback, fix avg response length for step-wise (average over last-step responses only).
  • Tests in tests/train/test_step_wise_merge.py.

Test: pytest tests/train/test_step_wise_merge.py -v


Open with Devin

…I#1277)

- Add step_wise_merge.py: _is_prefix(), merge_step_wise_turns_for_trajectory()
- Wire merge into convert_to_training_input() when step_wise_trajectories=True
- Group by trajectory_id, merge turns when next prompt has prev full sequence as prefix
- Add metrics: stepwise_num_samples_before/after, merge_ratio, prefix_mismatch_count
- Add unit tests: merge, prefix_mismatch, partial_merge

Made-with: Cursor
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a prefix-aware merging strategy for step-wise trajectories, which is a significant enhancement for training on multi-turn conversations. The implementation is well-structured, with a dedicated module for the merging logic and corresponding unit tests. My review focuses on improving the efficiency of the merging logic, ensuring robustness in handling optional data, and correcting inconsistencies in the unit tests that could lead to misunderstandings of the feature's behavior.

Comment on lines +47 to +48
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These assertions seem to contradict the implementation and docstring of merge_step_wise_turns_for_trajectory. The implementation sets prompt_token_ids to the observation of the first turn in a merged sequence, and response_ids includes the actions and observation deltas from subsequent turns. This is also what the function's docstring and the PR description state.

Based on the logic in merge_step_wise_turns_for_trajectory:

  • prompt_token_ids should be [10, 20, 30] (from the first turn's observation).
  • response_ids should be [40, 50, 60, 70, 80, 90] (action from turn 1 + delta observation from turn 2 + action from turn 2).

The current assertions imply a different merging strategy where the prompt is the full context up to the last action. Please verify the intended merging logic and update either the implementation or the test assertions to be consistent.

Suggested change
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
assert merged[0].prompt_token_ids == [10, 20, 30]
assert merged[0].response_ids == [40, 50, 60, 70, 80, 90]

Comment on lines +107 to +108
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

These assertions are inconsistent with the merging logic in merge_step_wise_turns_for_trajectory for the same reasons as in test_merge_works. The prompt should be from the first turn of the merged sequence, and the response should include observation deltas.

Based on the implementation, the expected values for the first merged sample (turns 1 and 2) should be:

  • prompt_token_ids: [1, 2, 3]
  • response_ids: [4, 5, 6, 7, 8, 9]

Please align the test with the implementation's logic.

Suggested change
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
assert merged[0].prompt_token_ids == [1, 2, 3]
assert merged[0].response_ids == [4, 5, 6, 7, 8, 9]

Comment on lines +108 to +131
ac_logprobs_i = rollout_logprobs[i] if rollout_logprobs is not None else [0.0] * len(ac_tokens)

if len(full_sequence) == 0:
delta_ob = ob_tokens
initial_prompt = list(delta_ob)
elif _is_prefix(full_sequence, ob_tokens):
delta_ob = ob_tokens[len(full_sequence) :]
# Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence
acc_response_ids.extend(delta_ob)
acc_rewards.extend([0.0] * len(delta_ob))
acc_loss_masks.extend([0] * len(delta_ob))
acc_logprobs.extend([0.0] * len(delta_ob))
else:
prefix_mismatch_count += 1
flush()
delta_ob = ob_tokens
initial_prompt = list(delta_ob)

full_sequence.extend(delta_ob)
full_sequence.extend(ac_tokens)
acc_response_ids.extend(ac_tokens)
acc_rewards.extend(ac_rewards)
acc_loss_masks.extend(ac_masks)
acc_logprobs.extend(ac_logprobs_i)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for handling rollout_logprobs is inefficient when it is None. It creates and extends lists with zeros, which are then discarded in the flush function. This can be optimized by only performing logprob-related operations when rollout_logprobs is not None.

        if len(full_sequence) == 0:
            delta_ob = ob_tokens
            initial_prompt = list(delta_ob)
        elif _is_prefix(full_sequence, ob_tokens):
            delta_ob = ob_tokens[len(full_sequence) :]
            # Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence
            acc_response_ids.extend(delta_ob)
            acc_rewards.extend([0.0] * len(delta_ob))
            acc_loss_masks.extend([0] * len(delta_ob))
            if rollout_logprobs is not None:
                acc_logprobs.extend([0.0] * len(delta_ob))
        else:
            prefix_mismatch_count += 1
            flush()
            delta_ob = ob_tokens
            initial_prompt = list(delta_ob)

        full_sequence.extend(delta_ob)
        full_sequence.extend(ac_tokens)
        acc_response_ids.extend(ac_tokens)
        acc_rewards.extend(ac_rewards)
        acc_loss_masks.extend(ac_masks)
        if rollout_logprobs is not None:
            acc_logprobs.extend(rollout_logprobs[i])

merged_response_ids: List[List[int]] = []
merged_rewards: List[List[float]] = []
merged_loss_masks: List[List[int]] = []
merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using if logprobs can be ambiguous as it evaluates to False for both None and an empty list ([]). To distinguish between the case where logprobs are not provided (None) and the case where they are provided but the list is empty, it's safer to use if logprobs is not None. This ensures that an empty list of logprobs is handled correctly as an empty list, not as None.

Suggested change
merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None
merged_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None

turn_rewards = [rewards[j] for j in indices]
turn_masks = [loss_masks[j] for j in indices]
turn_is_last = [is_last_step_list[j] for j in indices]
turn_logprobs = [logprobs[j] for j in indices] if logprobs else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to a previous line, using if logprobs can be ambiguous as it's False for both None and an empty list. Using if logprobs is not None makes the intent clearer and correctly handles the case of an empty list of logprobs.

Suggested change
turn_logprobs = [logprobs[j] for j in indices] if logprobs else None
turn_logprobs = [logprobs[j] for j in indices] if logprobs is not None else None

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 3 additional findings in Devin Review.

Open in Devin Review

Comment on lines +47 to +48
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Test assertions in test_merge_works contradict merge function's actual output

The test asserts merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] and merged[0].response_ids == [40, 50, 80, 90], but tracing through merge_step_wise_turns_for_trajectory (skyrl/train/step_wise_merge.py:110-131), the function produces prompt_token_ids=[10, 20, 30] (only the first turn's observation, per the docstring at line 45) and response_ids=[40, 50, 60, 70, 80, 90] (resp1 + delta_ob2 + resp2, with delta_ob having zero loss mask). The code's behavior is correct: concatenating prompt+response in convert_prompts_responses_to_batch_tensors (skyrl/train/dataset/preprocess.py:125) yields a non-duplicated full sequence [10,20,30,40,50,60,70,80,90]. The test's expected values would produce duplicate tokens [...,40,50,60,70,40,50,80,90] when concatenated. This test will fail.

Suggested change
assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70]
assert merged[0].response_ids == [40, 50, 80, 90]
assert merged[0].prompt_token_ids == [10, 20, 30]
assert merged[0].response_ids == [40, 50, 60, 70, 80, 90]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +107 to +108
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Test assertions in test_partial_merge contradict merge function's actual output

Same issue as test_merge_works: the test asserts merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] and merged[0].response_ids == [4, 5, 8, 9], but the merge function produces prompt_token_ids=[1, 2, 3] (first turn's observation only, per skyrl/train/step_wise_merge.py:112) and response_ids=[4, 5, 6, 7, 8, 9] (resp1 + delta_ob [6,7] + resp2, per lines 116 and 128). The test's expected values would produce duplicate tokens [1,2,3,4,5,6,7,4,5,8,9] when concatenated in convert_prompts_responses_to_batch_tensors. This test will fail.

Suggested change
assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7]
assert merged[0].response_ids == [4, 5, 8, 9]
assert merged[0].prompt_token_ids == [1, 2, 3]
assert merged[0].response_ids == [4, 5, 6, 7, 8, 9]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant