diff --git a/engine/base_client/search.py b/engine/base_client/search.py index a52ab47d..b07c0fe8 100644 --- a/engine/base_client/search.py +++ b/engine/base_client/search.py @@ -13,7 +13,13 @@ DEFAULT_TOP = 10 MAX_QUERIES = int(os.getenv("MAX_QUERIES", -1)) +def chunkify(lst, n): + """Split list into n approximately equal chunks.""" + return [lst[i::n] for i in range(n)] +def process_chunk(chunk, search_one): + """Process a chunk of queries.""" + return [search_one(q) for q in chunk] class BaseSearcher: MP_CONTEXT = None @@ -58,6 +64,7 @@ def _search_one(cls, query, top: Optional[int] = None): precision = len(ids.intersection(query.expected_result[:top])) / top return precision, end - start + def search_all( self, distance, @@ -72,7 +79,7 @@ def search_all( self.setup_search() search_one = functools.partial(self.__class__._search_one, top=top) - used_queries = queries + used_queries = list(queries) if MAX_QUERIES > 0: @@ -86,6 +93,7 @@ def search_all( ) else: ctx = get_context(self.get_mp_start_method()) + query_chunks = chunkify(used_queries, parallel) with ctx.Pool( processes=parallel, @@ -100,9 +108,9 @@ def search_all( if parallel > 10: time.sleep(15) # Wait for all processes to start start = time.perf_counter() - precisions, latencies = list( - zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(used_queries))) - ) + results = pool.starmap(process_chunk, [(chunk, search_one) for chunk in query_chunks]) + + precisions, latencies = zip(*[item for sublist in results for item in sublist]) total_time = time.perf_counter() - start