|
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import scipy.sparse as sparse |
7 | | -import sklearn.cluster |
| 7 | +from sparsekmeans import LloydKmeans, ElkanKmeans |
8 | 8 | import sklearn.preprocessing |
9 | 9 | from tqdm import tqdm |
10 | 10 | import psutil |
@@ -274,28 +274,29 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, |
274 | 274 | Returns: |
275 | 275 | Node: Root of the (sub)tree built from label_representation. |
276 | 276 | """ |
277 | | - if d >= dmax or label_representation.shape[0] <= K: |
278 | | - return Node(label_map=label_map, children=[]) |
279 | | - |
280 | | - metalabels = ( |
281 | | - sklearn.cluster.KMeans( |
282 | | - K, |
283 | | - random_state=np.random.randint(2**31 - 1), |
284 | | - n_init=1, |
285 | | - max_iter=300, |
286 | | - tol=0.0001, |
287 | | - algorithm="elkan", |
| 277 | + children = [] |
| 278 | + if d < dmax and label_representation.shape[0] > K: |
| 279 | + if label_representation.shape[0] > 10000: |
| 280 | + kmeans_algo = ElkanKmeans |
| 281 | + else: |
| 282 | + kmeans_algo = LloydKmeans |
| 283 | + |
| 284 | + kmeans = kmeans_algo( |
| 285 | + n_clusters=K, max_iter=300, tol=0.0001, random_state=np.random.randint(2**31 - 1), verbose=True |
288 | 286 | ) |
289 | | - .fit(label_representation) |
290 | | - .labels_ |
291 | | - ) |
| 287 | + metalabels = kmeans.fit(label_representation) |
292 | 288 |
|
293 | | - children = [] |
294 | | - for i in range(K): |
295 | | - child_representation = label_representation[metalabels == i] |
296 | | - child_map = label_map[metalabels == i] |
297 | | - child = _build_tree(child_representation, child_map, d + 1, K, dmax) |
298 | | - children.append(child) |
| 289 | + unique_labels = np.unique(metalabels) |
| 290 | + if len(unique_labels) == K: |
| 291 | + create_child_node = lambda i: _build_tree( |
| 292 | + label_representation[metalabels == i], label_map[metalabels == i], d + 1, K, dmax |
| 293 | + ) |
| 294 | + else: |
| 295 | + create_child_node = lambda i: Node(label_map=label_map[metalabels == i], children=[]) |
| 296 | + |
| 297 | + for i in range(K): |
| 298 | + child = create_child_node(i) |
| 299 | + children.append(child) |
299 | 300 |
|
300 | 301 | return Node(label_map=label_map, children=children) |
301 | 302 |
|
|
0 commit comments