From c966c9b84a652a46d11dd27a38911c96bd53ad58 Mon Sep 17 00:00:00 2001 From: LeoSvalov Date: Tue, 24 May 2022 13:18:46 +0300 Subject: [PATCH] Add algorithm implementation --- doc/api.rst | 9 + doc/modules/nswgraph.rst | 35 ++ doc/user_guide.rst | 2 + setup.py | 6 + sklearn_extra/neighbors/__init__.py | 3 + .../_navigable_small_world_graph.pxd | 59 +++ .../_navigable_small_world_graph.pyx | 337 ++++++++++++++++++ sklearn_extra/neighbors/_nswgraph.py | 244 +++++++++++++ sklearn_extra/neighbors/tests/__init__.py | 0 .../neighbors/tests/test_nswgraph.py | 59 +++ sklearn_extra/tests/test_common.py | 2 + 11 files changed, 756 insertions(+) create mode 100644 doc/modules/nswgraph.rst create mode 100644 sklearn_extra/neighbors/__init__.py create mode 100644 sklearn_extra/neighbors/_navigable_small_world_graph.pxd create mode 100644 sklearn_extra/neighbors/_navigable_small_world_graph.pyx create mode 100644 sklearn_extra/neighbors/_nswgraph.py create mode 100644 sklearn_extra/neighbors/tests/__init__.py create mode 100644 sklearn_extra/neighbors/tests/test_nswgraph.py diff --git a/doc/api.rst b/doc/api.rst index 25fc8ed8..de427dba 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -44,3 +44,12 @@ Robust robust.RobustWeightedClassifier robust.RobustWeightedRegressor robust.RobustWeightedKMeans + +Neighbors +==================== + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + neighbors.NSWGraph \ No newline at end of file diff --git a/doc/modules/nswgraph.rst b/doc/modules/nswgraph.rst new file mode 100644 index 00000000..3f92fc3e --- /dev/null +++ b/doc/modules/nswgraph.rst @@ -0,0 +1,35 @@ +.. _neighbors: + +============================================================ +Neighbors search with NSW graphs +============================================================ +.. _nswgraph: +.. currentmodule:: sklearn_extra.neighbors + + +A navigable small-world graph is a type of mathematical graph in which most nodes are not neighbors of one another, +but the neighbors of any given node are likely to be neighbors of each other and most nodes can be reached +from every other node by some small number of hops or steps [1]_. +The number of steps regulates by the property which must be satisfied by the navigable small-world graph: + +* The minimum number of edges that must be traversed to travel between two randomly chosen nodes grows proportionally to the logarithm of the number of nodes in the network [2]_. + +:class:`NSWGraph` is the approximate nearest neighbor algorithm based on navigable small world graphs. +The algorithm tends to be more optimal in case of high-dimensional data [3]_ in comparison with +existing Scikit-Learn approximate nearest neighbor algorithms based on :class:`KDTree ` +and :class:`BallTree `. + +See `Scikit-Learn User-guide `_ +for more general information on Nearest Neighbors search. + + +.. topic:: References: + + .. [1] Porter, Mason A. “Small-World Network.” Scholarpedia. + Available at: http://www.scholarpedia.org/article/Small-world_network. + + .. [2] Kleinberg, Jon. "The small-world phenomenon and decentralized search." SiAM News 37.3 (2004): 1-2. + + .. [3] Malkov, Y., Ponomarenko, A., Logvinov, A., & Krylov, V. (2014). + Approximate nearest neighbor algorithm based on navigable small world graphs. + Information Systems, 45, 61-68. diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 0c90c2e8..5d9202d6 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -14,3 +14,5 @@ User guide modules/cluster.rst modules/robust.rst modules/kernel_approximation.rst + modules/nswgraph.rst + diff --git a/setup.py b/setup.py index 6c6399a5..17f847fd 100755 --- a/setup.py +++ b/setup.py @@ -79,6 +79,12 @@ include_dirs=[np.get_include()], language="c++", ), + Extension( + "sklearn_extra.neighbors._navigable_small_world_graph", + ["sklearn_extra/neighbors/_navigable_small_world_graph.pyx"], + include_dirs=[np.get_include()], + language="c++", + ), ] ), "cmdclass": dict(build_ext=build_ext), diff --git a/sklearn_extra/neighbors/__init__.py b/sklearn_extra/neighbors/__init__.py new file mode 100644 index 00000000..e68c0a90 --- /dev/null +++ b/sklearn_extra/neighbors/__init__.py @@ -0,0 +1,3 @@ +from ._nswgraph import NSWGraph + +__all__ = ["NSWGraph"] diff --git a/sklearn_extra/neighbors/_navigable_small_world_graph.pxd b/sklearn_extra/neighbors/_navigable_small_world_graph.pxd new file mode 100644 index 00000000..cee61ccd --- /dev/null +++ b/sklearn_extra/neighbors/_navigable_small_world_graph.pxd @@ -0,0 +1,59 @@ +# distutils: language=c++ +# NSWG-based ANN classification +# Authors: Lev Svalov +# Stanislav Protasov +# License: BSD 3 clause +import numpy as np +cimport numpy as np +np.import_array() +from libcpp.vector cimport vector +from libcpp.set cimport set as set_c +from libcpp.pair cimport pair as pair +from libcpp.queue cimport priority_queue +from libcpp cimport bool +ctypedef np.int_t ITYPE_t +ctypedef np.float64_t DTYPE_t +ctypedef bool BTYPE_t + +cdef class BaseNSWGraph: + """ + Declaration of Cython additional class for the NSWGraph implementation + """ + + # attributes declaration with types + cdef ITYPE_t dimension + cdef ITYPE_t regularity + cdef ITYPE_t guard_hops + cdef ITYPE_t attempts + cdef BTYPE_t quantize + cdef ITYPE_t quantization_levels + cdef ITYPE_t number_nodes + cdef DTYPE_t norm_factor + cdef vector[vector[DTYPE_t]] nodes + cdef vector[set_c[ITYPE_t]] neighbors + cdef vector[vector[DTYPE_t]] lookup_table + cdef vector[DTYPE_t] quantization_values + + + # methods declaration with types and non-utilization of Global Interpreter Lock (GIL) + + cdef DTYPE_t eucl_dist(self, vector[DTYPE_t] v1, vector[DTYPE_t] v2) nogil + + cdef priority_queue[pair[DTYPE_t, ITYPE_t]] delete_duplicate(self, priority_queue[pair[DTYPE_t, ITYPE_t]] queue) nogil + + cdef void search_nsw_basic(self, vector[DTYPE_t] query, + set_c[ITYPE_t]* visitedSet, + priority_queue[pair[DTYPE_t, ITYPE_t]]* candidates, + priority_queue[pair[DTYPE_t, ITYPE_t]]* result, + ITYPE_t* res_hops, + ITYPE_t k) nogil + + cdef void _build_navigable_graph(self, vector[vector[DTYPE_t]] values) nogil + + cdef pair[vector[ITYPE_t], vector[DTYPE_t]] _multi_search(self, vector[DTYPE_t] query, ITYPE_t k) nogil + + cdef vector[vector[DTYPE_t]] ndarray_to_vector_2(self, np.ndarray array) + + cdef np.ndarray _get_quantized(self, np.ndarray vector) + + cdef np.ndarray _quantization(self, np.ndarray data) diff --git a/sklearn_extra/neighbors/_navigable_small_world_graph.pyx b/sklearn_extra/neighbors/_navigable_small_world_graph.pyx new file mode 100644 index 00000000..106af992 --- /dev/null +++ b/sklearn_extra/neighbors/_navigable_small_world_graph.pyx @@ -0,0 +1,337 @@ +# distutils: language = c++ +#!python +# cython: boundscheck=False +# cython: wraparound=False +# cython: cdivision=True + +# NSWG-based ANN classification +# Authors: Lev Svalov +# Stanislav Protasov +# License: BSD 3 clause + +cimport numpy as np +np.import_array() +from libcpp.vector cimport vector +from libcpp.set cimport set as set_c +from libcpp.pair cimport pair as pair +from libc.math cimport pow +from libcpp.queue cimport priority_queue +from libc.stdlib cimport rand +import itertools +import numpy as np + +cdef class BaseNSWGraph: + """ + Cython-Optimized implementation of the Navigable small world graph structure + + Parameters + ---------- + regularity : int, default: 16 + The size of the friends list of every vertex in the graph. + Higher regularity leads to more accurate but slower search. + + guard_hops : int, default: 100 + The number of bi-directional links created for every new element in the graph. + + quantize : bool, default: False + If True, use a product quantization for the preliminary dimensionality reduction of the data. + + quantization_levels : int, default: 20 + (Used if quantize=True) + The number of the values used in quantization approximation of the dataset. + + """ + def __init__(self, ITYPE_t regularity=16, + ITYPE_t guard_hops=100, + ITYPE_t attempts=2, + BTYPE_t quantize=False, + ITYPE_t quantization_levels=20): + self.regularity = regularity + self.guard_hops = guard_hops + self.attempts = attempts + self.quantize = quantize, + self.quantization_levels = quantization_levels + + cdef priority_queue[pair[DTYPE_t, ITYPE_t]] delete_duplicate(self, priority_queue[pair[DTYPE_t, ITYPE_t]] queue) nogil: + """ + Auxiliary method for removing the duplicated nodes from the neighbor candidates sequence + + Parameters + ---------- + queue: priority_queue of pairs consisting double distance value and the index of the particular node + """ + cdef priority_queue[pair[DTYPE_t, ITYPE_t]] new_que + cdef set_c[ITYPE_t] tmp_set + new_que.push(queue.top()) + tmp_set.insert(queue.top().second) + queue.pop() + while queue.size() != 0: + if tmp_set.find(queue.top().second) == tmp_set.end(): + tmp_set.insert(queue.top().second) + new_que.push(queue.top()) + queue.pop() + return new_que + + cdef DTYPE_t eucl_dist(self, vector[DTYPE_t] v1, vector[DTYPE_t] v2) nogil: + """ + Calculation of the reduced Euclidean distance between two data vectors + + Parameters + ---------- + v1, v2: vector of double features values + + Returns + ------- + d: double, reduced Euclidean distance value + """ + cdef ITYPE_t i = 0 + cdef DTYPE_t res = 0 + if self.quantize: + for i in range(v1.size()): + res += self.lookup_table[int(v2[i])][int(v1[i])] + else: + for i in range(v1.size()): + res += pow(v1[i] - v2[i], 2) + return res + + + cdef void search_nsw_basic(self, vector[DTYPE_t] query, + set_c[ITYPE_t]* visitedSet, + priority_queue[pair[DTYPE_t, ITYPE_t]]* candidates, + priority_queue[pair[DTYPE_t, ITYPE_t]]* result, + ITYPE_t* res_hops, + ITYPE_t k) nogil: + """ + Single search for neighbors candidates for the provided query vector + + Parameters + ---------- + query: query data vector consisting double features values + visitedSet: pointer set of nodes indices that was already visited by attempted searches + candidates: pointer to sequence of possible neighbors for the query + result: pointer to final sequence of neighbors of the query + res_hops: pointer to the result number of hops obtained after the search + k: number of neighbors + """ + cdef ITYPE_t entry = rand() % self.nodes.size() + cdef ITYPE_t hops = 0 + cdef DTYPE_t closest_dist = 0 + cdef ITYPE_t closest_id = 0 + cdef ITYPE_t e = 0 + cdef DTYPE_t d = 0 + cdef pair[DTYPE_t, ITYPE_t] tmp_pair + + d = self.eucl_dist(query, self.nodes[entry]) + tmp_pair.first = d * (-1) + tmp_pair.second = entry + + if visitedSet[0].find(entry) == visitedSet[0].end(): + candidates[0].push(tmp_pair) + tmp_pair.first = tmp_pair.first * (-1) + result[0].push(tmp_pair) + hops = 0 + + while hops < self.guard_hops: + hops += 1 + if candidates[0].size() == 0: + break + tmp_pair = candidates[0].top() + candidates.pop() + closest_dist = tmp_pair.first * (-1) + closest_id = tmp_pair.second + if result[0].size() >= k: + while result[0].size() > k: + result[0].pop() + + if result[0].top().first < closest_dist: + break + + for e in self.neighbors[closest_id]: + if visitedSet[0].find(e) == visitedSet[0].end(): + d = self.eucl_dist(query, self.nodes[e]) + visitedSet[0].insert(e) + tmp_pair.first = d + tmp_pair.second = e + result.push(tmp_pair) + tmp_pair.first = tmp_pair.first * (-1) + candidates.push(tmp_pair) + res_hops[0] = hops + + + cdef np.ndarray _get_quantized(self, np.ndarray vector): + """ Auxiliary method for transformation the initial data vector to quantized version """ + result = [] + for i, data_value in enumerate(vector): + result.append((np.abs(self.quantization_values - data_value)).argmin()) + return np.array(result) + + + cdef pair[vector[ITYPE_t], vector[DTYPE_t]] _multi_search(self, vector[DTYPE_t] query, ITYPE_t k) nogil: + """ + Main neighbors search function that combines results from multiple attempted single search and deletes duplicated results + + Parameters + ---------- + query: query data vector consisting double features values + k: number of neighbors + + Returns + ------- + ind, dist: pair of sequences: indices of neighbor vector and corresponding distance to the query vector + """ + cdef set_c[ITYPE_t] visitedSet + cdef priority_queue[pair[DTYPE_t, ITYPE_t]] candidates + cdef priority_queue[pair[DTYPE_t, ITYPE_t]] result + cdef vector[ITYPE_t] res + cdef vector[DTYPE_t] dist + cdef ITYPE_t i + cdef ITYPE_t hops + cdef pair[DTYPE_t, ITYPE_t] j + cdef ITYPE_t id + cdef DTYPE_t d + + for i in range(self.attempts): + self.search_nsw_basic(query, &visitedSet, &candidates, &result, &hops, k) + result = self.delete_duplicate(result) + while result.size() > k: + result.pop() + while res.size() < k: + if result.empty(): + break + el = result.top().second + d = result.top().first + dist.insert(dist.begin(), d) + res.insert(res.begin(), el) + result.pop() + + return pair[vector[ITYPE_t], vector[DTYPE_t]](res, dist) + + + cdef void _build_navigable_graph(self, vector[vector[DTYPE_t]] X) nogil: + """ + Build the Navigable small world graph + + Parameters + ---------- + X: query data vectors that are constructing the data structure + """ + cdef vector[DTYPE_t] val + cdef vector[ITYPE_t] closest + cdef ITYPE_t c + cdef ITYPE_t i + cdef vector[ITYPE_t] res + cdef set_c[ITYPE_t] tmp_set + if X.size() != self.number_nodes: + raise Exception("Number of nodes don't match") + if X[0].size() != self.dimension: + raise Exception("Dimension doesn't match") + + self.nodes.clear() + self.neighbors.clear() + + self.nodes.push_back(X[0]) + for i in range(self.number_nodes): + self.neighbors.push_back(tmp_set) + + for i in range(1, self.number_nodes): + val = X[i] + closest.clear() + closest = self._multi_search(val, k=self.regularity).first + self.nodes.push_back(val) + for c in closest: + self.neighbors[i].insert(c) + self.neighbors[c].insert(i) + + cdef vector[vector[DTYPE_t]] ndarray_to_vector_2(self, np.ndarray X): + """ + Auxiliary method for conversion the numpy array of data vectors to libcpp 2d vector + + Parameters + ---------- + X: numpy array to convert to the 2d vector + + Returns + ------- + X_vector: libcpp 2d vector + """ + cdef vector[vector[DTYPE_t]] X_vector + cdef ITYPE_t i + for i in range(len(X)): + X_vector.push_back((X[i])) + return X_vector + + cdef np.ndarray _quantization(self, np.ndarray X): + """ + Auxiliary method for quantization of the given data. + It quantizes the data vectors and constructs the lookup table of reduced distances + + Parameters + ---------- + X: the given data to build NSWG + + Returns + ------- + X_quantized: the quantizers data with reduced dimensionality + """ + self.quantization_values = np.linspace(0.0, 1.0, num=self.quantization_levels) + self.lookup_table = np.zeros(shape=(self.quantization_levels,self.quantization_levels)) + for v in itertools.combinations(enumerate(self.quantization_values), 2): + i = v[0][0] + j = v[1][0] + self.lookup_table[i][j] = pow(np.abs(v[0][1]-v[1][1]),2) + self.lookup_table[j][i] = pow(np.abs(v[1][1]-v[0][1]),2) + X_quantized = [] + for i, vector in enumerate(X): + X_quantized.append(self._get_quantized(vector)) + return np.array(X_quantized) + + def build(self, X): + """ + Build BaseNSWGraph on the provided data. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = (n_samples, n_features), + Training data. + """ + self.number_nodes = len(X) + self.dimension = len(X[0]) + if self.quantize: + quantized_data = self._quantization(X) + X = quantized_data + cdef vector[vector[DTYPE_t]] X_vector = self.ndarray_to_vector_2(X) + self._build_navigable_graph(X_vector) + + def query(self, np.ndarray queries, ITYPE_t k=1): + """Query the BaseNSWGraph for the k nearest neighbors + + Parameters + ---------- + queries : array-like, shape = (n_samples, n_features), + An array of points to query + + k : int, default=1 + The number of nearest neighbors to return + + Returns + ------- + dist: ndarray of shape X.shape[:-1] + (k,), dtype=double + Each entry gives the list of distances to the neighbors of the corresponding point. + + ind: ndarray of shape X.shape[:-1] + (k,), dtype=int + Each entry gives the list of indices of neighbors of the corresponding point. + """ + ind = [] + dist = [] + cdef pair[vector[ITYPE_t], vector[DTYPE_t]] res + cdef vector[vector[DTYPE_t]] query_vector + for query in queries: + if self.quantize: + normalized_query = query + query = self._get_quantized(normalized_query) + query = np.array([query]) + query_vector = self.ndarray_to_vector_2(query) + res = self._multi_search(query_vector[0], k) + ind.append(res.first) + dist.append(res.second) + return np.array(dist, dtype=object), np.array(ind, dtype=object) diff --git a/sklearn_extra/neighbors/_nswgraph.py b/sklearn_extra/neighbors/_nswgraph.py new file mode 100644 index 00000000..b938a632 --- /dev/null +++ b/sklearn_extra/neighbors/_nswgraph.py @@ -0,0 +1,244 @@ +# NSWG-based ANN classification +# Authors: Lev Svalov +# Stanislav Protasov +# License: BSD 3 clause + +from sklearn.base import BaseEstimator, ClassifierMixin +from ._navigable_small_world_graph import BaseNSWGraph +from sklearn.utils.validation import check_array, check_is_fitted, check_X_y +from sklearn.utils.multiclass import type_of_target +import numpy as np + + +def _check_positive_int(value, desc): + """Validates if value is a valid integer > 0""" + if value is None or not isinstance(value, (int, np.integer)) or value <= 0: + raise ValueError( + "%s should be a positive integer. " "%s was given" % (desc, value) + ) + + +def _check_label_type(y): + """Validates if labels type is correct for the estimator""" + if type_of_target(y) not in ["binary", "multiclass"]: + raise ValueError("Unknown label type: ") + + +class NSWGraph(BaseEstimator, ClassifierMixin, BaseNSWGraph): + """Nearest Neighbors search using Navigable small world graphs. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + regularity : int, default: 16 + The size of the friends list of every vertex in the graph. + Higher regularity leads to more accurate but slower search. + + guard_hops : int, default: 100 + The number of bi-directional links created for every new element in the graph. + + quantize : bool, default: False + If True, use a product quantization for the preliminary dimensionality reduction of the data. + + quantization_levels : int, default: 20 + (Used if quantize=True) + The number of the values used in quantization approximation of the dataset. + + Attributes + ---------- + classes_ : ndarray of shape (n_classes, ) + A list of class labels known to the classifier. + + y_: ndarray of shape (n_samples, ) + A list of labels for the corresponding targets + + n_features_in_: ndarray of shape (n_features, ) + A number of features in the provided data + + is_fitted_: boolean + The boolean flag, indicates that the estimator was constructed with provided targets, + so that the predict() can be used. + + is_constructed_: boolean + The boolean flag, indicates that the estimator was constructed without provided targets, + so that the eestimator can query for neighbours regardless its labels + + Examples + -------- + >>> from sklearn_extra.neighbors import NSWGraph + >>> import numpy as np + >>> rng = np.random.RandomState(10) + >>> X = rng.random_sample((50, 128)) + >>> nswgraph = NSWGraph() + >>> nswgraph.build(X) + NSWGraph(regularity=16, guard_hops=100, attempts=2, quantize=False, quantization_levels=20) + >>> X_val = rng.random_sample((5, 128)) + >>> dist, ind = nswgraph.query(X_val, k=3) + + References + ---------- + * Malkov, Y., Ponomarenko, A., Logvinov, A., & Krylov, V. (2014). Approximate nearest neighbor algorithm based on navigable small world graphs. Information Systems, 45, 61-68. + + + """ + + def __init__( + self, + regularity=16, + guard_hops=100, + attempts=2, + quantize=False, + quantization_levels=20, + ): + super().__init__() + self.regularity = regularity + self.guard_hops = guard_hops + self.attempts = attempts + self.quantize = quantize + self.quantization_levels = quantization_levels + self._check_init_args() + + def _check_init_args(self): + """Validation of the initialization arguments""" + _check_positive_int(self.regularity, "regularity") + _check_positive_int(self.guard_hops, "guard_hops") + _check_positive_int(self.attempts, "attempts") + _check_positive_int(self.quantization_levels, "quantization_levels") + if not isinstance(self.quantize, bool): + raise ValueError( + "%s should be a boolean. " + "%s was given" % ("Quantization switch", self.quantize) + ) + + def _check_dimension_correspondence(self, X): + if X.shape[1] != self.n_features_in_: + raise ValueError( + "Wrong dimensionality of the data." + "Estimator is built with %s, but %s was given" + % (self.n_features_in_, X.shape[1]) + ) + + def __repr__(self, **kwargs): + return f"NSWGraph(regularity={self.regularity}, guard_hops={self.guard_hops}, attempts={self.attempts}, quantize={self.quantize}, quantization_levels={self.quantization_levels})" + + def build(self, X, y=None): + """Build NSWGraph on the provided data. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = (n_samples, n_features), + Training data. + + y : array-like of shape (n_samples,) + Target labels. + + Returns + ------- + self: NSWGraph + The constructed NSWGraph + + """ + self._check_init_args() + + if y is not None: + self.fit(X, y) + else: + X = check_array(X, dtype=[np.float64, np.float32]) + super().build(X) + self.is_constructed_ = True + self.n_features_in_ = X.shape[1] + + return self + + def fit(self, X, y): + """Build NSWGraph on the provided data and link it with the labels. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = (n_samples, n_features), + Training data. + + y : array-like of shape (n_samples,) + Target labels. + + Returns + ------- + self: NSWGraph + The constructed NSWGraph and fitted nearest neighbors classifier. + """ + self._check_init_args() + X = check_array(X) + X, y = check_X_y(X, y) + _check_label_type(y) + super().build(X) + self.classes_ = np.unique(y) + self.y_ = y + self.n_features_in_ = X.shape[1] + self.is_fitted_ = True + self.is_constructed_ = True + + return self + + def query(self, X, k=1, return_distance=True): + """Query the NSWGraph for the k nearest neighbors + + Parameters + ---------- + X : array-like, shape = (n_samples, n_features), + An array of points to query + + k : int, default=1 + The number of nearest neighbors to return + + return_distance: bool, default=True + if True, return a tuple (dist, ind) of dists and indices if False, return array i + + Returns + ------- + dist: ndarray of shape X.shape[:-1] + (k,), dtype=double + Each entry gives the list of distances to the neighbors of the corresponding point. + + ind: ndarray of shape X.shape[:-1] + (k,), dtype=int + Each entry gives the list of indices of neighbors of the corresponding point. + """ + + check_is_fitted( + self, + "is_constructed_", + msg="This %(name)s instance is not constructed yet. Call 'build' with " + "appropriate arguments before using this estimator.", + ) + + X = check_array(X, dtype=[np.float64, np.float32]) + self._check_dimension_correspondence(X) + _check_positive_int(k, "k-closests") + dist, ind = super().query(X, k) + + if return_distance: + return dist, ind + else: + return ind + + def predict(self, X: np.ndarray) -> np.ndarray: + """Predict the class labels for the provided query data. + The label of the closest neighbor is supposed to be predicted label + + Parameters + ---------- + X : array-like, shape = (n_samples, n_features), + An array of data vectors to query + + Returns + ------- + y : ndarray of shape (n_queries,) + Label for each data sample. + """ + + check_is_fitted(self, "is_fitted_") + X = check_array(X, dtype=[np.float64, np.float32]) + self._check_dimension_correspondence(X) + + _, ind = super().query(X, k=1) + result = np.array([self.y_[res[0]] for res in ind]) + return result diff --git a/sklearn_extra/neighbors/tests/__init__.py b/sklearn_extra/neighbors/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sklearn_extra/neighbors/tests/test_nswgraph.py b/sklearn_extra/neighbors/tests/test_nswgraph.py new file mode 100644 index 00000000..721a140b --- /dev/null +++ b/sklearn_extra/neighbors/tests/test_nswgraph.py @@ -0,0 +1,59 @@ +import pytest +from sklearn_extra.neighbors import NSWGraph +from sklearn.utils.validation import check_array, check_random_state +from numpy.testing import assert_array_almost_equal +import numpy as np +from sklearn.metrics import DistanceMetric + + +def brute_force_neighbors(X, Y, k, metric, **kwargs): + """True neighbours for assertion check. Taken from BallTree tests in Scikit-Learn""" + X, Y = check_array(X), check_array(Y) + D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X) + ind = np.argsort(D, axis=1)[:, :k] + return ind + + +def test_array_object_type(): + """Check that we do not accept object dtype array. Taken from BallTree tests in Scikit-Learn""" + X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object) + nswgraph = NSWGraph() + with pytest.raises( + ValueError, match="setting an array element with a sequence" + ): + nswgraph.build(X) + + +def test_init_types(): + """Make sure that the init args validation check works properly""" + regularity = -1 + with pytest.raises(ValueError): + nswgraph = NSWGraph(regularity=regularity) + + guard_hops = "something" + with pytest.raises(ValueError): + nswgraph = NSWGraph(guard_hops=guard_hops) + + quantize = "True" + with pytest.raises(ValueError): + nswgraph = NSWGraph(quantize=quantize) + + quantization_levels = 1.5 + with pytest.raises(ValueError): + nswgraph = NSWGraph( + quantize=True, quantization_levels=quantization_levels + ) + + +def test_query(): + """Make sure that neighbours query works satisfactory using NSWGraph""" + + rng = check_random_state(0) + X = rng.random_sample((30, 16)) + nswgraph = NSWGraph() + nswgraph.build(X) + k = 3 + X_val = X[:1] + ind1 = nswgraph.query(X_val, k=k, return_distance=False) + ind2 = brute_force_neighbors(X, X_val, k=k, metric="euclidean") + assert_array_almost_equal(ind1, ind2) diff --git a/sklearn_extra/tests/test_common.py b/sklearn_extra/tests/test_common.py index 3a72dc32..01aef98a 100644 --- a/sklearn_extra/tests/test_common.py +++ b/sklearn_extra/tests/test_common.py @@ -4,6 +4,7 @@ from sklearn_extra.kernel_approximation import Fastfood from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor from sklearn_extra.cluster import KMedoids, CommonNNClustering, CLARA +from sklearn_extra.neighbors import NSWGraph from sklearn_extra.robust import ( RobustWeightedClassifier, RobustWeightedRegressor, @@ -21,6 +22,7 @@ RobustWeightedKMeans, RobustWeightedRegressor, RobustWeightedClassifier, + NSWGraph, ]