1313
1414__all__ = ["train_tree" , "TreeModel" , "train_ensemble_tree" , "EnsembleTreeModel" ]
1515
16- K = 100
17- DMAX = 10
16+ DEFAULT_K = 100
17+ DEFAULT_DMAX = 10
1818
1919
2020class Node :
@@ -201,6 +201,8 @@ def train_tree(
201201 y : sparse .csr_matrix ,
202202 x : sparse .csr_matrix ,
203203 options : str = "" ,
204+ K = DEFAULT_K ,
205+ dmax = DEFAULT_DMAX ,
204206 verbose : bool = True ,
205207) -> TreeModel :
206208 """Train a linear model for multi-label data using a divide-and-conquer strategy.
@@ -210,14 +212,16 @@ def train_tree(
210212 y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
211213 x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
212214 options (str): The option string passed to liblinear.
215+ K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
216+ dmax (int, optional): Maximum depth of the tree. Defaults to 10.
213217 verbose (bool, optional): Output extra progress information. Defaults to True.
214218
215219 Returns:
216220 TreeModel: A model which can be used in predict_values.
217221 """
218222 label_representation = (y .T * x ).tocsr ()
219223 label_representation = sklearn .preprocessing .normalize (label_representation , norm = "l2" , axis = 1 )
220- root = _build_tree (label_representation , np .arange (y .shape [1 ]), 0 )
224+ root = _build_tree (label_representation , np .arange (y .shape [1 ]), 0 , K , dmax )
221225 root .is_root = True
222226
223227 num_nodes = 0
@@ -260,18 +264,20 @@ def visit(node):
260264 return TreeModel (root , flat_model , node_ptr )
261265
262266
263- def _build_tree (label_representation : sparse .csr_matrix , label_map : np .ndarray , d : int ) -> Node :
267+ def _build_tree (label_representation : sparse .csr_matrix , label_map : np .ndarray , d : int , K : int , dmax : int ) -> Node :
264268 """Build the tree recursively by kmeans clustering.
265269
266270 Args:
267271 label_representation (sparse.csr_matrix): A matrix with dimensions number of classes under this node * number of features.
268272 label_map (np.ndarray): Maps 0..label_representation.shape[0] to the original label indices.
269273 d (int): Current depth.
274+ K (int): Maximum degree of nodes in the tree.
275+ dmax (int): Maximum depth of the tree.
270276
271277 Returns:
272278 Node: Root of the (sub)tree built from label_representation.
273279 """
274- if d >= DMAX or label_representation .shape [0 ] <= K :
280+ if d >= dmax or label_representation .shape [0 ] <= K :
275281 return Node (label_map = label_map , children = [])
276282
277283 metalabels = (
@@ -291,7 +297,7 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
291297 for i in range (K ):
292298 child_representation = label_representation [metalabels == i ]
293299 child_map = label_map [metalabels == i ]
294- child = _build_tree (child_representation , child_map , d + 1 )
300+ child = _build_tree (child_representation , child_map , d + 1 , K , dmax )
295301 children .append (child )
296302
297303 return Node (label_map = label_map , children = children )
@@ -413,6 +419,8 @@ def train_ensemble_tree(
413419 y : sparse .csr_matrix ,
414420 x : sparse .csr_matrix ,
415421 options : str = "" ,
422+ K : int = DEFAULT_K ,
423+ dmax : int = DEFAULT_DMAX ,
416424 n_trees : int = 3 ,
417425 seed : int = 42 ,
418426 verbose : bool = True ,
@@ -422,6 +430,8 @@ def train_ensemble_tree(
422430 y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
423431 x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
424432 options (str, optional): The option string passed to liblinear. Defaults to ''.
433+ K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
434+ dmax (int, optional): Maximum depth of the tree. Defaults to 10.
425435 n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
426436 seed (int, optional): The base random seed for the ensemble. Defaults to 42.
427437 verbose (bool, optional): Output extra progress information. Defaults to True.
@@ -433,10 +443,9 @@ def train_ensemble_tree(
433443 for i in range (n_trees ):
434444 np .random .seed (seed + i )
435445
436- tree_model = train_tree (y , x , options , verbose = False )
446+ tree_model = train_tree (y , x , options , K , dmax , verbose )
437447 tree_models .append (tree_model )
438448
439- if verbose :
440- print ("Ensemble training completed." )
449+ print ("Ensemble training completed." )
441450
442451 return EnsembleTreeModel (tree_models )
0 commit comments