From 3512cac890094b4f3808457469ba46916d190eb3 Mon Sep 17 00:00:00 2001 From: Freddie Witherden Date: Mon, 30 Dec 2024 07:54:43 -0600 Subject: [PATCH] Update bulk interface to match changes in libspatialindex. (#343) * Update bulk interface to match changes in libspatialindex. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- rtree/index.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/rtree/index.py b/rtree/index.py index df7ad7e0..bf4d6ac2 100644 --- a/rtree/index.py +++ b/rtree/index.py @@ -1094,7 +1094,13 @@ def intersection_v(self, mins, maxs): ids = ids.resize(2 * len(ids), refcheck=False) def nearest_v( - self, mins, maxs, num_results=1, strict=False, return_max_dists=False + self, + mins, + maxs, + num_results=1, + max_dists=None, + strict=False, + return_max_dists=False, ): import numpy as np @@ -1114,10 +1120,18 @@ def nearest_v( ids = np.empty(n * num_results, dtype=np.int64) counts = np.empty(n, dtype=np.uint64) - dists = np.empty(n) if return_max_dists else None nr = ctypes.c_int64(0) offn, offi = 0, 0 + if max_dists is not None: + assert len(max_dists) == n + + dists = max_dists.astype(np.float64).copy() + elif return_max_dists: + dists = np.zeros(n) + else: + dists = None + while True: core.rt.Index_NearestNeighbors_id_v( self.handle, @@ -1131,7 +1145,7 @@ def nearest_v( maxs[offn:].ctypes.data, ids[offi:].ctypes.data, counts[offn:].ctypes.data, - dists[offn:].ctypes.data if return_max_dists else None, + dists[offn:].ctypes.data if dists is not None else None, ctypes.byref(nr), )