diff --git a/rtree/index.py b/rtree/index.py index df7ad7e0..bf4d6ac2 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -1094,7 +1094,13 @@ def intersection_v(self, mins, maxs): ids = ids.resize(2 * len(ids), refcheck=False) def nearest_v( - self, mins, maxs, num_results=1, strict=False, return_max_dists=False + self, + mins, + maxs, + num_results=1, + max_dists=None, + strict=False, + return_max_dists=False, ): import numpy as np @@ -1114,10 +1120,18 @@ def nearest_v( ids = np.empty(n * num_results, dtype=np.int64) counts = np.empty(n, dtype=np.uint64) - dists = np.empty(n) if return_max_dists else None nr = ctypes.c_int64(0) offn, offi = 0, 0 + if max_dists is not None: + assert len(max_dists) == n + + dists = max_dists.astype(np.float64).copy() + elif return_max_dists: + dists = np.zeros(n) + else: + dists = None + while True: core.rt.Index_NearestNeighbors_id_v( self.handle, @@ -1131,7 +1145,7 @@ def nearest_v( maxs[offn:].ctypes.data, ids[offi:].ctypes.data, counts[offn:].ctypes.data, - dists[offn:].ctypes.data if return_max_dists else None, + dists[offn:].ctypes.data if dists is not None else None, ctypes.byref(nr), )