Skip to content

Commit 88e3322

Browse files
committed
make number of tokens more accurate to scenario
1 parent 53c61ea commit 88e3322

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

genai_bench/sampling/text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _sample_prefix(self, current_prefix_length) -> str:
187187
prefix += line[: int(remaining_prefix_len * char_to_token_ratio)]
188188
return prefix
189189
prefix += line
190-
prefix_tokens_len += line_tokens_len
190+
prefix_tokens_len = self.get_token_length(prefix)
191191

192192
return prefix
193193

@@ -269,7 +269,7 @@ def _sample_text(self, num_input_tokens: int) -> str:
269269
prompt += line[: int(left_tokens_to_sample * char_to_token_ratio)]
270270
return prompt
271271
prompt += line
272-
left_tokens_to_sample -= tokens
272+
left_tokens_to_sample = num_input_tokens - self.get_token_length(prompt)
273273
return prompt
274274

275275
def _sample_prompt(self) -> str:

tests/sampling/test_text.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ def test_check_discrepancy_no_warning(self):
6060
logger.removeHandler(ch)
6161

6262
def test_sample_chat_request(self):
63-
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
63+
def mock_get_token_length(text, add_special_tokens=False):
64+
return len(text) // 4 # Simple approximation: 4 chars per token
65+
66+
# Override the get_token_length method with our mock
67+
self.sampler.get_token_length = mock_get_token_length
68+
6469
scenario = NormalDistribution(
6570
mean_input_tokens=10,
6671
stddev_input_tokens=2,
@@ -84,7 +89,13 @@ def test_sample_chat_request_with_dataset(self):
8489
data=self.test_data,
8590
use_scenario=False,
8691
)
87-
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
92+
93+
def mock_get_token_length(text, add_special_tokens=False):
94+
return len(text) // 4 # Simple approximation: 4 chars per token
95+
96+
# Override the get_token_length method with our mock
97+
no_scenario_sampler.get_token_length = mock_get_token_length
98+
8899
scenario = NormalDistribution(
89100
mean_input_tokens=10,
90101
stddev_input_tokens=2,
@@ -102,13 +113,18 @@ def test_sample_chat_request_with_dataset(self):
102113
) # Should be None for non-scenario sampling
103114

104115
def test_sample_embedding_request(self):
105-
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
116+
def mock_get_token_length(text, add_special_tokens=False):
117+
return len(text) // 4 # Simple approximation: 4 chars per token
118+
106119
embedding_sampler = TextSampler(
107120
tokenizer=self.tokenizer,
108121
model=self.model,
109122
output_modality="embeddings",
110123
data=self.test_data,
111124
)
125+
# Override the get_token_length method with our mock
126+
embedding_sampler.get_token_length = mock_get_token_length
127+
112128
scenario = EmbeddingScenario(tokens_per_document=1024)
113129

114130
request = embedding_sampler.sample(scenario)
@@ -119,13 +135,20 @@ def test_sample_embedding_request(self):
119135
self.assertTrue(len(request.documents) > 0)
120136

121137
def test_sample_rerank_request(self):
122-
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
138+
# Mock get_token_length to return different values based on input length
139+
# This prevents infinite loops in _sample_text()
140+
def mock_get_token_length(text, add_special_tokens=False):
141+
return len(text) // 2 # Simple approximation: 2 chars per token
142+
123143
rerank_sampler = TextSampler(
124144
tokenizer=self.tokenizer,
125145
model=self.model,
126146
output_modality="rerank",
127147
data=self.test_data,
128148
)
149+
# Override the get_token_length method with our mock
150+
rerank_sampler.get_token_length = mock_get_token_length
151+
129152
scenario = ReRankScenario(tokens_per_document=1024, tokens_per_query=100)
130153

131154
request = rerank_sampler.sample(scenario)

0 commit comments

Comments
 (0)