@@ -267,12 +267,14 @@ def test_sample_chat_prefix_ratio_request(self):
267267
268268 # Mock encode to return list with length equal to number of characters in input
269269 def mock_encode (text , add_special_tokens = False ):
270- return [1 ] * len (text )
270+ # ignore space
271+ encoded_text = [1 ] * len (text .replace (" " , "" ))
272+ return encoded_text
271273
272274 self .tokenizer .encode = mock_encode
273275
274276 # Mock decode to return the original text
275- def mock_decode (tokens ):
277+ def mock_decode (tokens , skip_special_tokens = True ):
276278 if isinstance (tokens , list ):
277279 return "a" * len (tokens ) # Return 'a' repeated for the token count
278280 return "decoded_text"
@@ -290,7 +292,6 @@ def mock_decode(tokens):
290292 model = self .model ,
291293 output_modality = self .output_modality ,
292294 data = self .test_data ,
293- use_scenario = True ,
294295 prompt_prefix_ratio = 0.5 , # 50% of 20 tokens = 10 tokens
295296 )
296297 result = prefix_sampler .sample (scenario )
@@ -299,7 +300,7 @@ def mock_decode(tokens):
299300 self .assertTrue (isinstance (result .prompt , str ))
300301 self .assertGreater (len (result .prompt ), 0 )
301302 self .assertTrue (result .prompt .startswith (prefix_sampler .prefix ))
302- self .assertEqual (len (result .prompt ), 20 )
303+ self .assertEqual (len (mock_encode ( result .prompt ) ), 20 )
303304
304305 def test_short_prompt_request (self ):
305306 """Test that short prompts are handled correctly."""
@@ -335,7 +336,6 @@ def test_empty_dataset(self):
335336 model = self .model ,
336337 output_modality = self .output_modality ,
337338 data = [],
338- use_scenario = True ,
339339 )
340340 scenario = NormalDistribution (10 , 0 , 10 , 0 )
341341
0 commit comments