[train] Prefix-aware merge for step-wise trajectories (#1277)#1377
[train] Prefix-aware merge for step-wise trajectories (#1277)#1377deepsheth3 wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
…I#1277) - Add step_wise_merge.py: _is_prefix(), merge_step_wise_turns_for_trajectory() - Wire merge into convert_to_training_input() when step_wise_trajectories=True - Group by trajectory_id, merge turns when next prompt has prev full sequence as prefix - Add metrics: stepwise_num_samples_before/after, merge_ratio, prefix_mismatch_count - Add unit tests: merge, prefix_mismatch, partial_merge Made-with: Cursor
…for merge Made-with: Cursor
There was a problem hiding this comment.
Code Review
This pull request introduces a prefix-aware merging strategy for step-wise trajectories, which is a significant enhancement for training on multi-turn conversations. The implementation is well-structured, with a dedicated module for the merging logic and corresponding unit tests. My review focuses on improving the efficiency of the merging logic, ensuring robustness in handling optional data, and correcting inconsistencies in the unit tests that could lead to misunderstandings of the feature's behavior.
| assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] | ||
| assert merged[0].response_ids == [40, 50, 80, 90] |
There was a problem hiding this comment.
These assertions seem to contradict the implementation and docstring of merge_step_wise_turns_for_trajectory. The implementation sets prompt_token_ids to the observation of the first turn in a merged sequence, and response_ids includes the actions and observation deltas from subsequent turns. This is also what the function's docstring and the PR description state.
Based on the logic in merge_step_wise_turns_for_trajectory:
prompt_token_idsshould be[10, 20, 30](from the first turn's observation).response_idsshould be[40, 50, 60, 70, 80, 90](action from turn 1 + delta observation from turn 2 + action from turn 2).
The current assertions imply a different merging strategy where the prompt is the full context up to the last action. Please verify the intended merging logic and update either the implementation or the test assertions to be consistent.
| assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] | |
| assert merged[0].response_ids == [40, 50, 80, 90] | |
| assert merged[0].prompt_token_ids == [10, 20, 30] | |
| assert merged[0].response_ids == [40, 50, 60, 70, 80, 90] |
| assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] | ||
| assert merged[0].response_ids == [4, 5, 8, 9] |
There was a problem hiding this comment.
These assertions are inconsistent with the merging logic in merge_step_wise_turns_for_trajectory for the same reasons as in test_merge_works. The prompt should be from the first turn of the merged sequence, and the response should include observation deltas.
Based on the implementation, the expected values for the first merged sample (turns 1 and 2) should be:
prompt_token_ids:[1, 2, 3]response_ids:[4, 5, 6, 7, 8, 9]
Please align the test with the implementation's logic.
| assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] | |
| assert merged[0].response_ids == [4, 5, 8, 9] | |
| assert merged[0].prompt_token_ids == [1, 2, 3] | |
| assert merged[0].response_ids == [4, 5, 6, 7, 8, 9] |
| ac_logprobs_i = rollout_logprobs[i] if rollout_logprobs is not None else [0.0] * len(ac_tokens) | ||
|
|
||
| if len(full_sequence) == 0: | ||
| delta_ob = ob_tokens | ||
| initial_prompt = list(delta_ob) | ||
| elif _is_prefix(full_sequence, ob_tokens): | ||
| delta_ob = ob_tokens[len(full_sequence) :] | ||
| # Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence | ||
| acc_response_ids.extend(delta_ob) | ||
| acc_rewards.extend([0.0] * len(delta_ob)) | ||
| acc_loss_masks.extend([0] * len(delta_ob)) | ||
| acc_logprobs.extend([0.0] * len(delta_ob)) | ||
| else: | ||
| prefix_mismatch_count += 1 | ||
| flush() | ||
| delta_ob = ob_tokens | ||
| initial_prompt = list(delta_ob) | ||
|
|
||
| full_sequence.extend(delta_ob) | ||
| full_sequence.extend(ac_tokens) | ||
| acc_response_ids.extend(ac_tokens) | ||
| acc_rewards.extend(ac_rewards) | ||
| acc_loss_masks.extend(ac_masks) | ||
| acc_logprobs.extend(ac_logprobs_i) |
There was a problem hiding this comment.
The current implementation for handling rollout_logprobs is inefficient when it is None. It creates and extends lists with zeros, which are then discarded in the flush function. This can be optimized by only performing logprob-related operations when rollout_logprobs is not None.
if len(full_sequence) == 0:
delta_ob = ob_tokens
initial_prompt = list(delta_ob)
elif _is_prefix(full_sequence, ob_tokens):
delta_ob = ob_tokens[len(full_sequence) :]
# Interleave: delta_ob goes into response stream with zero loss so prompt+response = full sequence
acc_response_ids.extend(delta_ob)
acc_rewards.extend([0.0] * len(delta_ob))
acc_loss_masks.extend([0] * len(delta_ob))
if rollout_logprobs is not None:
acc_logprobs.extend([0.0] * len(delta_ob))
else:
prefix_mismatch_count += 1
flush()
delta_ob = ob_tokens
initial_prompt = list(delta_ob)
full_sequence.extend(delta_ob)
full_sequence.extend(ac_tokens)
acc_response_ids.extend(ac_tokens)
acc_rewards.extend(ac_rewards)
acc_loss_masks.extend(ac_masks)
if rollout_logprobs is not None:
acc_logprobs.extend(rollout_logprobs[i])| merged_response_ids: List[List[int]] = [] | ||
| merged_rewards: List[List[float]] = [] | ||
| merged_loss_masks: List[List[int]] = [] | ||
| merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None |
There was a problem hiding this comment.
Using if logprobs can be ambiguous as it evaluates to False for both None and an empty list ([]). To distinguish between the case where logprobs are not provided (None) and the case where they are provided but the list is empty, it's safer to use if logprobs is not None. This ensures that an empty list of logprobs is handled correctly as an empty list, not as None.
| merged_logprobs: Optional[List[List[float]]] = [] if logprobs else None | |
| merged_logprobs: Optional[List[List[float]]] = [] if logprobs is not None else None |
| turn_rewards = [rewards[j] for j in indices] | ||
| turn_masks = [loss_masks[j] for j in indices] | ||
| turn_is_last = [is_last_step_list[j] for j in indices] | ||
| turn_logprobs = [logprobs[j] for j in indices] if logprobs else None |
There was a problem hiding this comment.
Similar to a previous line, using if logprobs can be ambiguous as it's False for both None and an empty list. Using if logprobs is not None makes the intent clearer and correctly handles the case of an empty list of logprobs.
| turn_logprobs = [logprobs[j] for j in indices] if logprobs else None | |
| turn_logprobs = [logprobs[j] for j in indices] if logprobs is not None else None |
| assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] | ||
| assert merged[0].response_ids == [40, 50, 80, 90] |
There was a problem hiding this comment.
🔴 Test assertions in test_merge_works contradict merge function's actual output
The test asserts merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] and merged[0].response_ids == [40, 50, 80, 90], but tracing through merge_step_wise_turns_for_trajectory (skyrl/train/step_wise_merge.py:110-131), the function produces prompt_token_ids=[10, 20, 30] (only the first turn's observation, per the docstring at line 45) and response_ids=[40, 50, 60, 70, 80, 90] (resp1 + delta_ob2 + resp2, with delta_ob having zero loss mask). The code's behavior is correct: concatenating prompt+response in convert_prompts_responses_to_batch_tensors (skyrl/train/dataset/preprocess.py:125) yields a non-duplicated full sequence [10,20,30,40,50,60,70,80,90]. The test's expected values would produce duplicate tokens [...,40,50,60,70,40,50,80,90] when concatenated. This test will fail.
| assert merged[0].prompt_token_ids == [10, 20, 30, 40, 50, 60, 70] | |
| assert merged[0].response_ids == [40, 50, 80, 90] | |
| assert merged[0].prompt_token_ids == [10, 20, 30] | |
| assert merged[0].response_ids == [40, 50, 60, 70, 80, 90] |
Was this helpful? React with 👍 or 👎 to provide feedback.
| assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] | ||
| assert merged[0].response_ids == [4, 5, 8, 9] |
There was a problem hiding this comment.
🔴 Test assertions in test_partial_merge contradict merge function's actual output
Same issue as test_merge_works: the test asserts merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] and merged[0].response_ids == [4, 5, 8, 9], but the merge function produces prompt_token_ids=[1, 2, 3] (first turn's observation only, per skyrl/train/step_wise_merge.py:112) and response_ids=[4, 5, 6, 7, 8, 9] (resp1 + delta_ob [6,7] + resp2, per lines 116 and 128). The test's expected values would produce duplicate tokens [1,2,3,4,5,6,7,4,5,8,9] when concatenated in convert_prompts_responses_to_batch_tensors. This test will fail.
| assert merged[0].prompt_token_ids == [1, 2, 3, 4, 5, 6, 7] | |
| assert merged[0].response_ids == [4, 5, 8, 9] | |
| assert merged[0].prompt_token_ids == [1, 2, 3] | |
| assert merged[0].response_ids == [4, 5, 6, 7, 8, 9] |
Was this helpful? React with 👍 or 👎 to provide feedback.
Test: pytest tests/train/test_step_wise_merge.py -v