Skip to content

Commit 5419063

Browse files
committed
resolve SW's comment
1 parent ab9cde0 commit 5419063

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

linear_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import libmultilabel.linear as linear
88
from libmultilabel.common_utils import dump_log, is_multiclass_dataset
9-
from libmultilabel.linear.tree import train_ensemble_tree
9+
from libmultilabel.linear.tree import EnsembleTreeModel, TreeModel, train_ensemble_tree
1010
from libmultilabel.linear.utils import LINEAR_TECHNIQUES
1111

1212

@@ -22,7 +22,7 @@ def linear_test(config, model, datasets, label_mapping):
2222
scores = []
2323

2424
predict_kwargs = {}
25-
if model.name == "tree" or model.name == "ensemble-tree":
25+
if isinstance(model, (TreeModel, EnsembleTreeModel)):
2626
predict_kwargs["beam_width"] = config.beam_width
2727

2828
for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):

0 commit comments

Comments
 (0)