Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions renderers/qwen35.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,17 +814,17 @@ def flush_buf() -> None:

# Merge prev mm_data (images from earlier turns) with the new turn's.
merged_hashes: dict[str, list[str]] = (
dict(previous_multi_modal_data.mm_hashes)
{modality: list(vals) for modality, vals in previous_multi_modal_data.mm_hashes.items()}
if previous_multi_modal_data
else {}
)
merged_placeholders: dict[str, list[PlaceholderRange]] = (
dict(previous_multi_modal_data.mm_placeholders)
{modality: list(vals) for modality, vals in previous_multi_modal_data.mm_placeholders.items()}
if previous_multi_modal_data
else {}
)
merged_items: dict[str, list[dict[str, Any]]] = (
dict(previous_multi_modal_data.mm_items)
{modality: list(vals) for modality, vals in previous_multi_modal_data.mm_items.items()}
if previous_multi_modal_data
else {}
)
Expand Down
6 changes: 3 additions & 3 deletions renderers/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,17 +804,17 @@ def render_media_content(content: Any) -> None:

# Merge prev mm_data with the new turn's items.
merged_hashes = (
dict(previous_multi_modal_data.mm_hashes)
{modality: list(vals) for modality, vals in previous_multi_modal_data.mm_hashes.items()}
if previous_multi_modal_data
else {}
)
merged_placeholders = (
dict(previous_multi_modal_data.mm_placeholders)
{modality: list(vals) for modality, vals in previous_multi_modal_data.mm_placeholders.items()}
if previous_multi_modal_data
else {}
)
merged_items = (
dict(previous_multi_modal_data.mm_items)
{modality: list(vals) for modality, vals in previous_multi_modal_data.mm_items.items()}
if previous_multi_modal_data
else {}
)
Expand Down
97 changes: 97 additions & 0 deletions tests/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,103 @@ def test_multimodal_bridge_extends_and_carries_mm_data(
)


@pytest.mark.parametrize(
"mm_model_name,modality", _CASES, ids=[f"{m}|{mo}" for m, mo in _CASES]
)
def test_multimodal_bridge_does_not_mutate_previous_mm_data(
mm_model_name, modality, tiny_image
):
"""``bridge_to_next_turn`` must not mutate ``previous_multi_modal_data``.

Regression for a shallow-copy bug: ``dict(prev.mm_items)`` copies the
outer mapping but leaves each per-modality list aliased to the
original. The bridge then ``.extend(...)`` on that list, mutating
the prior turn's ``MultiModalData`` in place. Callers that retain
the prior ``RenderedTokens`` (e.g. trainers that keep per-step
snapshots for loss reconstruction) silently see their earlier
turns' image lists grow on every bridge.
"""
if not _hf_snapshot_cached(mm_model_name):
pytest.skip(f"{mm_model_name}: HF snapshot not cached locally")

kit = _modality_kit(modality, mm_model_name)
tokenizer, _, renderer = _load_processor_and_renderer(mm_model_name)

initial = [
{
"role": "user",
"content": [
kit["make_part"](tiny_image),
{"type": "text", "text": "Turn one."},
],
}
]
new = [
{
"role": "user",
"content": [
kit["make_part"](tiny_image),
{"type": "text", "text": "Turn two."},
],
}
]

initial_rendered = renderer.render(initial, add_generation_prompt=True)
assert initial_rendered.multi_modal_data is not None
prev_mm = initial_rendered.multi_modal_data

# Snapshot the prior lists' identities and contents BEFORE bridging.
prev_items_list = prev_mm.mm_items.get(modality, [])
prev_placeholders_list = prev_mm.mm_placeholders.get(modality, [])
prev_hashes_list = prev_mm.mm_hashes.get(modality, [])
items_id_before = id(prev_items_list)
placeholders_id_before = id(prev_placeholders_list)
hashes_id_before = id(prev_hashes_list)
items_snapshot = list(prev_items_list)
placeholders_snapshot = list(prev_placeholders_list)
hashes_snapshot = list(prev_hashes_list)

im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
completion_ids = tokenizer.encode("Saw it.", add_special_tokens=False) + [im_end_id]

bridged = renderer.bridge_to_next_turn(
previous_prompt_ids=initial_rendered.token_ids,
previous_completion_ids=completion_ids,
new_messages=new,
previous_multi_modal_data=prev_mm,
)
assert bridged is not None and getattr(bridged, "multi_modal_data", None) is not None

# The prior MultiModalData must be untouched.
assert prev_mm.mm_items.get(modality, []) == items_snapshot, (
f"{mm_model_name} / {modality}: bridge mutated previous mm_items list "
f"(expected len {len(items_snapshot)}, got {len(prev_mm.mm_items.get(modality, []))})"
)
assert prev_mm.mm_placeholders.get(modality, []) == placeholders_snapshot, (
f"{mm_model_name} / {modality}: bridge mutated previous mm_placeholders list"
)
assert prev_mm.mm_hashes.get(modality, []) == hashes_snapshot, (
f"{mm_model_name} / {modality}: bridge mutated previous mm_hashes list"
)

# And the bridged data's inner lists must not be the same objects
# as the prior turn's — otherwise a later second bridge would mutate
# this turn's lists, too.
bridged_items = bridged.multi_modal_data.mm_items.get(modality, [])
bridged_placeholders = bridged.multi_modal_data.mm_placeholders.get(modality, [])
bridged_hashes = bridged.multi_modal_data.mm_hashes.get(modality, [])
assert id(bridged_items) != items_id_before, (
f"{mm_model_name} / {modality}: bridged mm_items list is aliased to "
"previous_multi_modal_data.mm_items — outer-dict-only copy detected"
)
assert id(bridged_placeholders) != placeholders_id_before, (
f"{mm_model_name} / {modality}: bridged mm_placeholders list aliased to prior"
)
assert id(bridged_hashes) != hashes_id_before, (
f"{mm_model_name} / {modality}: bridged mm_hashes list aliased to prior"
)


def test_modality_registry_models_route_to_renderer():
"""Every model in ``MULTIMODAL_MODELS`` resolves to a concrete renderer
via ``create_renderer(renderer='auto')``. Guards against drift between
Expand Down
Loading