|
19 | 19 | apply_token_matches,
|
20 | 20 | find_mm_placeholders,
|
21 | 21 | find_text_matches, find_token_matches,
|
22 |
| - iter_token_matches) |
| 22 | + iter_token_matches, |
| 23 | + replace_token_matches) |
23 | 24 | # yapf: enable
|
24 | 25 | from vllm.multimodal.profiling import MultiModalProfiler
|
25 | 26 | from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
@@ -89,6 +90,58 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
89 | 90 | assert all(match_len == len(match_ids) for match_len in match_lens)
|
90 | 91 |
|
91 | 92 |
|
| 93 | +# yapf: disable |
| 94 | +@pytest.mark.parametrize( |
| 95 | + ("token_ids", "match_ids", "new_ids", "expected"), |
| 96 | + [ |
| 97 | + ([], [], [-1], []), |
| 98 | + ([], [32000], [-1], []), |
| 99 | + ( |
| 100 | + [32000, 32000, 32000], |
| 101 | + [32000], |
| 102 | + [-1], |
| 103 | + [-1, -1, -1], |
| 104 | + ), |
| 105 | + ( |
| 106 | + [32000, 32000, 32000], |
| 107 | + [32000, 32000], |
| 108 | + [-1], |
| 109 | + [-1, 32000], |
| 110 | + ), |
| 111 | + ( |
| 112 | + [32000, 32000, 32000], |
| 113 | + [32000, 32000, 32000], |
| 114 | + [-1], |
| 115 | + [-1], |
| 116 | + ), |
| 117 | + ( |
| 118 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 119 | + [28747, 32000], |
| 120 | + [-1], |
| 121 | + [9833, -1, 32000, 32000, 9833, -1, 32000, 918], |
| 122 | + ), |
| 123 | + ( |
| 124 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 125 | + [28747, 32000, 32000, 32000], |
| 126 | + [-1], |
| 127 | + [9833, -1, 9833, 28747, 32000, 32000, 918], |
| 128 | + ), |
| 129 | + ( |
| 130 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 131 | + [28747, 0, 32000], |
| 132 | + [-1], |
| 133 | + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], |
| 134 | + ), |
| 135 | + ], |
| 136 | +) |
| 137 | +# yapf: enable |
| 138 | +def test_replace_token_matches(token_ids, match_ids, new_ids, expected): |
| 139 | + result = replace_token_matches(token_ids, match_ids, new_ids) |
| 140 | + |
| 141 | + # Manually constructed results |
| 142 | + assert result == expected |
| 143 | + |
| 144 | + |
92 | 145 | # yapf: disable
|
93 | 146 | @pytest.mark.parametrize(
|
94 | 147 | ("prompt", "target_by_key", "expected_by_key"),
|
|
0 commit comments