Skip to content

Commit 0a26ba7

Browse files
committed
prefix sampling
tests token length fix sample prefix tokens not chars lint fix tests percentage lint fixes gemini-feedback fix tests failing rename prompt-prefix-ratio mkae current_prefix_length local remove prompt prefix length refactor prefix sampling logic format revert back to 4 digits fix prefix length to change with variable distribution use line char ratio put latest changes in their own function and update prefix truncation gemini-feedback make number of tokens more accurate to scenario use tokenizer.encode fix tests documentation fix merge issues
1 parent dd3fba3 commit 0a26ba7

File tree

5 files changed

+195
-4
lines changed

5 files changed

+195
-4
lines changed

docs/user-guide/run-benchmark.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ For heavier traffic scenarios, like `D(16000,200)` or `D(128000,200)`, use the f
186186
--num-concurrency 32 \
187187
```
188188

189+
To benchmark with prefix caching, you can make a given fraction of each prompt a common prefix with `--prompt-prefix-ratio`. For example, to set the first half of each prompt to a common prefix, use:
190+
191+
```shell
192+
--prompt-prefix-ratio 0.5 \
193+
```
189194

190195
## Distributed Benchmark
191196

genai_bench/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def benchmark(
121121
spawn_rate,
122122
upload_results,
123123
namespace,
124+
prompt_prefix_ratio,
124125
# Storage auth options
125126
storage_provider,
126127
storage_bucket,
@@ -285,6 +286,7 @@ def benchmark(
285286
data=data,
286287
additional_request_params=additional_request_params,
287288
dataset_config=dataset_config_obj,
289+
prefix_length=prompt_prefix_ratio,
288290
)
289291

290292
# If user did not provide scenarios but provided a dataset, default to dataset mode

genai_bench/cli/option_groups.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,14 @@ def server_options(func):
368368

369369
# Group experiment-related options
370370
def experiment_options(func):
371+
func = click.option(
372+
"--prompt-prefix-ratio",
373+
type=click.FloatRange(0.0, 1.0),
374+
default=0.0,
375+
help="The ratio of prefix length to overall input length "
376+
"to prepend to all inputs to test prefix caching. "
377+
"Value should be between 0.0 and 1.0. ",
378+
)(func)
371379
func = click.option(
372380
"--experiment-folder-name",
373381
type=str,

genai_bench/sampling/text.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
model: str,
3232
output_modality: str,
3333
data: List[str],
34+
prompt_prefix_ratio: float = 0.0,
3435
additional_request_params: Optional[Dict[str, Any]] = None,
3536
dataset_config: Optional[DatasetConfig] = None,
3637
**kwargs,
@@ -41,6 +42,8 @@ def __init__(
4142

4243
self.data = data
4344
self.batch_size = 1 # Default batch size
45+
self.prompt_prefix_ratio = prompt_prefix_ratio
46+
self.prefix = ""
4447

4548
def sample(self, scenario: Optional[Scenario]) -> UserRequest:
4649
"""
@@ -167,6 +170,68 @@ def _validate_scenario(self, scenario: Optional[Scenario]) -> None:
167170
f"{type(scenario.scenario_type)}"
168171
)
169172

173+
def _sample_prefix(self, current_prefix_length) -> str:
174+
"""
175+
Generates prefix of length current_prefix_length to be
176+
prepended to all input prompts.
177+
"""
178+
179+
data_copy = self.data.copy()
180+
181+
if not self.data:
182+
raise ValueError("Cannot generate prefix from an empty dataset")
183+
184+
prefix = ""
185+
prefix_tokens_len = 0
186+
# Generate the prefix
187+
while prefix_tokens_len < current_prefix_length:
188+
random.shuffle(data_copy)
189+
for line in data_copy:
190+
line_tokens = self.tokenizer.encode(line)
191+
num_line_tokens = len(line_tokens)
192+
if prefix_tokens_len + num_line_tokens > current_prefix_length:
193+
remaining_prefix_len = current_prefix_length - prefix_tokens_len
194+
truncated_text = self.tokenizer.decode(
195+
line_tokens[:remaining_prefix_len]
196+
)
197+
prefix += truncated_text
198+
return prefix
199+
prefix += line
200+
prefix_tokens_len = len(self.tokenizer.encode(prefix))
201+
202+
return prefix
203+
204+
def _get_current_prefix(self, prefix_length: int) -> str:
205+
"""
206+
Returns the prefix for the current prompt of the specified length.
207+
208+
Args:
209+
current_prefix_length (int): The desired length of the prefix.
210+
"""
211+
212+
# Prefix of the current prompt being generated
213+
current_prefix: str = self.prefix
214+
215+
# Get the difference in length between the existing
216+
# prefix and the desired prefix length
217+
218+
current_prefix_tokens = self.tokenizer.encode(current_prefix)
219+
current_prefix_length = len(current_prefix_tokens)
220+
prefix_length_diff: int = prefix_length - current_prefix_length
221+
222+
# Generate the prefix if it hasn't been created yet, or add
223+
# to its length if it's not long enough
224+
if prefix_length_diff > 0:
225+
self.prefix += self._sample_prefix(prefix_length_diff)
226+
current_prefix = self.prefix
227+
228+
elif prefix_length_diff < 0:
229+
# If the prefix is longer than needed, truncate it
230+
current_prefix = self.tokenizer.decode(
231+
current_prefix_tokens[:prefix_length]
232+
)
233+
return current_prefix
234+
170235
def _sample_text(self, num_input_tokens: Optional[int]) -> str:
171236
"""
172237
Samples text from a list of lines based on the specified number of
@@ -176,16 +241,40 @@ def _sample_text(self, num_input_tokens: Optional[int]) -> str:
176241
Args:
177242
num_input_tokens (int): The target number of input tokens.
178243
244+
Raises:
245+
ValueError: if the prompt length is shorter than the prefix
246+
length.
247+
179248
Returns:
180249
str: A text prompt containing the desired number of tokens.
181250
"""
182251
if not num_input_tokens:
183252
return random.choice(self.data)
184253

254+
# Calculate actual prefix length based on ratio or fixed length
255+
current_prefix_length = 0
256+
if self.prompt_prefix_ratio > 0.0:
257+
current_prefix_length = round(num_input_tokens * self.prompt_prefix_ratio)
258+
185259
data_copy = self.data.copy()
186-
prompt = ""
187-
left_tokens_to_sample = num_input_tokens
188260

261+
if not self.data:
262+
raise ValueError("Cannot sample text from an empty dataset")
263+
264+
if num_input_tokens <= current_prefix_length:
265+
raise ValueError("Prefix length must be shorter than total input length")
266+
267+
# Get the prompt prefix
268+
current_prefix: str = self._get_current_prefix(current_prefix_length)
269+
270+
# Prepend the prefix to all prompts with a randomly picked 4 digits
271+
prompt = f"{current_prefix}{random.randint(1000,9999)}"
272+
273+
prompt_tokens = self.tokenizer.encode(prompt)
274+
left_tokens_to_sample = num_input_tokens - len(prompt_tokens)
275+
276+
if left_tokens_to_sample < 0:
277+
return self.tokenizer.decode(prompt_tokens[:num_input_tokens])
189278
while left_tokens_to_sample > 0:
190279
random.shuffle(data_copy)
191280
for line in data_copy:

tests/sampling/test_text.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ def mock_encode(text, add_special_tokens=False):
221221
# Count actual tokens in result
222222
# Need to handle mixed content (original lines + decoded text)
223223
total_tokens = 0
224+
225+
# All prompts start with 4 numbers, which are 1 token
226+
total_tokens += 1
227+
result = result[4:]
228+
224229
# Split by our test lines to count tokens properly
225230
remaining = result
226231
for line in self.test_data:
@@ -255,6 +260,88 @@ def test_sample_text_truncation(self):
255260
_ = self.sampler._sample_text(requested_tokens)
256261

257262
# Verify decode was called with truncated tokens
258-
self.tokenizer.decode.assert_called_with(
259-
line_tokens[:requested_tokens], skip_special_tokens=True
263+
self.tokenizer.decode.assert_called_with(line_tokens[:requested_tokens])
264+
265+
def test_sample_chat_prefix_ratio_request(self):
266+
"""Test prefix generation using ratio."""
267+
268+
# Mock encode to return list with length equal to number of characters in input
269+
def mock_encode(text, add_special_tokens=False):
270+
return [1] * len(text)
271+
272+
self.tokenizer.encode = mock_encode
273+
274+
# Mock decode to return the original text
275+
def mock_decode(tokens):
276+
if isinstance(tokens, list):
277+
return "a" * len(tokens) # Return 'a' repeated for the token count
278+
return "decoded_text"
279+
280+
self.tokenizer.decode = mock_decode
281+
282+
scenario = NormalDistribution(
283+
mean_input_tokens=20,
284+
stddev_input_tokens=0,
285+
mean_output_tokens=20,
286+
stddev_output_tokens=0,
287+
)
288+
prefix_sampler = TextSampler(
289+
tokenizer=self.tokenizer,
290+
model=self.model,
291+
output_modality=self.output_modality,
292+
data=self.test_data,
293+
use_scenario=True,
294+
prompt_prefix_ratio=0.5, # 50% of 20 tokens = 10 tokens
295+
)
296+
result = prefix_sampler.sample(scenario)
297+
self.assertIsInstance(result, UserChatRequest)
298+
self.assertEqual(result.model, self.model)
299+
self.assertTrue(isinstance(result.prompt, str))
300+
self.assertGreater(len(result.prompt), 0)
301+
self.assertTrue(result.prompt.startswith(prefix_sampler.prefix))
302+
self.assertEqual(len(result.prompt), 20)
303+
304+
def test_short_prompt_request(self):
305+
"""Test that short prompts are handled correctly."""
306+
307+
def mock_encode(text, add_special_tokens=False):
308+
return [1] * len(text)
309+
310+
self.tokenizer.encode = mock_encode
311+
312+
# Mock decode to return the original text
313+
def mock_decode(tokens):
314+
if isinstance(tokens, list):
315+
return "a" * len(tokens) # Return 'a' repeated for the token count
316+
return "decoded_text"
317+
318+
self.tokenizer.decode = mock_decode
319+
320+
self.sampler.data = ["2"]
321+
322+
# Scenario asks for only 1 input token
323+
scenario = NormalDistribution(1, 0, 1, 0)
324+
325+
result = self.sampler.sample(scenario)
326+
self.assertIsInstance(result, UserChatRequest)
327+
# The prompt will be the 4-digit number, truncated to 1 char
328+
self.assertEqual(len(result.prompt), 1)
329+
self.assertGreater(len(result.prompt), 0)
330+
331+
def test_empty_dataset(self):
332+
"""Test sampling from an empty dataset."""
333+
empty_sampler = TextSampler(
334+
tokenizer=self.tokenizer,
335+
model=self.model,
336+
output_modality=self.output_modality,
337+
data=[],
338+
use_scenario=True,
339+
)
340+
scenario = NormalDistribution(10, 0, 10, 0)
341+
342+
with self.assertRaises(ValueError) as context:
343+
empty_sampler.sample(scenario)
344+
345+
self.assertEqual(
346+
str(context.exception), "Cannot sample text from an empty dataset"
260347
)

0 commit comments

Comments
 (0)