Skip to content

Commit dbfbfc4

Browse files
committed
rebase fixes
1 parent 0a26ba7 commit dbfbfc4

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

genai_bench/cli/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010

1111
from genai_bench.analysis.excel_report import create_workbook
1212
from genai_bench.analysis.experiment_loader import load_one_experiment
13+
from genai_bench.analysis.flexible_plot_report import plot_experiment_data_flexible
1314
from genai_bench.analysis.plot_report import (
1415
plot_single_scenario_inference_speed_vs_throughput,
1516
)
16-
from genai_bench.analysis.flexible_plot_report import plot_experiment_data_flexible
17-
1817
from genai_bench.auth.unified_factory import UnifiedAuthFactory
1918
from genai_bench.cli.option_groups import (
2019
api_options,

tests/sampling/test_text.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,14 @@ def test_sample_chat_prefix_ratio_request(self):
267267

268268
# Mock encode to return list with length equal to number of characters in input
269269
def mock_encode(text, add_special_tokens=False):
270-
return [1] * len(text)
270+
# ignore space
271+
encoded_text = [1] * len(text.replace(" ", ""))
272+
return encoded_text
271273

272274
self.tokenizer.encode = mock_encode
273275

274276
# Mock decode to return the original text
275-
def mock_decode(tokens):
277+
def mock_decode(tokens, skip_special_tokens=True):
276278
if isinstance(tokens, list):
277279
return "a" * len(tokens) # Return 'a' repeated for the token count
278280
return "decoded_text"
@@ -290,7 +292,6 @@ def mock_decode(tokens):
290292
model=self.model,
291293
output_modality=self.output_modality,
292294
data=self.test_data,
293-
use_scenario=True,
294295
prompt_prefix_ratio=0.5, # 50% of 20 tokens = 10 tokens
295296
)
296297
result = prefix_sampler.sample(scenario)
@@ -299,7 +300,7 @@ def mock_decode(tokens):
299300
self.assertTrue(isinstance(result.prompt, str))
300301
self.assertGreater(len(result.prompt), 0)
301302
self.assertTrue(result.prompt.startswith(prefix_sampler.prefix))
302-
self.assertEqual(len(result.prompt), 20)
303+
self.assertEqual(len(mock_encode(result.prompt)), 20)
303304

304305
def test_short_prompt_request(self):
305306
"""Test that short prompts are handled correctly."""
@@ -335,7 +336,6 @@ def test_empty_dataset(self):
335336
model=self.model,
336337
output_modality=self.output_modality,
337338
data=[],
338-
use_scenario=True,
339339
)
340340
scenario = NormalDistribution(10, 0, 10, 0)
341341

0 commit comments

Comments
 (0)