diff --git a/dataset_reader/ann_compound_reader.py b/dataset_reader/ann_compound_reader.py index a2b03301..f6352927 100644 --- a/dataset_reader/ann_compound_reader.py +++ b/dataset_reader/ann_compound_reader.py @@ -1,4 +1,5 @@ import json +import random from typing import Iterator, List import numpy as np @@ -26,7 +27,9 @@ def read_vectors(self) -> Iterator[List[float]]: def read_queries(self) -> Iterator[Query]: with open(self.path / self.QUERIES_FILE) as payloads_fp: - for idx, row in enumerate(payloads_fp): + lines = payloads_fp.readlines() + random.shuffle(lines) + for idx, row in enumerate(lines): row_json = json.loads(row) vector = np.array(row_json["query"]) if self.normalize: diff --git a/dataset_reader/ann_h5_reader.py b/dataset_reader/ann_h5_reader.py index 1bc984ac..660e49a2 100644 --- a/dataset_reader/ann_h5_reader.py +++ b/dataset_reader/ann_h5_reader.py @@ -1,6 +1,7 @@ from typing import Iterator import h5py +import random import numpy as np from benchmark import DATASETS_DIR @@ -15,9 +16,9 @@ def __init__(self, path, normalize=False): def read_queries(self) -> Iterator[Query]: data = h5py.File(self.path) - for vector, expected_result, expected_scores in zip( - data["test"], data["neighbors"], data["distances"] - ): + query_data = list(zip(data["test"], data["neighbors"], data["distances"])) + random.shuffle(query_data) + for vector, expected_result, expected_scores in query_data : if self.normalize: vector /= np.linalg.norm(vector) yield Query(