1313
1414__all__ = ["train_tree" , "TreeModel" , "train_ensemble_tree" , "EnsembleTreeModel" ]
1515
16+ K = 100
17+ DMAX = 10
18+
1619
1720class Node :
1821 def __init__ (
@@ -198,8 +201,6 @@ def train_tree(
198201 y : sparse .csr_matrix ,
199202 x : sparse .csr_matrix ,
200203 options : str = "" ,
201- K = 100 ,
202- dmax = 10 ,
203204 verbose : bool = True ,
204205) -> TreeModel :
205206 """Train a linear model for multi-label data using a divide-and-conquer strategy.
@@ -209,16 +210,14 @@ def train_tree(
209210 y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
210211 x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
211212 options (str): The option string passed to liblinear.
212- K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
213- dmax (int, optional): Maximum depth of the tree. Defaults to 10.
214213 verbose (bool, optional): Output extra progress information. Defaults to True.
215214
216215 Returns:
217216 TreeModel: A model which can be used in predict_values.
218217 """
219218 label_representation = (y .T * x ).tocsr ()
220219 label_representation = sklearn .preprocessing .normalize (label_representation , norm = "l2" , axis = 1 )
221- root = _build_tree (label_representation , np .arange (y .shape [1 ]), 0 , K , dmax )
220+ root = _build_tree (label_representation , np .arange (y .shape [1 ]), 0 )
222221 root .is_root = True
223222
224223 num_nodes = 0
@@ -261,20 +260,18 @@ def visit(node):
261260 return TreeModel (root , flat_model , node_ptr )
262261
263262
264- def _build_tree (label_representation : sparse .csr_matrix , label_map : np .ndarray , d : int , K : int , dmax : int ) -> Node :
263+ def _build_tree (label_representation : sparse .csr_matrix , label_map : np .ndarray , d : int ) -> Node :
265264 """Build the tree recursively by kmeans clustering.
266265
267266 Args:
268267 label_representation (sparse.csr_matrix): A matrix with dimensions number of classes under this node * number of features.
269268 label_map (np.ndarray): Maps 0..label_representation.shape[0] to the original label indices.
270269 d (int): Current depth.
271- K (int): Maximum degree of nodes in the tree.
272- dmax (int): Maximum depth of the tree.
273270
274271 Returns:
275272 Node: Root of the (sub)tree built from label_representation.
276273 """
277- if d >= dmax or label_representation .shape [0 ] <= K :
274+ if d >= DMAX or label_representation .shape [0 ] <= K :
278275 return Node (label_map = label_map , children = [])
279276
280277 metalabels = (
@@ -294,7 +291,7 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
294291 for i in range (K ):
295292 child_representation = label_representation [metalabels == i ]
296293 child_map = label_map [metalabels == i ]
297- child = _build_tree (child_representation , child_map , d + 1 , K , dmax )
294+ child = _build_tree (child_representation , child_map , d + 1 )
298295 children .append (child )
299296
300297 return Node (label_map = label_map , children = children )
@@ -416,8 +413,6 @@ def train_ensemble_tree(
416413 y : sparse .csr_matrix ,
417414 x : sparse .csr_matrix ,
418415 options : str = "" ,
419- K : int = 100 ,
420- dmax : int = 10 ,
421416 n_trees : int = 3 ,
422417 seed : int = 42 ,
423418 verbose : bool = True ,
@@ -427,8 +422,6 @@ def train_ensemble_tree(
427422 y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
428423 x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
429424 options (str, optional): The option string passed to liblinear. Defaults to ''.
430- K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
431- dmax (int, optional): Maximum depth of the tree. Defaults to 10.
432425 n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
433426 seed (int, optional): The base random seed for the ensemble. Defaults to 42.
434427 verbose (bool, optional): Output extra progress information. Defaults to True.
@@ -440,7 +433,7 @@ def train_ensemble_tree(
440433 for i in range (n_trees ):
441434 np .random .seed (seed + i )
442435
443- tree_model = train_tree (y , x , options , K , dmax , verbose = False )
436+ tree_model = train_tree (y , x , options , verbose = False )
444437 tree_models .append (tree_model )
445438
446439 if verbose :
0 commit comments