@@ -60,7 +60,12 @@ def test_check_discrepancy_no_warning(self):
6060 logger .removeHandler (ch )
6161
6262 def test_sample_chat_request (self ):
63- self .tokenizer .encode .return_value = [1 , 2 , 3 , 4 , 5 ]
63+ def mock_get_token_length (text , add_special_tokens = False ):
64+ return len (text ) // 4 # Simple approximation: 4 chars per token
65+
66+ # Override the get_token_length method with our mock
67+ self .sampler .get_token_length = mock_get_token_length
68+
6469 scenario = NormalDistribution (
6570 mean_input_tokens = 10 ,
6671 stddev_input_tokens = 2 ,
@@ -84,7 +89,13 @@ def test_sample_chat_request_with_dataset(self):
8489 data = self .test_data ,
8590 use_scenario = False ,
8691 )
87- self .tokenizer .encode .return_value = [1 , 2 , 3 , 4 , 5 ]
92+
93+ def mock_get_token_length (text , add_special_tokens = False ):
94+ return len (text ) // 4 # Simple approximation: 4 chars per token
95+
96+ # Override the get_token_length method with our mock
97+ no_scenario_sampler .get_token_length = mock_get_token_length
98+
8899 scenario = NormalDistribution (
89100 mean_input_tokens = 10 ,
90101 stddev_input_tokens = 2 ,
@@ -102,13 +113,18 @@ def test_sample_chat_request_with_dataset(self):
102113 ) # Should be None for non-scenario sampling
103114
104115 def test_sample_embedding_request (self ):
105- self .tokenizer .encode .return_value = [1 , 2 , 3 , 4 , 5 ]
116+ def mock_get_token_length (text , add_special_tokens = False ):
117+ return len (text ) // 4 # Simple approximation: 4 chars per token
118+
106119 embedding_sampler = TextSampler (
107120 tokenizer = self .tokenizer ,
108121 model = self .model ,
109122 output_modality = "embeddings" ,
110123 data = self .test_data ,
111124 )
125+ # Override the get_token_length method with our mock
126+ embedding_sampler .get_token_length = mock_get_token_length
127+
112128 scenario = EmbeddingScenario (tokens_per_document = 1024 )
113129
114130 request = embedding_sampler .sample (scenario )
@@ -119,13 +135,20 @@ def test_sample_embedding_request(self):
119135 self .assertTrue (len (request .documents ) > 0 )
120136
121137 def test_sample_rerank_request (self ):
122- self .tokenizer .encode .return_value = [1 , 2 , 3 , 4 , 5 ]
138+ # Mock get_token_length to return different values based on input length
139+ # This prevents infinite loops in _sample_text()
140+ def mock_get_token_length (text , add_special_tokens = False ):
141+ return len (text ) // 2 # Simple approximation: 2 chars per token
142+
123143 rerank_sampler = TextSampler (
124144 tokenizer = self .tokenizer ,
125145 model = self .model ,
126146 output_modality = "rerank" ,
127147 data = self .test_data ,
128148 )
149+ # Override the get_token_length method with our mock
150+ rerank_sampler .get_token_length = mock_get_token_length
151+
129152 scenario = ReRankScenario (tokens_per_document = 1024 , tokens_per_query = 100 )
130153
131154 request = rerank_sampler .sample (scenario )
0 commit comments