Skip to content

Update inference method involving prune tree prediction. #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 99 additions & 16 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ def __init__(
self,
root: Node,
flat_model: linear.FlatModel,
weight_map: np.ndarray,
node_ptr: np.ndarray,
):
self.name = "tree"
self.root = root
self.flat_model = flat_model
self.weight_map = weight_map
self.node_ptr = node_ptr
self.multiclass = False
self._model_separated = False # Indicates whether the model has been separated for pruning tree.

def predict_values(
self,
Expand All @@ -68,10 +69,93 @@ def predict_values(
Returns:
np.ndarray: A matrix with dimension number of instances * number of classes.
"""
# number of instances * number of labels + total number of metalabels
all_preds = linear.predict_values(self.flat_model, x)
if beam_width >= len(self.root.children):
# Beam_width is sufficiently large; pruning not applied.
# Calculates decision values for all nodes.
all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels)
else:
# Beam_width is small; pruning applied to reduce computation.
if not self._model_separated:
self._separate_model_for_pruning_tree()
self._model_separated = True
all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels)
return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])])

def _separate_model_for_pruning_tree(self):
"""
This function seperates the weights for the root node and its children into (K+1) FlatModel
for efficient beam search traversal in Python.
"""
tree_flat_model_params = {
'bias': self.root.model.bias,
'thresholds': 0,
'multiclass': False
}
slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
self.root_model = linear.FlatModel(
name="root-flattened-tree",
weights=self.flat_model.weights[slice].tocsr(),
**tree_flat_model_params
)

self.subtree_models = []
for i in range(len(self.root.children)):
subtree_weights_start = self.node_ptr[self.root.children[i].index]
subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else -1
slice = np.s_[:, subtree_weights_start:subtree_weights_end]
subtree_flatmodel = linear.FlatModel(
name="subtree-flattened-tree",
weights=self.flat_model.weights[slice].tocsr(),
**tree_flat_model_params
)
self.subtree_models.append(subtree_flatmodel)

def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray:
"""Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees.

Only subtrees corresponding to the top beam_width candidates from the root are evaluated,
skipping the rest to avoid unnecessary computation.

Args:
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
beam_width (int): Number of top candidate branches considered for prediction.

Returns:
np.ndarray: A matrix with dimension number of instances * (number of labels + total number of metalabels).
"""
# Initialize space for all predictions with negative infinity
num_instances, num_labels = x.shape[0], self.node_ptr[-1]
all_preds = np.full((num_instances, num_labels), -np.inf)

# Calculate root decision values and scores
root_preds = linear.predict_values(self.root_model, x)
children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds))

slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
all_preds[slice] = root_preds

# Select indices of the top beam_width subtrees for each instance
top_beam_width_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width]

# Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance
mask = np.zeros_like(children_scores, dtype=np.bool_)
np.put_along_axis(mask, top_beam_width_indices, True, axis=1)

# Calculate predictions for each subtree with its corresponding instances
for subtree_idx in range(len(self.root.children)):
subtree_model = self.subtree_models[subtree_idx]
instances_mask = mask[:, subtree_idx]
reduced_instances = x[np.s_[instances_mask], :]

# Locate the position of the subtree root in the weight mapping of all nodes
subtree_weights_start = self.node_ptr[self.root.children[subtree_idx].index]
subtree_weights_end = subtree_weights_start + subtree_model.weights.shape[1]

slice = np.s_[instances_mask, subtree_weights_start:subtree_weights_end]
all_preds[slice] = linear.predict_values(subtree_model, reduced_instances)

return all_preds

def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarray:
"""Predict with beam search using cached probability estimates for a single instance.

Expand All @@ -93,7 +177,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
if node.isLeaf():
next_level.append((node, score))
continue
slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]]
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
children_score = score - np.square(np.maximum(0, 1 - pred))
next_level.extend(zip(node.children, children_score.tolist()))
Expand All @@ -102,9 +186,9 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
next_level = []

num_labels = len(self.root.label_map)
scores = np.full(num_labels, 0.0)
scores = np.zeros(num_labels)
for node, score in cur_level:
slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]]
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred)))
return scores
Expand Down Expand Up @@ -173,8 +257,8 @@ def visit(node):
root.dfs(visit)
pbar.close()

flat_model, weight_map = _flatten_model(root)
return TreeModel(root, flat_model, weight_map)
flat_model, node_ptr = _flatten_model(root)
return TreeModel(root, flat_model, node_ptr)


def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node:
Expand Down Expand Up @@ -261,11 +345,10 @@ def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]:
"""Flattens tree weight matrices into a single weight matrix. The flattened weight
matrix is used to predict all possible values, which is cached for beam search.
This pessimizes complexity but is faster in practice.
Consecutive values of the returned map denotes the start and end indices of the
weights of each node. Conceptually, given root and node:
flat_model, weight_map = _flatten_model(root)
slice = np.s_[weight_map[node.index]:
weight_map[node.index+1]]
Consecutive values of the returned array denotes the start and end indices of each node in the weight matrix.
To extract a node's classifiers:
slice = np.s_[node_ptr[node.index]:
node_ptr[node.index+1]]
node.model.weights == flat_model.weights[:, slice]

Args:
Expand Down Expand Up @@ -296,6 +379,6 @@ def visit(node):
)

# w.shape[1] is the number of labels/metalabels of each node
weight_map = np.cumsum([0] + list(map(lambda w: w.shape[1], weights)))
node_ptr = np.cumsum([0] + list(map(lambda w: w.shape[1], weights)))

return model, weight_map
return model, node_ptr