55distances in aeon.distances.
66"""
77
8+ import numbers
9+ from typing import Optional
10+
811__maintainer__ = []
912__all__ = ["KNeighborsTimeSeriesClassifier" ]
1013
11- from typing import Callable , Optional , Union
14+ from typing import Callable , Union
1215
1316import numpy as np
14- from joblib import Parallel , delayed
1517
1618from aeon .classification .base import BaseClassifier
17- from aeon .distances import get_distance_function
19+ from aeon .distances import pairwise_distance
1820from aeon .utils .validation import check_n_jobs
1921
2022WEIGHTS_SUPPORTED = ["uniform" , "distance" ]
@@ -47,21 +49,17 @@ class KNeighborsTimeSeriesClassifier(BaseClassifier):
4749 n_timepoints)` as input and returns a ``float``.
4850 distance_params : dict, default = None
4951 Dictionary for metric parameters for the case that distance is a ``str``.
50- n_jobs : int, default = 1
51- The number of parallel jobs to run for neighbors search.
52- ``None`` means 1 unless in a :obj:``joblib.parallel_backend`` context.
53- ``-1`` means using all processors.
54- parallel_backend : str, ParallelBackendBase instance or None, default=None
55- Specify the parallelisation backend implementation in joblib, if None
56- a ‘prefer’ value of “threads” is used by default. Valid options are
57- “loky”, “multiprocessing”, “threading” or a custom backend.
58- See the joblib Parallel documentation for more details.
52+ n_jobs : int, default=1
53+ The number of jobs to run in parallel. If -1, then the number of jobs is set
54+ to the number of CPU cores. If 1, then the function is executed in a single
55+ thread. If greater than 1, then the function is executed in parallel.
5956
6057 Raises
6158 ------
6259 ValueError
6360 If ``weights`` is not among the supported values.
6461 See the ``weights`` parameter description for valid options.
62+ Dictionary for metric parameters for the case that distance is a str.
6563
6664 Examples
6765 --------
@@ -90,13 +88,11 @@ def __init__(
9088 n_neighbors : int = 1 ,
9189 weights : Union [str , Callable ] = "uniform" ,
9290 n_jobs : int = 1 ,
93- parallel_backend : str = None ,
9491 ) -> None :
9592 self .distance = distance
9693 self .distance_params = distance_params
9794 self .n_neighbors = n_neighbors
9895 self .n_jobs = n_jobs
99- self .parallel_backend = parallel_backend
10096
10197 self ._distance_params = distance_params
10298 if self ._distance_params is None :
@@ -124,7 +120,6 @@ def _fit(self, X, y):
124120 y : array-like, shape = (n_cases)
125121 The class labels.
126122 """
127- self .metric_ = get_distance_function (method = self .distance )
128123 self .X_ = X
129124 _ , self .y_ = np .unique (y , return_inverse = True )
130125 self ._n_jobs = check_n_jobs (self .n_jobs )
@@ -148,10 +143,26 @@ def _predict_proba(self, X):
148143 The class probabilities of the input samples. Classes are ordered
149144 by lexicographic order.
150145 """
151- preds = Parallel (n_jobs = self ._n_jobs , backend = self .parallel_backend )(
152- delayed (self ._proba_row )(x ) for x in X
153- )
154- return np .array (preds )
146+ preds = np .zeros ((len (X ), len (self .classes_ )))
147+ for i in range (len (X )):
148+ neigh_dist , neigh_ind = self .kneighbors (X [i : i + 1 ])
149+ neigh_dist = neigh_dist [0 ]
150+ neigh_ind = neigh_ind [0 ]
151+
152+ if self .weights == "distance" :
153+ weights = 1 / (neigh_dist + np .finfo (float ).eps )
154+ elif self .weights == "uniform" :
155+ weights = np .repeat (1.0 , len (neigh_ind ))
156+ else :
157+ raise Exception (f"Invalid kNN weights: { self .weights } " )
158+
159+ for id , w in zip (neigh_ind , weights ):
160+ predicted_class = self .y_ [id ]
161+ preds [i , predicted_class ] += w
162+
163+ preds [i ] = preds [i ] / np .sum (preds [i ])
164+
165+ return preds
155166
156167 def _predict (self , X ):
157168 """
@@ -170,70 +181,136 @@ def _predict(self, X):
170181 y : array of shape (n_cases)
171182 Class labels for each data sample.
172183 """
173- preds = Parallel (n_jobs = self ._n_jobs , backend = self .parallel_backend )(
174- delayed (self ._predict_row )(x ) for x in X
184+ self ._check_is_fitted ()
185+
186+ neigh_ind = self ._kneighbors (
187+ X , n_neighbors = 1 , return_distance = False , query_is_train = False
175188 )
176- return np .array (preds , dtype = self .classes_ .dtype )
189+ indexes = neigh_ind [:, 0 ]
190+ return self .classes_ [self .y_ [indexes ]]
177191
178- def _proba_row (self , x ):
179- scores = self ._predict_scores (x )
180- return scores / np .sum (scores )
192+ def kneighbors (self , X = None , n_neighbors = None , return_distance = True ):
193+ """Find the K-neighbors of a point.
181194
182- def _predict_row (self , x ):
183- scores = self ._predict_scores (x )
184- return self .classes_ [np .argmax (scores )]
195+ Returns indices of and distances to the neighbors of each point.
185196
186- def _predict_scores (self , x ):
187- scores = np .zeros (len (self .classes_ ))
188- idx , weights = self ._kneighbors (x )
189- for id , weight in zip (idx , weights ):
190- predicted_class = self .y_ [id ]
191- scores [predicted_class ] += weight
192- return scores
197+ Parameters
198+ ----------
199+ X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints) or list of
200+ shape [n_cases] of 2D arrays shape (n_channels,n_timepoints_i)
201+ The query point or points.
202+ If not provided, neighbors of each indexed point are returned.
203+ In this case, the query point is not considered its own neighbor.
204+ n_neighbors : int, default=None
205+ Number of neighbors required for each sample. The default is the value
206+ passed to the constructor.
207+ return_distance : bool, default=True
208+ Whether or not to return the distances.
193209
194- def _kneighbors (self , X ):
210+ Returns
211+ -------
212+ neigh_dist : ndarray of shape (n_queries, n_neighbors)
213+ Array representing the distances to points, only present if
214+ return_distance=True.
215+ neigh_ind : ndarray of shape (n_queries, n_neighbors)
216+ Indices of the nearest points in the population matrix.
195217 """
196- Find the K-neighbors of a point.
218+ self ._check_is_fitted ()
219+
220+ # Input validation
221+ if n_neighbors is None :
222+ n_neighbors = self .n_neighbors
223+ elif not isinstance (n_neighbors , numbers .Integral ):
224+ raise TypeError (
225+ f"n_neighbors does not take { type (n_neighbors )} value, enter integer "
226+ f"value"
227+ )
228+ elif n_neighbors <= 0 :
229+ raise ValueError (f"Expected n_neighbors > 0. Got { n_neighbors } " )
230+
231+ if not isinstance (return_distance , bool ):
232+ raise TypeError (
233+ f"return_distance must be a boolean, got { type (return_distance )} "
234+ )
235+
236+ # Preprocess X if provided
237+ query_is_train = X is None
238+ if query_is_train :
239+ X = self .X_
240+ else :
241+ X = self ._preprocess_collection (X , store_metadata = False )
242+ self ._check_shape (X )
243+
244+ # Validate n_neighbors against data size
245+ n_samples_fit = len (self .X_ )
246+ if query_is_train :
247+ if not (n_neighbors < n_samples_fit ):
248+ raise ValueError (
249+ "Expected n_neighbors < n_samples_fit, but "
250+ f"n_neighbors = { n_neighbors } , n_samples_fit = { n_samples_fit } , "
251+ f"n_samples = { len (X )} "
252+ )
253+ else :
254+ if not (n_neighbors <= n_samples_fit ):
255+ raise ValueError (
256+ "Expected n_neighbors <= n_samples_fit, but "
257+ f"n_neighbors = { n_neighbors } , n_samples_fit = { n_samples_fit } , "
258+ f"n_samples = { len (X )} "
259+ )
197260
198- Returns indices and weights of each point.
261+ return self ._kneighbors (X , n_neighbors , return_distance , query_is_train )
262+
263+ def _kneighbors (self , X , n_neighbors , return_distance , query_is_train ):
264+ """Find the K-neighbors of a point.
265+
266+ Returns indices of and distances to the neighbors of each point.
199267
200268 Parameters
201269 ----------
202- X : np.ndarray
203- A single time series instance if shape = ``(n_channels, n_timepoints)``
270+ X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints) or list of
271+ shape [n_cases] of 2D arrays shape (n_channels,n_timepoints_i)
272+ The query point or points.
273+ n_neighbors : int
274+ Number of neighbors required for each sample.
275+ return_distance : bool
276+ Whether or not to return the distances.
277+ query_is_train : bool
278+ Whether the query points are from the training set.
204279
205280 Returns
206281 -------
207- ind : array
282+ neigh_dist : ndarray of shape (n_queries, n_neighbors)
283+ Array representing the distances to points, only present if
284+ return_distance=True.
285+ neigh_ind : ndarray of shape (n_queries, n_neighbors)
208286 Indices of the nearest points in the population matrix.
209- ws : array
210- Array representing the weights of each neighbor.
211287 """
212- distances = np .array (
213- [
214- self .metric_ (X , self .X_ [j ], ** self ._distance_params )
215- for j in range (len (self .X_ ))
216- ]
288+ distances = pairwise_distance (
289+ X ,
290+ None if query_is_train else self .X_ ,
291+ method = self .distance ,
292+ n_jobs = self .n_jobs ,
293+ ** self ._distance_params ,
217294 )
218295
219- # Find indices of k nearest neighbors using partitioning:
220- # [0..k-1], [k], [k+1..n-1]
221- # They might not be ordered within themselves,
222- # but it is not necessary and partitioning is
223- # O(n) while sorting is O(nlogn)
224- closest_idx = np . argpartition ( distances , self . n_neighbors )
225- closest_idx = closest_idx [: self . n_neighbors ]
226-
227- if self . weights == "distance" :
228- ws = distances [closest_idx ]
229- # Using epsilon ~= 0 to avoid division by zero
230- ws = 1 / ( ws + np .finfo ( float ). eps )
231- elif self . weights == "uniform" :
232- ws = np . repeat ( 1.0 , self . n_neighbors )
233- else :
234- raise Exception ( f"Invalid kNN weights: { self . weights } " )
235-
236- return closest_idx , ws
296+ # If querying the training set, exclude self by setting diag to +inf
297+ if query_is_train :
298+ np . fill_diagonal ( distances , np . inf )
299+
300+ k = n_neighbors
301+ # 1) partial select smallest k
302+ idx_part = np . argpartition ( distances , kth = k - 1 , axis = 1 )[:, : k ]
303+ # 2) sort those k by (distance, index)
304+ row_idx = np . arange ( distances . shape [ 0 ])[:, None ]
305+ part_d = distances [row_idx , idx_part ]
306+ # argsort by distance, then by index for ties (lexsort uses last key as primary)
307+ order = np .lexsort (( idx_part , part_d ), axis = 1 )
308+ neigh_ind = idx_part [ row_idx , order ]
309+
310+ if return_distance :
311+ neigh_dist = distances [ row_idx , neigh_ind ]
312+ return neigh_dist , neigh_ind
313+ return neigh_ind
237314
238315 @classmethod
239316 def _get_test_params (
0 commit comments