Skip to content

Commit ab9cde0

Browse files
committed
keep the ensemble implementation only and remove the scoring aware
1 parent 605c7ee commit ab9cde0

File tree

1 file changed

+21
-39
lines changed

1 file changed

+21
-39
lines changed

libmultilabel/linear/tree.py

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import psutil
1111

1212
from . import linear
13-
from scipy.special import log_expit
14-
1513

1614
__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]
1715

@@ -49,31 +47,13 @@ def __init__(
4947
root: Node,
5048
flat_model: linear.FlatModel,
5149
node_ptr: np.ndarray,
52-
options: str,
5350
):
5451
self.name = "tree"
5552
self.root = root
5653
self.flat_model = flat_model
5754
self.node_ptr = node_ptr
58-
self.options = options
5955
self.multiclass = False
60-
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
61-
62-
def _is_lr(self) -> bool:
63-
options = self.options or ""
64-
options_split = options.split()
65-
if "-s" in options_split:
66-
i = options_split.index("-s")
67-
if i + 1 < len(options_split):
68-
solver_type = options_split[i + 1]
69-
return solver_type in ["0", "6", "7"]
70-
return False
71-
72-
def _get_scores(self, pred: np.ndarray, parent_score: float = 0.0) -> np.ndarray:
73-
if self._is_lr():
74-
return parent_score + log_expit(pred)
75-
else:
76-
return parent_score - np.square(np.maximum(0, 1 - pred))
56+
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
7757

7858
def predict_values(
7959
self,
@@ -92,42 +72,44 @@ def predict_values(
9272
if beam_width >= len(self.root.children):
9373
# Beam_width is sufficiently large; pruning not applied.
9474
# Calculates decision values for all nodes.
95-
all_preds = linear.predict_values(
96-
self.flat_model, x
97-
) # number of instances * (number of labels + total number of metalabels)
75+
all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels)
9876
else:
9977
# Beam_width is small; pruning applied to reduce computation.
10078
if not self._model_separated:
10179
self._separate_model_for_pruning_tree()
10280
self._model_separated = True
103-
all_preds = self._prune_tree_and_predict_values(
104-
x, beam_width
105-
) # number of instances * (number of labels + total number of metalabels)
81+
all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels)
10682
return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])])
10783

10884
def _separate_model_for_pruning_tree(self):
10985
"""
11086
This function separates the weights for the root node and its children into (K+1) FlatModel
11187
for efficient beam search traversal in Python.
11288
"""
113-
tree_flat_model_params = {"bias": self.root.model.bias, "thresholds": 0, "multiclass": False}
89+
tree_flat_model_params = {
90+
'bias': self.root.model.bias,
91+
'thresholds': 0,
92+
'multiclass': False
93+
}
11494
slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
11595
self.root_model = linear.FlatModel(
116-
name="root-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params
96+
name="root-flattened-tree",
97+
weights=self.flat_model.weights[slice].tocsr(),
98+
**tree_flat_model_params
11799
)
118100

119101
self.subtree_models = []
120102
for i in range(len(self.root.children)):
121103
subtree_weights_start = self.node_ptr[self.root.children[i].index]
122-
subtree_weights_end = (
123-
self.node_ptr[self.root.children[i + 1].index] if i + 1 < len(self.root.children) else -1
124-
)
104+
subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else -1
125105
slice = np.s_[:, subtree_weights_start:subtree_weights_end]
126106
subtree_flatmodel = linear.FlatModel(
127-
name="subtree-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params
107+
name="subtree-flattened-tree",
108+
weights=self.flat_model.weights[slice].tocsr(),
109+
**tree_flat_model_params
128110
)
129111
self.subtree_models.append(subtree_flatmodel)
130-
112+
131113
def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray:
132114
"""Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees.
133115
@@ -147,7 +129,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
147129

148130
# Calculate root decision values and scores
149131
root_preds = linear.predict_values(self.root_model, x)
150-
children_scores = self._get_scores(root_preds)
132+
children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds))
151133

152134
slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
153135
all_preds[slice] = root_preds
@@ -158,7 +140,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
158140
# Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance
159141
mask = np.zeros_like(children_scores, dtype=np.bool_)
160142
np.put_along_axis(mask, top_beam_width_indices, True, axis=1)
161-
143+
162144
# Calculate predictions for each subtree with its corresponding instances
163145
for subtree_idx in range(len(self.root.children)):
164146
subtree_model = self.subtree_models[subtree_idx]
@@ -197,7 +179,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
197179
continue
198180
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
199181
pred = instance_preds[slice]
200-
children_score = self._get_scores(pred, score)
182+
children_score = score - np.square(np.maximum(0, 1 - pred))
201183
next_level.extend(zip(node.children, children_score.tolist()))
202184

203185
cur_level = sorted(next_level, key=lambda pair: -pair[1])[:beam_width]
@@ -208,7 +190,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
208190
for node, score in cur_level:
209191
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
210192
pred = instance_preds[slice]
211-
scores[node.label_map] = np.exp(self._get_scores(pred, score))
193+
scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred)))
212194
return scores
213195

214196

@@ -276,7 +258,7 @@ def visit(node):
276258
pbar.close()
277259

278260
flat_model, node_ptr = _flatten_model(root)
279-
return TreeModel(root, flat_model, node_ptr, options)
261+
return TreeModel(root, flat_model, node_ptr)
280262

281263

282264
def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node:

0 commit comments

Comments
 (0)