Skip to content

Commit a67ee11

Browse files
committed
final fix with seed
1 parent 65521b9 commit a67ee11

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

libmultilabel/linear/tree.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@ def train_ensemble_tree(
422422
K: int = DEFAULT_K,
423423
dmax: int = DEFAULT_DMAX,
424424
n_trees: int = 3,
425-
seed: int = 42,
426425
verbose: bool = True,
426+
seed: int = None,
427427
) -> EnsembleTreeModel:
428428
"""Trains an ensemble of tree models (Parabel/Bonsai-style).
429429
Args:
@@ -433,12 +433,15 @@ def train_ensemble_tree(
433433
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
434434
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
435435
n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
436-
seed (int, optional): The base random seed for the ensemble. Defaults to 42.
437436
verbose (bool, optional): Output extra progress information. Defaults to True.
437+
seed (int, optional): The base random seed for the ensemble. Defaults to None, which will use 42.
438438
439439
Returns:
440440
EnsembleTreeModel: An ensemble model which can be used for prediction.
441441
"""
442+
if seed is None:
443+
seed = 42
444+
442445
tree_models = []
443446
for i in range(n_trees):
444447
np.random.seed(seed + i)

linear_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def linear_train(datasets, config):
5757
K=config.tree_degree,
5858
dmax=config.tree_max_depth,
5959
n_trees=config.tree_ensemble_models,
60-
seed=config.seed if config.seed is not None else 42,
60+
seed=config.seed,
6161
)
6262
else:
6363
model = LINEAR_TECHNIQUES[config.linear_technique](

0 commit comments

Comments
 (0)