Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/user-guide/run-benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ For heavier traffic scenarios, like `D(16000,200)` or `D(128000,200)`, use the f
--num-concurrency 32 \
```

To benchmark with prefix caching, you can make a given fraction of each prompt a common prefix with `--prompt-prefix-ratio`. For example, to set the first half of each prompt to a common prefix, use:

```shell
--prompt-prefix-ratio 0.5 \
```

## Distributed Benchmark

If you see the message below in the genai-bench logs, it indicates that a single process is insufficient to generate the desired load.
Expand Down
2 changes: 2 additions & 0 deletions genai_bench/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def benchmark(
spawn_rate,
upload_results,
namespace,
prompt_prefix_ratio,
# Storage auth options
storage_provider,
storage_bucket,
Expand Down Expand Up @@ -287,6 +288,7 @@ def benchmark(
data=data,
additional_request_params=additional_request_params,
dataset_config=dataset_config_obj,
prompt_prefix_ratio=prompt_prefix_ratio,
)

# If user did not provide scenarios but provided a dataset, default to dataset mode
Expand Down
8 changes: 8 additions & 0 deletions genai_bench/cli/option_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,14 @@ def server_options(func):

# Group experiment-related options
def experiment_options(func):
func = click.option(
"--prompt-prefix-ratio",
type=click.FloatRange(0.0, 1.0),
default=0.0,
help="The ratio of prefix length to overall input length "
"to prepend to all inputs to test prefix caching. "
"Value should be between 0.0 and 1.0. ",
)(func)
func = click.option(
"--experiment-folder-name",
type=str,
Expand Down
93 changes: 91 additions & 2 deletions genai_bench/sampling/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
model: str,
output_modality: str,
data: List[str],
prompt_prefix_ratio: float = 0.0,
additional_request_params: Optional[Dict[str, Any]] = None,
dataset_config: Optional[DatasetConfig] = None,
**kwargs,
Expand All @@ -41,6 +42,8 @@ def __init__(

self.data = data
self.batch_size = 1 # Default batch size
self.prompt_prefix_ratio = prompt_prefix_ratio
self.prefix = ""

def sample(self, scenario: Optional[Scenario]) -> UserRequest:
"""
Expand Down Expand Up @@ -165,6 +168,68 @@ def _validate_scenario(self, scenario: Optional[Scenario]) -> None:
f"{type(scenario.scenario_type)}"
)

def _sample_prefix(self, current_prefix_length) -> str:
"""
Generates prefix of length current_prefix_length to be
prepended to all input prompts.
"""

data_copy = self.data.copy()

if not self.data:
raise ValueError("Cannot generate prefix from an empty dataset")

prefix = ""
prefix_tokens_len = 0
# Generate the prefix
while prefix_tokens_len < current_prefix_length:
random.shuffle(data_copy)
for line in data_copy:
line_tokens = self.tokenizer.encode(line)
num_line_tokens = len(line_tokens)
if prefix_tokens_len + num_line_tokens > current_prefix_length:
remaining_prefix_len = current_prefix_length - prefix_tokens_len
truncated_text = self.tokenizer.decode(
line_tokens[:remaining_prefix_len]
)
prefix += truncated_text
return prefix
prefix += line
prefix_tokens_len = len(self.tokenizer.encode(prefix))

return prefix

def _get_current_prefix(self, prefix_length: int) -> str:
"""
Returns the prefix for the current prompt of the specified length.

Args:
current_prefix_length (int): The desired length of the prefix.
"""

# Prefix of the current prompt being generated
current_prefix: str = self.prefix

# Get the difference in length between the existing
# prefix and the desired prefix length

current_prefix_tokens = self.tokenizer.encode(current_prefix)
current_prefix_length = len(current_prefix_tokens)
prefix_length_diff: int = prefix_length - current_prefix_length

# Generate the prefix if it hasn't been created yet, or add
# to its length if it's not long enough
if prefix_length_diff > 0:
self.prefix += self._sample_prefix(prefix_length_diff)
current_prefix = self.prefix

elif prefix_length_diff < 0:
# If the prefix is longer than needed, truncate it
current_prefix = self.tokenizer.decode(
current_prefix_tokens[:prefix_length]
)
return current_prefix

def _sample_text(self, num_input_tokens: Optional[int]) -> str:
"""
Samples text from a list of lines based on the specified number of
Expand All @@ -174,16 +239,40 @@ def _sample_text(self, num_input_tokens: Optional[int]) -> str:
Args:
num_input_tokens (int): The target number of input tokens.

Raises:
ValueError: if the prompt length is shorter than the prefix
length.

Returns:
str: A text prompt containing the desired number of tokens.
"""
if not num_input_tokens:
return random.choice(self.data)

# Calculate actual prefix length based on ratio or fixed length
current_prefix_length = 0
if self.prompt_prefix_ratio > 0.0:
current_prefix_length = round(num_input_tokens * self.prompt_prefix_ratio)

data_copy = self.data.copy()
prompt = ""
left_tokens_to_sample = num_input_tokens

if not self.data:
raise ValueError("Cannot sample text from an empty dataset")

if num_input_tokens <= current_prefix_length:
raise ValueError("Prefix length must be shorter than total input length")

# Get the prompt prefix
current_prefix: str = self._get_current_prefix(current_prefix_length)

# Prepend the prefix to all prompts with a randomly picked 4 digits
prompt = f"{current_prefix}{random.randint(1000,9999)}"

prompt_tokens = self.tokenizer.encode(prompt)
left_tokens_to_sample = num_input_tokens - len(prompt_tokens)

if left_tokens_to_sample < 0:
return self.tokenizer.decode(prompt_tokens[:num_input_tokens])
while left_tokens_to_sample > 0:
random.shuffle(data_copy)
for line in data_copy:
Expand Down
91 changes: 89 additions & 2 deletions tests/sampling/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ def mock_encode(text, add_special_tokens=False):
# Count actual tokens in result
# Need to handle mixed content (original lines + decoded text)
total_tokens = 0

# All prompts start with 4 numbers, which are 1 token
total_tokens += 1
result = result[4:]

# Split by our test lines to count tokens properly
remaining = result
for line in self.test_data:
Expand Down Expand Up @@ -255,6 +260,88 @@ def test_sample_text_truncation(self):
_ = self.sampler._sample_text(requested_tokens)

# Verify decode was called with truncated tokens
self.tokenizer.decode.assert_called_with(
line_tokens[:requested_tokens], skip_special_tokens=True
self.tokenizer.decode.assert_called_with(line_tokens[:requested_tokens])

def test_sample_chat_prefix_ratio_request(self):
"""Test prefix generation using ratio."""

# Mock encode to return list with length equal to number of characters in input
def mock_encode(text, add_special_tokens=False):
# ignore space
encoded_text = [1] * len(text.replace(" ", ""))
return encoded_text

self.tokenizer.encode = mock_encode

# Mock decode to return the original text
def mock_decode(tokens, skip_special_tokens=True):
if isinstance(tokens, list):
return "a" * len(tokens) # Return 'a' repeated for the token count
return "decoded_text"

self.tokenizer.decode = mock_decode

scenario = NormalDistribution(
mean_input_tokens=20,
stddev_input_tokens=0,
mean_output_tokens=20,
stddev_output_tokens=0,
)
prefix_sampler = TextSampler(
tokenizer=self.tokenizer,
model=self.model,
output_modality=self.output_modality,
data=self.test_data,
prompt_prefix_ratio=0.5, # 50% of 20 tokens = 10 tokens
)
result = prefix_sampler.sample(scenario)
self.assertIsInstance(result, UserChatRequest)
self.assertEqual(result.model, self.model)
self.assertTrue(isinstance(result.prompt, str))
self.assertGreater(len(result.prompt), 0)
self.assertTrue(result.prompt.startswith(prefix_sampler.prefix))
self.assertEqual(len(mock_encode(result.prompt)), 20)

def test_short_prompt_request(self):
"""Test that short prompts are handled correctly."""

def mock_encode(text, add_special_tokens=False):
return [1] * len(text)

self.tokenizer.encode = mock_encode

# Mock decode to return the original text
def mock_decode(tokens):
if isinstance(tokens, list):
return "a" * len(tokens) # Return 'a' repeated for the token count
return "decoded_text"

self.tokenizer.decode = mock_decode

self.sampler.data = ["2"]

# Scenario asks for only 1 input token
scenario = NormalDistribution(1, 0, 1, 0)

result = self.sampler.sample(scenario)
self.assertIsInstance(result, UserChatRequest)
# The prompt will be the 4-digit number, truncated to 1 char
self.assertEqual(len(result.prompt), 1)
self.assertGreater(len(result.prompt), 0)

def test_empty_dataset(self):
"""Test sampling from an empty dataset."""
empty_sampler = TextSampler(
tokenizer=self.tokenizer,
model=self.model,
output_modality=self.output_modality,
data=[],
)
scenario = NormalDistribution(10, 0, 10, 0)

with self.assertRaises(ValueError) as context:
empty_sampler.sample(scenario)

self.assertEqual(
str(context.exception), "Cannot sample text from an empty dataset"
)