Skip to content

Commit 0669c9f

Browse files
committed
making dmax and K as global default value
1 parent 24f9a6b commit 0669c9f

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

libmultilabel/linear/tree.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]
1515

16-
K = 100
17-
DMAX = 10
16+
DEFAULT_K = 100
17+
DEFAULT_DMAX = 10
1818

1919

2020
class Node:
@@ -201,6 +201,8 @@ def train_tree(
201201
y: sparse.csr_matrix,
202202
x: sparse.csr_matrix,
203203
options: str = "",
204+
K=DEFAULT_K,
205+
dmax=DEFAULT_DMAX,
204206
verbose: bool = True,
205207
) -> TreeModel:
206208
"""Train a linear model for multi-label data using a divide-and-conquer strategy.
@@ -210,14 +212,16 @@ def train_tree(
210212
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
211213
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
212214
options (str): The option string passed to liblinear.
215+
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
216+
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
213217
verbose (bool, optional): Output extra progress information. Defaults to True.
214218
215219
Returns:
216220
TreeModel: A model which can be used in predict_values.
217221
"""
218222
label_representation = (y.T * x).tocsr()
219223
label_representation = sklearn.preprocessing.normalize(label_representation, norm="l2", axis=1)
220-
root = _build_tree(label_representation, np.arange(y.shape[1]), 0)
224+
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
221225
root.is_root = True
222226

223227
num_nodes = 0
@@ -260,18 +264,20 @@ def visit(node):
260264
return TreeModel(root, flat_model, node_ptr)
261265

262266

263-
def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int) -> Node:
267+
def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node:
264268
"""Build the tree recursively by kmeans clustering.
265269
266270
Args:
267271
label_representation (sparse.csr_matrix): A matrix with dimensions number of classes under this node * number of features.
268272
label_map (np.ndarray): Maps 0..label_representation.shape[0] to the original label indices.
269273
d (int): Current depth.
274+
K (int): Maximum degree of nodes in the tree.
275+
dmax (int): Maximum depth of the tree.
270276
271277
Returns:
272278
Node: Root of the (sub)tree built from label_representation.
273279
"""
274-
if d >= DMAX or label_representation.shape[0] <= K:
280+
if d >= dmax or label_representation.shape[0] <= K:
275281
return Node(label_map=label_map, children=[])
276282

277283
metalabels = (
@@ -291,7 +297,7 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
291297
for i in range(K):
292298
child_representation = label_representation[metalabels == i]
293299
child_map = label_map[metalabels == i]
294-
child = _build_tree(child_representation, child_map, d + 1)
300+
child = _build_tree(child_representation, child_map, d + 1, K, dmax)
295301
children.append(child)
296302

297303
return Node(label_map=label_map, children=children)
@@ -413,6 +419,8 @@ def train_ensemble_tree(
413419
y: sparse.csr_matrix,
414420
x: sparse.csr_matrix,
415421
options: str = "",
422+
K: int = DEFAULT_K,
423+
dmax: int = DEFAULT_DMAX,
416424
n_trees: int = 3,
417425
seed: int = 42,
418426
verbose: bool = True,
@@ -422,6 +430,8 @@ def train_ensemble_tree(
422430
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
423431
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
424432
options (str, optional): The option string passed to liblinear. Defaults to ''.
433+
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
434+
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
425435
n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
426436
seed (int, optional): The base random seed for the ensemble. Defaults to 42.
427437
verbose (bool, optional): Output extra progress information. Defaults to True.
@@ -433,10 +443,9 @@ def train_ensemble_tree(
433443
for i in range(n_trees):
434444
np.random.seed(seed + i)
435445

436-
tree_model = train_tree(y, x, options, verbose=False)
446+
tree_model = train_tree(y, x, options, K, dmax, verbose)
437447
tree_models.append(tree_model)
438448

439-
if verbose:
440-
print("Ensemble training completed.")
449+
print("Ensemble training completed.")
441450

442451
return EnsembleTreeModel(tree_models)

linear_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@ def linear_train(datasets, config):
4848
if config.linear_technique == "tree":
4949
if multiclass:
5050
raise ValueError("Tree model should only be used with multilabel datasets.")
51-
linear.tree.K = config.tree_degree
52-
linear.tree.DMAX = config.tree_max_depth
51+
5352
if config.tree_ensemble_models > 1:
5453
model = train_ensemble_tree(
5554
datasets["train"]["y"],
5655
datasets["train"]["x"],
5756
options=config.liblinear_options,
57+
K=config.tree_degree,
58+
dmax=config.tree_max_depth,
5859
n_trees=config.tree_ensemble_models,
5960
seed=config.seed if config.seed is not None else 42,
6061
)
@@ -63,6 +64,8 @@ def linear_train(datasets, config):
6364
datasets["train"]["y"],
6465
datasets["train"]["x"],
6566
options=config.liblinear_options,
67+
K=config.tree_degree,
68+
dmax=config.tree_max_depth,
6669
)
6770
else:
6871
model = LINEAR_TECHNIQUES[config.linear_technique](

0 commit comments

Comments
 (0)