Skip to content

Commit 24f9a6b

Browse files
committed
making dmax and K as global variables
1 parent 5419063 commit 24f9a6b

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

libmultilabel/linear/tree.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]
1515

16+
K = 100
17+
DMAX = 10
18+
1619

1720
class Node:
1821
def __init__(
@@ -198,8 +201,6 @@ def train_tree(
198201
y: sparse.csr_matrix,
199202
x: sparse.csr_matrix,
200203
options: str = "",
201-
K=100,
202-
dmax=10,
203204
verbose: bool = True,
204205
) -> TreeModel:
205206
"""Train a linear model for multi-label data using a divide-and-conquer strategy.
@@ -209,16 +210,14 @@ def train_tree(
209210
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
210211
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
211212
options (str): The option string passed to liblinear.
212-
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
213-
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
214213
verbose (bool, optional): Output extra progress information. Defaults to True.
215214
216215
Returns:
217216
TreeModel: A model which can be used in predict_values.
218217
"""
219218
label_representation = (y.T * x).tocsr()
220219
label_representation = sklearn.preprocessing.normalize(label_representation, norm="l2", axis=1)
221-
root = _build_tree(label_representation, np.arange(y.shape[1]), 0, K, dmax)
220+
root = _build_tree(label_representation, np.arange(y.shape[1]), 0)
222221
root.is_root = True
223222

224223
num_nodes = 0
@@ -261,20 +260,18 @@ def visit(node):
261260
return TreeModel(root, flat_model, node_ptr)
262261

263262

264-
def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node:
263+
def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int) -> Node:
265264
"""Build the tree recursively by kmeans clustering.
266265
267266
Args:
268267
label_representation (sparse.csr_matrix): A matrix with dimensions number of classes under this node * number of features.
269268
label_map (np.ndarray): Maps 0..label_representation.shape[0] to the original label indices.
270269
d (int): Current depth.
271-
K (int): Maximum degree of nodes in the tree.
272-
dmax (int): Maximum depth of the tree.
273270
274271
Returns:
275272
Node: Root of the (sub)tree built from label_representation.
276273
"""
277-
if d >= dmax or label_representation.shape[0] <= K:
274+
if d >= DMAX or label_representation.shape[0] <= K:
278275
return Node(label_map=label_map, children=[])
279276

280277
metalabels = (
@@ -294,7 +291,7 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
294291
for i in range(K):
295292
child_representation = label_representation[metalabels == i]
296293
child_map = label_map[metalabels == i]
297-
child = _build_tree(child_representation, child_map, d + 1, K, dmax)
294+
child = _build_tree(child_representation, child_map, d + 1)
298295
children.append(child)
299296

300297
return Node(label_map=label_map, children=children)
@@ -416,8 +413,6 @@ def train_ensemble_tree(
416413
y: sparse.csr_matrix,
417414
x: sparse.csr_matrix,
418415
options: str = "",
419-
K: int = 100,
420-
dmax: int = 10,
421416
n_trees: int = 3,
422417
seed: int = 42,
423418
verbose: bool = True,
@@ -427,8 +422,6 @@ def train_ensemble_tree(
427422
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
428423
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
429424
options (str, optional): The option string passed to liblinear. Defaults to ''.
430-
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
431-
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
432425
n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
433426
seed (int, optional): The base random seed for the ensemble. Defaults to 42.
434427
verbose (bool, optional): Output extra progress information. Defaults to True.
@@ -440,7 +433,7 @@ def train_ensemble_tree(
440433
for i in range(n_trees):
441434
np.random.seed(seed + i)
442435

443-
tree_model = train_tree(y, x, options, K, dmax, verbose=False)
436+
tree_model = train_tree(y, x, options, verbose=False)
444437
tree_models.append(tree_model)
445438

446439
if verbose:

linear_trainer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,13 @@ def linear_train(datasets, config):
4848
if config.linear_technique == "tree":
4949
if multiclass:
5050
raise ValueError("Tree model should only be used with multilabel datasets.")
51-
51+
linear.tree.K = config.tree_degree
52+
linear.tree.DMAX = config.tree_max_depth
5253
if config.tree_ensemble_models > 1:
5354
model = train_ensemble_tree(
5455
datasets["train"]["y"],
5556
datasets["train"]["x"],
5657
options=config.liblinear_options,
57-
K=config.tree_degree,
58-
dmax=config.tree_max_depth,
5958
n_trees=config.tree_ensemble_models,
6059
seed=config.seed if config.seed is not None else 42,
6160
)
@@ -64,8 +63,6 @@ def linear_train(datasets, config):
6463
datasets["train"]["y"],
6564
datasets["train"]["x"],
6665
options=config.liblinear_options,
67-
K=config.tree_degree,
68-
dmax=config.tree_max_depth,
6966
)
7067
else:
7168
model = LINEAR_TECHNIQUES[config.linear_technique](

0 commit comments

Comments
 (0)