From bc271de0ca6f22b155ffa5e0e1423e60f49fa1b6 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Tue, 18 Feb 2025 20:36:06 +0800 Subject: [PATCH 01/15] Update inference method involving prune tree prediction. --- libmultilabel/linear/tree.py | 138 +++++++++++++++++++++++++++++------ 1 file changed, 114 insertions(+), 24 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 0db2f86..d4c9a54 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Optional import numpy as np import scipy.sparse as sparse @@ -46,13 +46,15 @@ def __init__( self, root: Node, flat_model: linear.FlatModel, - weight_map: np.ndarray, + weight_map: Optional[np.ndarray] = None, + subtrees: Optional[list[TreeModel]] = None, ): self.name = "tree" self.root = root self.flat_model = flat_model - self.weight_map = weight_map + self.weight_map = weight_map if weight_map is not None else np.array([]) self.multiclass = False + self.subtrees = subtrees if subtrees else [] def predict_values( self, @@ -69,9 +71,54 @@ def predict_values( np.ndarray: A matrix with dimension number of instances * number of classes. """ # number of instances * number of labels + total number of metalabels - all_preds = linear.predict_values(self.flat_model, x) + all_preds = self._prune_tree_predictions(x, beam_width) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) + + def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: + """Calculates the decision values associated with x. + + If the beam width is smaller than the number of nodes at a some level, many nodes become unreachable, resulting in unnecessary computations. + In LibMultiLabel's default setting, the beam width is smaller than the root's degree in the tree. + To mitigate unnecessary computations, pruning is applied to predictions starting from the root. + + Args: + x (sparse.csr_matrix): A matrix with dimension number of instances * number of features. + beam_width (int): Number of top candidate branches considered for prediction. + + Returns: + np.ndarray: A matrix with dimension number of instances * (number of labels + total number of metalabels). + """ + # Initialize space for all predictions with negative infinity + num_instances, num_labels = x.shape[0], self.weight_map[-1] + all_preds= np.full((num_instances, num_labels), np.NINF) + + # Calculate root decision value and scores + root_preds = linear.predict_values(self.flat_model, x) + children_scores = 0.0 - np.maximum(0, 1 - root_preds) ** 2 + + # Find the top k subtree for each instance + top_k_indices = np.argsort(-children_scores, axis=1, kind='stable')[:, :beam_width] + + # Building a mapping from subtree to instances + subtree_to_instances = {subtree: np.where(top_k_indices == subtree)[0] for subtree in np.unique(top_k_indices)} + + slice = np.s_[:num_instances, self.weight_map[self.root.index]: self.weight_map[self.root.index+1]] + all_preds[slice] = root_preds + + # Calculate predictions for each subtree with its corresponding instances + for subtree, instances in subtree_to_instances.items(): + current_subtree = self.subtrees[subtree] + reduced_instances = x[np.s_[instances], :] + # Locate the position of the subtree root in the weight mapping of all nodes. + subtree_weights_start = self.weight_map[current_subtree.root.index] + subtree_weights_end = subtree_weights_start+current_subtree.flat_model.weights.shape[1] + + slice = np.s_[instances, subtree_weights_start:subtree_weights_end] + all_preds[slice] = linear.predict_values(current_subtree.flat_model, reduced_instances) + + return all_preds + def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarray: """Predict with beam search using cached probability estimates for a single instance. @@ -102,7 +149,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra next_level = [] num_labels = len(self.root.label_map) - scores = np.full(num_labels, 0.0) + scores = np.zeros(num_labels) for node, score in cur_level: slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]] pred = instance_preds[slice] @@ -162,7 +209,12 @@ def count(node): pbar = tqdm(total=num_nodes, disable=not verbose) + index = 0 + def visit(node): + nonlocal index + node.index = index + index += 1 if node.is_root: _train_node(y, x, options, node) else: @@ -173,9 +225,7 @@ def visit(node): root.dfs(visit) pbar.close() - flat_model, weight_map = _flatten_model(root) - return TreeModel(root, flat_model, weight_map) - + return _tree_model(root) def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node: """Builds the tree recursively by kmeans clustering. @@ -257,32 +307,23 @@ def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: node.model.weights = sparse.csc_matrix(node.model.weights) -def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]: +def _flatten_model(root: Node) -> linear.FlatModel: """Flattens tree weight matrices into a single weight matrix. The flattened weight matrix is used to predict all possible values, which is cached for beam search. This pessimizes complexity but is faster in practice. - Consecutive values of the returned map denotes the start and end indices of the - weights of each node. Conceptually, given root and node: - flat_model, weight_map = _flatten_model(root) - slice = np.s_[weight_map[node.index]: - weight_map[node.index+1]] - node.model.weights == flat_model.weights[:, slice] - + flat_model = _flatten_model(root) Args: root (Node): Root of the tree. Returns: - tuple[linear.FlatModel, np.ndarray]: The flattened model and the ranges of each node. + linear.FlatModel: The flattened model. """ - index = 0 weights = [] bias = root.model.bias + def visit(node): assert bias == node.model.bias - nonlocal index - node.index = index - index += 1 weights.append(node.model.__dict__.pop("weights")) root.dfs(visit) @@ -295,7 +336,56 @@ def visit(node): multiclass=False, ) - # w.shape[1] is the number of labels/metalabels of each node - weight_map = np.cumsum([0] + list(map(lambda w: w.shape[1], weights))) + return model + +def _tree_model(root: Node) -> TreeModel: + """Constructs a tree model by aggregating the weights of all nodes in the tree. + To speed up inference in Python, we avoid using a single flattened weight matrix, + which would involve many unnecessary computations. + Instead, we build a hierarchical tree model by aggregating the weights of each root's child + into different flattened weight matrices, representing subtrees as `TreeModel` instances. + Additionally, the root itself is also a `TreeModel`, containing subtree `TreeModel` instances. + + Consecutive values of the weight map denotes the start and end indices of the + weights of each node. Conceptually, given root and node: + slice = np.s_[weight_map[node.index]: + weight_map[node.index+1]] + node.model.weights == flat_model.weights[:, slice] + + Args: + root (Node): Root of the tree. + + Returns: + Tree Model: A tree model containing the root's flattened model, + weight index mappings of all nodes, and subtrees. + """ + # Build weights mapping which contains the start and end indices of the weights of each node. + weight_map = [0] + subtrees = [] + bias = root.model.bias + - return model, weight_map + def visit(node): + assert bias == node.model.bias + # weights.shape[1] is the number of labels/metalabels of each node + weight_map.append(node.model.weights.shape[1]) + + root.dfs(visit) + + weight_map = np.cumsum(weight_map) + + # Build root's subtrees + for child in root.children: + child_flat_model = _flatten_model(child) + subtrees.append(TreeModel(child, child_flat_model)) + + # Build root's flatten model with root model weights + model = linear.FlatModel( + name="root-flattened-tree", + weights=root.model.__dict__.pop("weights"), + bias=root.model.bias, + thresholds=0, + multiclass=False, + ) + + return TreeModel(root, model, weight_map, subtrees) \ No newline at end of file From a3f562a3216176b47a9ea417547e78e90adcc718 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Thu, 20 Feb 2025 18:51:20 +0800 Subject: [PATCH 02/15] Fix a bug in the ew inference implementation. --- libmultilabel/linear/tree.py | 37 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index d4c9a54..f8d8544 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -97,25 +97,26 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n root_preds = linear.predict_values(self.flat_model, x) children_scores = 0.0 - np.maximum(0, 1 - root_preds) ** 2 - # Find the top k subtree for each instance - top_k_indices = np.argsort(-children_scores, axis=1, kind='stable')[:, :beam_width] - - # Building a mapping from subtree to instances - subtree_to_instances = {subtree: np.where(top_k_indices == subtree)[0] for subtree in np.unique(top_k_indices)} - slice = np.s_[:num_instances, self.weight_map[self.root.index]: self.weight_map[self.root.index+1]] - all_preds[slice] = root_preds - - # Calculate predictions for each subtree with its corresponding instances - for subtree, instances in subtree_to_instances.items(): - current_subtree = self.subtrees[subtree] - reduced_instances = x[np.s_[instances], :] - # Locate the position of the subtree root in the weight mapping of all nodes. - subtree_weights_start = self.weight_map[current_subtree.root.index] - subtree_weights_end = subtree_weights_start+current_subtree.flat_model.weights.shape[1] - - slice = np.s_[instances, subtree_weights_start:subtree_weights_end] - all_preds[slice] = linear.predict_values(current_subtree.flat_model, reduced_instances) + all_preds[slice] = root_preds + + if not self.root.isLeaf(): + # Find the top k subtree for each instance + top_k_indices = np.argsort(-children_scores, axis=1, kind='stable')[:, :beam_width] + + # Building a mapping from subtree to instances + subtree_to_instances = {subtree: np.where(top_k_indices == subtree)[0] for subtree in np.unique(top_k_indices)} + + # Calculate predictions for each subtree with its corresponding instances + for subtree, instances in subtree_to_instances.items(): + current_subtree = self.subtrees[subtree] + reduced_instances = x[np.s_[instances], :] + # Locate the position of the subtree root in the weight mapping of all nodes. + subtree_weights_start = self.weight_map[current_subtree.root.index] + subtree_weights_end = subtree_weights_start+current_subtree.flat_model.weights.shape[1] + + slice = np.s_[instances, subtree_weights_start:subtree_weights_end] + all_preds[slice] = linear.predict_values(current_subtree.flat_model, reduced_instances) return all_preds From dfd4782df73a51d85fba3c14bd079057baaa713f Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Thu, 27 Feb 2025 15:47:10 +0800 Subject: [PATCH 03/15] Revise code for better readability in new implementations. --- libmultilabel/linear/tree.py | 80 +++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 34 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index f8d8544..66485b0 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Callable import numpy as np import scipy.sparse as sparse @@ -46,15 +46,15 @@ def __init__( self, root: Node, flat_model: linear.FlatModel, - weight_map: Optional[np.ndarray] = None, - subtrees: Optional[list[TreeModel]] = None, + weight_map: np.ndarray, + subtrees: list[SubTree], ): self.name = "tree" self.root = root self.flat_model = flat_model - self.weight_map = weight_map if weight_map is not None else np.array([]) + self.weight_map = weight_map self.multiclass = False - self.subtrees = subtrees if subtrees else [] + self.subtrees = subtrees def predict_values( self, @@ -74,12 +74,11 @@ def predict_values( all_preds = self._prune_tree_predictions(x, beam_width) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) - def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: """Calculates the decision values associated with x. - If the beam width is smaller than the number of nodes at a some level, many nodes become unreachable, resulting in unnecessary computations. - In LibMultiLabel's default setting, the beam width is smaller than the root's degree in the tree. + If the beam width is smaller than the number of nodes at a some level, many nodes become unreachable, resulting in unnecessary computations. + In LibMultiLabel's default setting, the beam width is smaller than the root's degree in the tree. To mitigate unnecessary computations, pruning is applied to predictions starting from the root. Args: @@ -91,32 +90,34 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n """ # Initialize space for all predictions with negative infinity num_instances, num_labels = x.shape[0], self.weight_map[-1] - all_preds= np.full((num_instances, num_labels), np.NINF) + all_preds = np.full((num_instances, num_labels), np.NINF) # Calculate root decision value and scores root_preds = linear.predict_values(self.flat_model, x) children_scores = 0.0 - np.maximum(0, 1 - root_preds) ** 2 - slice = np.s_[:num_instances, self.weight_map[self.root.index]: self.weight_map[self.root.index+1]] - all_preds[slice] = root_preds - + slice = np.s_[:num_instances, self.weight_map[self.root.index] : self.weight_map[self.root.index + 1]] + all_preds[slice] = root_preds + if not self.root.isLeaf(): # Find the top k subtree for each instance - top_k_indices = np.argsort(-children_scores, axis=1, kind='stable')[:, :beam_width] + top_k_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width] # Building a mapping from subtree to instances - subtree_to_instances = {subtree: np.where(top_k_indices == subtree)[0] for subtree in np.unique(top_k_indices)} + subtree_to_instances = { + self.subtrees[subtree_idx]: np.where(top_k_indices == subtree_idx)[0] + for subtree_idx in np.unique(top_k_indices) + } # Calculate predictions for each subtree with its corresponding instances for subtree, instances in subtree_to_instances.items(): - current_subtree = self.subtrees[subtree] reduced_instances = x[np.s_[instances], :] # Locate the position of the subtree root in the weight mapping of all nodes. - subtree_weights_start = self.weight_map[current_subtree.root.index] - subtree_weights_end = subtree_weights_start+current_subtree.flat_model.weights.shape[1] + subtree_weights_start = self.weight_map[subtree.root.index] + subtree_weights_end = subtree_weights_start + subtree.flat_model.weights.shape[1] slice = np.s_[instances, subtree_weights_start:subtree_weights_end] - all_preds[slice] = linear.predict_values(current_subtree.flat_model, reduced_instances) + all_preds[slice] = linear.predict_values(subtree.flat_model, reduced_instances) return all_preds @@ -158,6 +159,18 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra return scores +class SubTree: + """Represents a subtree with its root node and the linear flattened model which builts from the subtree's root.""" + + def __init__( + self, + root: Node, + flat_model: linear.FlatModel, + ): + self.root = root + self.flat_model = flat_model + + def train_tree( y: sparse.csr_matrix, x: sparse.csr_matrix, @@ -226,7 +239,8 @@ def visit(node): root.dfs(visit) pbar.close() - return _tree_model(root) + return _tree_model(root) + def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node: """Builds the tree recursively by kmeans clustering. @@ -322,7 +336,6 @@ def _flatten_model(root: Node) -> linear.FlatModel: weights = [] bias = root.model.bias - def visit(node): assert bias == node.model.bias weights.append(node.model.__dict__.pop("weights")) @@ -339,11 +352,12 @@ def visit(node): return model + def _tree_model(root: Node) -> TreeModel: """Constructs a tree model by aggregating the weights of all nodes in the tree. - To speed up inference in Python, we avoid using a single flattened weight matrix, + To speed up inference in Python, we avoid using a single flattened weight matrix, which would involve many unnecessary computations. - Instead, we build a hierarchical tree model by aggregating the weights of each root's child + Instead, we build a hierarchical tree model by aggregating the weights of each root's child into different flattened weight matrices, representing subtrees as `TreeModel` instances. Additionally, the root itself is also a `TreeModel`, containing subtree `TreeModel` instances. @@ -357,7 +371,7 @@ def _tree_model(root: Node) -> TreeModel: root (Node): Root of the tree. Returns: - Tree Model: A tree model containing the root's flattened model, + Tree Model: A tree model containing the root's flattened model, weight index mappings of all nodes, and subtrees. """ # Build weights mapping which contains the start and end indices of the weights of each node. @@ -365,7 +379,6 @@ def _tree_model(root: Node) -> TreeModel: subtrees = [] bias = root.model.bias - def visit(node): assert bias == node.model.bias # weights.shape[1] is the number of labels/metalabels of each node @@ -378,15 +391,14 @@ def visit(node): # Build root's subtrees for child in root.children: child_flat_model = _flatten_model(child) - subtrees.append(TreeModel(child, child_flat_model)) - + subtrees.append(SubTree(child, child_flat_model)) # Build root's flatten model with root model weights model = linear.FlatModel( - name="root-flattened-tree", - weights=root.model.__dict__.pop("weights"), - bias=root.model.bias, - thresholds=0, - multiclass=False, - ) - - return TreeModel(root, model, weight_map, subtrees) \ No newline at end of file + name="root-flattened-tree", + weights=root.model.__dict__.pop("weights"), + bias=root.model.bias, + thresholds=0, + multiclass=False, + ) + + return TreeModel(root, model, weight_map, subtrees) From dc61b28595fcc19e4150e3d1ae1881600f322900 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Thu, 6 Mar 2025 13:59:10 +0800 Subject: [PATCH 04/15] Revise the type of subTree class. --- libmultilabel/linear/tree.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 66485b0..d5477f2 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,6 +8,7 @@ import sklearn.preprocessing from tqdm import tqdm import psutil +from dataclasses import dataclass from . import linear @@ -158,17 +159,11 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred))) return scores - +@dataclass(frozen=True) class SubTree: """Represents a subtree with its root node and the linear flattened model which builts from the subtree's root.""" - - def __init__( - self, - root: Node, - flat_model: linear.FlatModel, - ): - self.root = root - self.flat_model = flat_model + root: Node + flat_model: linear.FlatModel def train_tree( @@ -177,6 +172,7 @@ def train_tree( options: str = "", K=100, dmax=10, + path=None, verbose: bool = True, ) -> TreeModel: """Trains a linear model for multi-label data using a divide-and-conquer strategy. From ae886950f54724623ec0dfb64ee89279a0bc12e7 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Thu, 6 Mar 2025 14:02:28 +0800 Subject: [PATCH 05/15] Fix the bug. --- libmultilabel/linear/tree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index d5477f2..08327ad 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -172,7 +172,6 @@ def train_tree( options: str = "", K=100, dmax=10, - path=None, verbose: bool = True, ) -> TreeModel: """Trains a linear model for multi-label data using a divide-and-conquer strategy. From 3af921266792abcc8a59c9c0bcec129454387baa Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Mon, 31 Mar 2025 23:47:53 +0800 Subject: [PATCH 06/15] Update the preprocess for pruning tree predictions. --- libmultilabel/linear/tree.py | 137 +++++++++++++++-------------------- 1 file changed, 59 insertions(+), 78 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 08327ad..81cec14 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,7 +8,7 @@ import sklearn.preprocessing from tqdm import tqdm import psutil -from dataclasses import dataclass +import itertools from . import linear @@ -48,14 +48,12 @@ def __init__( root: Node, flat_model: linear.FlatModel, weight_map: np.ndarray, - subtrees: list[SubTree], ): self.name = "tree" self.root = root self.flat_model = flat_model self.weight_map = weight_map self.multiclass = False - self.subtrees = subtrees def predict_values( self, @@ -72,14 +70,45 @@ def predict_values( np.ndarray: A matrix with dimension number of instances * number of classes. """ # number of instances * number of labels + total number of metalabels + self._preprocess_beam_search() all_preds = self._prune_tree_predictions(x, beam_width) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) + def _preprocess_beam_search(self): + """ Preprocess the flattened model for beam search. + + This function extracts the weights for the root node and its children into separate FlatModel instances for efficient beam search traversal in Python. + """ + if not hasattr(self, "root_weights"): + slice = np.s_[:, self.weight_map[self.root.index]:self.weight_map[self.root.index+1]] + self.root_weights = linear.FlatModel( + name="root-flattened-tree", + weights=self.flat_model.weights[slice], + bias=self.root.model.bias, + thresholds=0, + multiclass=False, + ) + + self.subtrees_weights = [] + subtree_indices = [self.weight_map[child.index] for child in self.root.children] + [self.weight_map[-1]] + + for subtree_start, subtree_end in itertools.pairwise(subtree_indices): + slice = np.s_[:, subtree_start:subtree_end] + subtree_flatmodel = linear.FlatModel( + name="subtree-flattened-tree", + weights=self.flat_model.weights[slice], + bias=self.root.model.bias, + thresholds=0, + multiclass=False, + ) + self.subtrees.append(subtree_flatmodel) + def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: """Calculates the decision values associated with x. - If the beam width is smaller than the number of nodes at a some level, many nodes become unreachable, resulting in unnecessary computations. - In LibMultiLabel's default setting, the beam width is smaller than the root's degree in the tree. + In LibMultiLabel, we concatenate all nodes' weights into a single matrix to avoid the large overhead of performing multiple matrix multiplications in Python. + However, if the beam width is smaller than the number of nodes at a certain level, many nodes become unreachable, leading to unnecessary computations. + For example, in LibMultiLabel's default setting, the beam width is smaller than the root's degree in the tree. To mitigate unnecessary computations, pruning is applied to predictions starting from the root. Args: @@ -94,7 +123,7 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n all_preds = np.full((num_instances, num_labels), np.NINF) # Calculate root decision value and scores - root_preds = linear.predict_values(self.flat_model, x) + root_preds = linear.predict_values(self.root_weights, x) children_scores = 0.0 - np.maximum(0, 1 - root_preds) ** 2 slice = np.s_[:num_instances, self.weight_map[self.root.index] : self.weight_map[self.root.index + 1]] @@ -106,19 +135,20 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n # Building a mapping from subtree to instances subtree_to_instances = { - self.subtrees[subtree_idx]: np.where(top_k_indices == subtree_idx)[0] + subtree_idx: np.where(top_k_indices == subtree_idx)[0] for subtree_idx in np.unique(top_k_indices) } # Calculate predictions for each subtree with its corresponding instances - for subtree, instances in subtree_to_instances.items(): + for subtree_idx, instances in subtree_to_instances.items(): + subtree = self.subtrees[subtree_idx] reduced_instances = x[np.s_[instances], :] # Locate the position of the subtree root in the weight mapping of all nodes. - subtree_weights_start = self.weight_map[subtree.root.index] - subtree_weights_end = subtree_weights_start + subtree.flat_model.weights.shape[1] + subtree_weights_start = self.weight_map[self.root.children[subtree_idx].index] + subtree_weights_end = subtree_weights_start + subtree.weights.shape[1] slice = np.s_[instances, subtree_weights_start:subtree_weights_end] - all_preds[slice] = linear.predict_values(subtree.flat_model, reduced_instances) + all_preds[slice] = linear.predict_values(subtree, reduced_instances) return all_preds @@ -159,12 +189,6 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred))) return scores -@dataclass(frozen=True) -class SubTree: - """Represents a subtree with its root node and the linear flattened model which builts from the subtree's root.""" - root: Node - flat_model: linear.FlatModel - def train_tree( y: sparse.csr_matrix, @@ -218,12 +242,7 @@ def count(node): pbar = tqdm(total=num_nodes, disable=not verbose) - index = 0 - def visit(node): - nonlocal index - node.index = index - index += 1 if node.is_root: _train_node(y, x, options, node) else: @@ -233,8 +252,8 @@ def visit(node): root.dfs(visit) pbar.close() - - return _tree_model(root) + flat_model, weight_map = _flatten_model(root) + return TreeModel(root, flat_model, weight_map) def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node: @@ -317,22 +336,32 @@ def _train_node(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str, node: node.model.weights = sparse.csc_matrix(node.model.weights) -def _flatten_model(root: Node) -> linear.FlatModel: +def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]: """Flattens tree weight matrices into a single weight matrix. The flattened weight matrix is used to predict all possible values, which is cached for beam search. This pessimizes complexity but is faster in practice. - flat_model = _flatten_model(root) + Consecutive values of the returned map denotes the start and end indices of the + weights of each node. Conceptually, given root and node: + flat_model, weight_map = _flatten_model(root) + slice = np.s_[weight_map[node.index]: + weight_map[node.index+1]] + node.model.weights == flat_model.weights[:, slice] + Args: root (Node): Root of the tree. Returns: - linear.FlatModel: The flattened model. + tuple[linear.FlatModel, np.ndarray]: The flattened model and the ranges of each node. """ + index = 0 weights = [] bias = root.model.bias def visit(node): assert bias == node.model.bias + nonlocal index + node.index = index + index += 1 weights.append(node.model.__dict__.pop("weights")) root.dfs(visit) @@ -345,55 +374,7 @@ def visit(node): multiclass=False, ) - return model - - -def _tree_model(root: Node) -> TreeModel: - """Constructs a tree model by aggregating the weights of all nodes in the tree. - To speed up inference in Python, we avoid using a single flattened weight matrix, - which would involve many unnecessary computations. - Instead, we build a hierarchical tree model by aggregating the weights of each root's child - into different flattened weight matrices, representing subtrees as `TreeModel` instances. - Additionally, the root itself is also a `TreeModel`, containing subtree `TreeModel` instances. - - Consecutive values of the weight map denotes the start and end indices of the - weights of each node. Conceptually, given root and node: - slice = np.s_[weight_map[node.index]: - weight_map[node.index+1]] - node.model.weights == flat_model.weights[:, slice] - - Args: - root (Node): Root of the tree. - - Returns: - Tree Model: A tree model containing the root's flattened model, - weight index mappings of all nodes, and subtrees. - """ - # Build weights mapping which contains the start and end indices of the weights of each node. - weight_map = [0] - subtrees = [] - bias = root.model.bias - - def visit(node): - assert bias == node.model.bias - # weights.shape[1] is the number of labels/metalabels of each node - weight_map.append(node.model.weights.shape[1]) - - root.dfs(visit) - - weight_map = np.cumsum(weight_map) - - # Build root's subtrees - for child in root.children: - child_flat_model = _flatten_model(child) - subtrees.append(SubTree(child, child_flat_model)) - # Build root's flatten model with root model weights - model = linear.FlatModel( - name="root-flattened-tree", - weights=root.model.__dict__.pop("weights"), - bias=root.model.bias, - thresholds=0, - multiclass=False, - ) + # w.shape[1] is the number of labels/metalabels of each node + weight_map = np.cumsum([0] + list(map(lambda w: w.shape[1], weights))) - return TreeModel(root, model, weight_map, subtrees) + return model, weight_map From 8221f695e3133a9ec2c9f8625ff3b0f4c2a13b7d Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Wed, 2 Apr 2025 18:07:36 +0800 Subject: [PATCH 07/15] Remove itertools.pairwise function call due to compatibility issues. --- libmultilabel/linear/tree.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 81cec14..552a543 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,7 +8,7 @@ import sklearn.preprocessing from tqdm import tqdm import psutil -import itertools +from more_itertools import pairwise from . import linear @@ -89,10 +89,10 @@ def _preprocess_beam_search(self): multiclass=False, ) - self.subtrees_weights = [] + self.subtrees = [] subtree_indices = [self.weight_map[child.index] for child in self.root.children] + [self.weight_map[-1]] - for subtree_start, subtree_end in itertools.pairwise(subtree_indices): + for subtree_start, subtree_end in zip(subtree_indices, subtree_indices[1:]): slice = np.s_[:, subtree_start:subtree_end] subtree_flatmodel = linear.FlatModel( name="subtree-flattened-tree", From 6d8322edc06d9dea74d2fd47f2747639c2415004 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Wed, 2 Apr 2025 18:14:04 +0800 Subject: [PATCH 08/15] Remove the more-iteratools module --- libmultilabel/linear/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 552a543..d20ea14 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,7 +8,6 @@ import sklearn.preprocessing from tqdm import tqdm import psutil -from more_itertools import pairwise from . import linear @@ -252,6 +251,7 @@ def visit(node): root.dfs(visit) pbar.close() + flat_model, weight_map = _flatten_model(root) return TreeModel(root, flat_model, weight_map) From 846cec1ab41a36be45fa0561fad25308f7b3f56a Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Thu, 3 Apr 2025 16:08:23 +0800 Subject: [PATCH 09/15] Rename the variable. --- libmultilabel/linear/tree.py | 60 +++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index d20ea14..a5ac669 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -69,38 +69,41 @@ def predict_values( np.ndarray: A matrix with dimension number of instances * number of classes. """ # number of instances * number of labels + total number of metalabels - self._preprocess_beam_search() - all_preds = self._prune_tree_predictions(x, beam_width) + if beam_width >= len(self.root.children): + all_preds = linear.predict_values(self.flat_model, x) + else: + self._preprocess_beam_search() + all_preds = self._prune_tree_predictions(x, beam_width) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) def _preprocess_beam_search(self): - """ Preprocess the flattened model for beam search. + """Preprocess the flattened model for beam search. This function extracts the weights for the root node and its children into separate FlatModel instances for efficient beam search traversal in Python. """ - if not hasattr(self, "root_weights"): - slice = np.s_[:, self.weight_map[self.root.index]:self.weight_map[self.root.index+1]] - self.root_weights = linear.FlatModel( - name="root-flattened-tree", - weights=self.flat_model.weights[slice], - bias=self.root.model.bias, - thresholds=0, - multiclass=False, - ) - - self.subtrees = [] + if not hasattr(self, "root_model"): + slice = np.s_[:, self.weight_map[self.root.index] : self.weight_map[self.root.index + 1]] + self.root_model = linear.FlatModel( + name="root-flattened-tree", + weights=self.flat_model.weights[slice].tocsr(), + bias=self.root.model.bias, + thresholds=0, + multiclass=False, + ) + + self.subtree_models = [] subtree_indices = [self.weight_map[child.index] for child in self.root.children] + [self.weight_map[-1]] - + for subtree_start, subtree_end in zip(subtree_indices, subtree_indices[1:]): slice = np.s_[:, subtree_start:subtree_end] subtree_flatmodel = linear.FlatModel( - name="subtree-flattened-tree", - weights=self.flat_model.weights[slice], - bias=self.root.model.bias, - thresholds=0, - multiclass=False, - ) - self.subtrees.append(subtree_flatmodel) + name="subtree-flattened-tree", + weights=self.flat_model.weights[slice].tocsr(), + bias=self.root.model.bias, + thresholds=0, + multiclass=False, + ) + self.subtree_models.append(subtree_flatmodel) def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: """Calculates the decision values associated with x. @@ -122,7 +125,7 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n all_preds = np.full((num_instances, num_labels), np.NINF) # Calculate root decision value and scores - root_preds = linear.predict_values(self.root_weights, x) + root_preds = linear.predict_values(self.root_model, x) children_scores = 0.0 - np.maximum(0, 1 - root_preds) ** 2 slice = np.s_[:num_instances, self.weight_map[self.root.index] : self.weight_map[self.root.index + 1]] @@ -134,20 +137,19 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n # Building a mapping from subtree to instances subtree_to_instances = { - subtree_idx: np.where(top_k_indices == subtree_idx)[0] - for subtree_idx in np.unique(top_k_indices) + subtree_idx: np.where(top_k_indices == subtree_idx)[0] for subtree_idx in np.unique(top_k_indices) } # Calculate predictions for each subtree with its corresponding instances for subtree_idx, instances in subtree_to_instances.items(): - subtree = self.subtrees[subtree_idx] + subtree_model = self.subtree_models[subtree_idx] reduced_instances = x[np.s_[instances], :] # Locate the position of the subtree root in the weight mapping of all nodes. subtree_weights_start = self.weight_map[self.root.children[subtree_idx].index] - subtree_weights_end = subtree_weights_start + subtree.weights.shape[1] + subtree_weights_end = subtree_weights_start + subtree_model.weights.shape[1] slice = np.s_[instances, subtree_weights_start:subtree_weights_end] - all_preds[slice] = linear.predict_values(subtree, reduced_instances) + all_preds[slice] = linear.predict_values(subtree_model, reduced_instances) return all_preds @@ -251,7 +253,7 @@ def visit(node): root.dfs(visit) pbar.close() - + flat_model, weight_map = _flatten_model(root) return TreeModel(root, flat_model, weight_map) From 66b50a49f32d805b3a4bf3bc50104444d186fcf7 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Thu, 10 Apr 2025 15:30:22 +0800 Subject: [PATCH 10/15] Update the inference algorithm. --- libmultilabel/common_utils.py | 10 ++++ libmultilabel/linear/tree.py | 104 +++++++++++++++++----------------- 2 files changed, 61 insertions(+), 53 deletions(-) diff --git a/libmultilabel/common_utils.py b/libmultilabel/common_utils.py index bb6fe3a..7ae2546 100644 --- a/libmultilabel/common_utils.py +++ b/libmultilabel/common_utils.py @@ -141,3 +141,13 @@ def wrapper(*args, **kwargs): return value return wrapper + +def pairwise(iterable): + # pairwise('ABCDEFG') → AB BC CD DE EF FG + + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b \ No newline at end of file diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index a5ac669..2e8f0e7 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,7 +8,7 @@ import sklearn.preprocessing from tqdm import tqdm import psutil - +from ..common_utils import pairwise from . import linear __all__ = ["train_tree", "TreeModel"] @@ -46,12 +46,12 @@ def __init__( self, root: Node, flat_model: linear.FlatModel, - weight_map: np.ndarray, + node_ptr: np.ndarray, ): self.name = "tree" self.root = root self.flat_model = flat_model - self.weight_map = weight_map + self.node_ptr = node_ptr self.multiclass = False def predict_values( @@ -72,46 +72,45 @@ def predict_values( if beam_width >= len(self.root.children): all_preds = linear.predict_values(self.flat_model, x) else: - self._preprocess_beam_search() + self._seperate_model_for_partial_predictions() all_preds = self._prune_tree_predictions(x, beam_width) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) - def _preprocess_beam_search(self): - """Preprocess the flattened model for beam search. - - This function extracts the weights for the root node and its children into separate FlatModel instances for efficient beam search traversal in Python. + def _seperate_model_for_partial_predictions(self): + """ + This function seperates the weights for the root node and its children into (K+1) FlatModel + for efficient beam search traversal in Python. """ if not hasattr(self, "root_model"): - slice = np.s_[:, self.weight_map[self.root.index] : self.weight_map[self.root.index + 1]] + tree_flat_model_params = { + 'bias': self.root.model.bias, + 'thresholds': 0, + 'multiclass': False + } + slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] self.root_model = linear.FlatModel( name="root-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), - bias=self.root.model.bias, - thresholds=0, - multiclass=False, + **tree_flat_model_params ) self.subtree_models = [] - subtree_indices = [self.weight_map[child.index] for child in self.root.children] + [self.weight_map[-1]] + children_indices = [child.index for child in self.root.children] + [-1] - for subtree_start, subtree_end in zip(subtree_indices, subtree_indices[1:]): - slice = np.s_[:, subtree_start:subtree_end] + for cur_child_idx, next_child_idx in pairwise(children_indices): + slice = np.s_[:, self.node_ptr[cur_child_idx]:self.node_ptr[next_child_idx]] subtree_flatmodel = linear.FlatModel( name="subtree-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), - bias=self.root.model.bias, - thresholds=0, - multiclass=False, + **tree_flat_model_params ) self.subtree_models.append(subtree_flatmodel) def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: - """Calculates the decision values associated with x. + """Calculates the paritial decision values associated with x. - In LibMultiLabel, we concatenate all nodes' weights into a single matrix to avoid the large overhead of performing multiple matrix multiplications in Python. - However, if the beam width is smaller than the number of nodes at a certain level, many nodes become unreachable, leading to unnecessary computations. - For example, in LibMultiLabel's default setting, the beam width is smaller than the root's degree in the tree. - To mitigate unnecessary computations, pruning is applied to predictions starting from the root. + Only subtrees corresponding to the top beam_width candidates from the root are evaluated, + skipping the rest to avoid unnecessary computation. Args: x (sparse.csr_matrix): A matrix with dimension number of instances * number of features. @@ -121,35 +120,36 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n np.ndarray: A matrix with dimension number of instances * (number of labels + total number of metalabels). """ # Initialize space for all predictions with negative infinity - num_instances, num_labels = x.shape[0], self.weight_map[-1] + num_instances, num_labels = x.shape[0], self.node_ptr[-1] all_preds = np.full((num_instances, num_labels), np.NINF) # Calculate root decision value and scores root_preds = linear.predict_values(self.root_model, x) - children_scores = 0.0 - np.maximum(0, 1 - root_preds) ** 2 + children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds)) - slice = np.s_[:num_instances, self.weight_map[self.root.index] : self.weight_map[self.root.index + 1]] + slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] all_preds[slice] = root_preds - if not self.root.isLeaf(): - # Find the top k subtree for each instance - top_k_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width] + # Select indices of the top beam_width subtrees for each instance + top_beam_width_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width] - # Building a mapping from subtree to instances - subtree_to_instances = { - subtree_idx: np.where(top_k_indices == subtree_idx)[0] for subtree_idx in np.unique(top_k_indices) - } + # Build a mask indicating whether i-th instance * j-th subtree + mask = np.zeros_like(children_scores, dtype=np.bool_) + row_indices = np.arange(num_instances)[:, np.newaxis] + mask[row_indices, top_beam_width_indices] = True + + # Calculate predictions for each subtree with its corresponding instances + for subtree_idx in range(len(self.root.children)): + subtree_model = self.subtree_models[subtree_idx] + instances_mask = mask[:, subtree_idx] + reduced_instances = x[np.s_[instances_mask], :] - # Calculate predictions for each subtree with its corresponding instances - for subtree_idx, instances in subtree_to_instances.items(): - subtree_model = self.subtree_models[subtree_idx] - reduced_instances = x[np.s_[instances], :] - # Locate the position of the subtree root in the weight mapping of all nodes. - subtree_weights_start = self.weight_map[self.root.children[subtree_idx].index] - subtree_weights_end = subtree_weights_start + subtree_model.weights.shape[1] + # Locate the position of the subtree root in the weight mapping of all nodes + subtree_weights_start = self.node_ptr[self.root.children[subtree_idx].index] + subtree_weights_end = subtree_weights_start + subtree_model.weights.shape[1] - slice = np.s_[instances, subtree_weights_start:subtree_weights_end] - all_preds[slice] = linear.predict_values(subtree_model, reduced_instances) + slice = np.s_[instances_mask, subtree_weights_start:subtree_weights_end] + all_preds[slice] = linear.predict_values(subtree_model, reduced_instances) return all_preds @@ -174,7 +174,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra if node.isLeaf(): next_level.append((node, score)) continue - slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]] + slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]] pred = instance_preds[slice] children_score = score - np.square(np.maximum(0, 1 - pred)) next_level.extend(zip(node.children, children_score.tolist())) @@ -185,7 +185,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra num_labels = len(self.root.label_map) scores = np.zeros(num_labels) for node, score in cur_level: - slice = np.s_[self.weight_map[node.index] : self.weight_map[node.index + 1]] + slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]] pred = instance_preds[slice] scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred))) return scores @@ -254,8 +254,8 @@ def visit(node): root.dfs(visit) pbar.close() - flat_model, weight_map = _flatten_model(root) - return TreeModel(root, flat_model, weight_map) + flat_model, node_ptr = _flatten_model(root) + return TreeModel(root, flat_model, node_ptr) def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray, d: int, K: int, dmax: int) -> Node: @@ -342,11 +342,9 @@ def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]: """Flattens tree weight matrices into a single weight matrix. The flattened weight matrix is used to predict all possible values, which is cached for beam search. This pessimizes complexity but is faster in practice. - Consecutive values of the returned map denotes the start and end indices of the - weights of each node. Conceptually, given root and node: - flat_model, weight_map = _flatten_model(root) - slice = np.s_[weight_map[node.index]: - weight_map[node.index+1]] + Consecutive values of the returned array where the classifiers for a node are stored in: + slice = np.s_[node_ptr[node.index]: + node_ptr[node.index+1]] node.model.weights == flat_model.weights[:, slice] Args: @@ -377,6 +375,6 @@ def visit(node): ) # w.shape[1] is the number of labels/metalabels of each node - weight_map = np.cumsum([0] + list(map(lambda w: w.shape[1], weights))) + node_ptr = np.cumsum([0] + list(map(lambda w: w.shape[1], weights))) - return model, weight_map + return model, node_ptr From dc570d0845ade8e06f66b417c5889cb6a4581bee Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Sun, 13 Apr 2025 03:38:34 +0800 Subject: [PATCH 11/15] Modify mask implementation and rename variable --- libmultilabel/common_utils.py | 12 +------ libmultilabel/linear/tree.py | 63 +++++++++++++++++------------------ 2 files changed, 32 insertions(+), 43 deletions(-) diff --git a/libmultilabel/common_utils.py b/libmultilabel/common_utils.py index 7ae2546..1f7bdc9 100644 --- a/libmultilabel/common_utils.py +++ b/libmultilabel/common_utils.py @@ -140,14 +140,4 @@ def wrapper(*args, **kwargs): logging.info(f"{repr(func.__name__)} finished in {wall_time:.2f} seconds") return value - return wrapper - -def pairwise(iterable): - # pairwise('ABCDEFG') → AB BC CD DE EF FG - - iterator = iter(iterable) - a = next(iterator, None) - - for b in iterator: - yield a, b - a = b \ No newline at end of file + return wrapper \ No newline at end of file diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 2e8f0e7..a929de5 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,7 +8,6 @@ import sklearn.preprocessing from tqdm import tqdm import psutil -from ..common_utils import pairwise from . import linear __all__ = ["train_tree", "TreeModel"] @@ -53,6 +52,7 @@ def __init__( self.flat_model = flat_model self.node_ptr = node_ptr self.multiclass = False + self.weigths_separated = False # Used for faster prediction def predict_values( self, @@ -68,45 +68,45 @@ def predict_values( Returns: np.ndarray: A matrix with dimension number of instances * number of classes. """ - # number of instances * number of labels + total number of metalabels if beam_width >= len(self.root.children): - all_preds = linear.predict_values(self.flat_model, x) + all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels) else: - self._seperate_model_for_partial_predictions() - all_preds = self._prune_tree_predictions(x, beam_width) + if not self.weigths_separated: + self._separate_model_for_partial_predictions() + self.weigths_separated = True + all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) - def _seperate_model_for_partial_predictions(self): + def _separate_model_for_partial_predictions(self): """ This function seperates the weights for the root node and its children into (K+1) FlatModel for efficient beam search traversal in Python. """ - if not hasattr(self, "root_model"): - tree_flat_model_params = { - 'bias': self.root.model.bias, - 'thresholds': 0, - 'multiclass': False - } - slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] - self.root_model = linear.FlatModel( - name="root-flattened-tree", + tree_flat_model_params = { + 'bias': self.root.model.bias, + 'thresholds': 0, + 'multiclass': False + } + slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] + self.root_model = linear.FlatModel( + name="root-flattened-tree", + weights=self.flat_model.weights[slice].tocsr(), + **tree_flat_model_params + ) + + self.subtree_models = [] + for i in range(len(self.root.children)): + subtree_weights_start = self.node_ptr[self.root.children[i].index] + subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else -1 + slice = np.s_[:, subtree_weights_start:subtree_weights_end] + subtree_flatmodel = linear.FlatModel( + name="subtree-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params ) - - self.subtree_models = [] - children_indices = [child.index for child in self.root.children] + [-1] - - for cur_child_idx, next_child_idx in pairwise(children_indices): - slice = np.s_[:, self.node_ptr[cur_child_idx]:self.node_ptr[next_child_idx]] - subtree_flatmodel = linear.FlatModel( - name="subtree-flattened-tree", - weights=self.flat_model.weights[slice].tocsr(), - **tree_flat_model_params - ) - self.subtree_models.append(subtree_flatmodel) - - def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: + self.subtree_models.append(subtree_flatmodel) + + def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: """Calculates the paritial decision values associated with x. Only subtrees corresponding to the top beam_width candidates from the root are evaluated, @@ -123,7 +123,7 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n num_instances, num_labels = x.shape[0], self.node_ptr[-1] all_preds = np.full((num_instances, num_labels), np.NINF) - # Calculate root decision value and scores + # Calculate root decision values and scores root_preds = linear.predict_values(self.root_model, x) children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds)) @@ -135,8 +135,7 @@ def _prune_tree_predictions(self, x: sparse.csr_matrix, beam_width: int) -> np.n # Build a mask indicating whether i-th instance * j-th subtree mask = np.zeros_like(children_scores, dtype=np.bool_) - row_indices = np.arange(num_instances)[:, np.newaxis] - mask[row_indices, top_beam_width_indices] = True + np.put_along_axis(mask, top_beam_width_indices, True, axis=1) # Calculate predictions for each subtree with its corresponding instances for subtree_idx in range(len(self.root.children)): From d94d9c0e26860ff7dbebdf2f8dcc617578c301de Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Sun, 13 Apr 2025 17:08:17 +0800 Subject: [PATCH 12/15] Rename the variable and modfiy the comment --- libmultilabel/common_utils.py | 2 +- libmultilabel/linear/tree.py | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/libmultilabel/common_utils.py b/libmultilabel/common_utils.py index 1f7bdc9..bb6fe3a 100644 --- a/libmultilabel/common_utils.py +++ b/libmultilabel/common_utils.py @@ -140,4 +140,4 @@ def wrapper(*args, **kwargs): logging.info(f"{repr(func.__name__)} finished in {wall_time:.2f} seconds") return value - return wrapper \ No newline at end of file + return wrapper diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index a929de5..ac0fb80 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -52,7 +52,7 @@ def __init__( self.flat_model = flat_model self.node_ptr = node_ptr self.multiclass = False - self.weigths_separated = False # Used for faster prediction + self.model_separated = False # Used for faster prediction def predict_values( self, @@ -69,24 +69,26 @@ def predict_values( np.ndarray: A matrix with dimension number of instances * number of classes. """ if beam_width >= len(self.root.children): + # Beam width sufficiently large; pruning skipped, pruning not applied, computing decision values for all nodes. all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels) else: - if not self.weigths_separated: + # Beam width is small; pruning applied, computing decision values selectively. + if not self.model_separated: self._separate_model_for_partial_predictions() - self.weigths_separated = True + self.model_separated = True all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) def _separate_model_for_partial_predictions(self): """ - This function seperates the weights for the root node and its children into (K+1) FlatModel + This function seperates the weights for the root node and its children into (K+1) FlatModel for efficient beam search traversal in Python. """ tree_flat_model_params = { 'bias': self.root.model.bias, 'thresholds': 0, 'multiclass': False - } + } slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]] self.root_model = linear.FlatModel( name="root-flattened-tree", @@ -96,7 +98,7 @@ def _separate_model_for_partial_predictions(self): self.subtree_models = [] for i in range(len(self.root.children)): - subtree_weights_start = self.node_ptr[self.root.children[i].index] + subtree_weights_start = self.node_ptr[self.root.children[i].index] subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else -1 slice = np.s_[:, subtree_weights_start:subtree_weights_end] subtree_flatmodel = linear.FlatModel( @@ -107,9 +109,9 @@ def _separate_model_for_partial_predictions(self): self.subtree_models.append(subtree_flatmodel) def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: - """Calculates the paritial decision values associated with x. + """Calculates the selected decision values associated with instances x. - Only subtrees corresponding to the top beam_width candidates from the root are evaluated, + Only subtrees corresponding to the top beam_width candidates from the root are evaluated, skipping the rest to avoid unnecessary computation. Args: @@ -133,7 +135,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) # Select indices of the top beam_width subtrees for each instance top_beam_width_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width] - # Build a mask indicating whether i-th instance * j-th subtree + # Build a mask where mask[i, j] is True if the j-th subtree is among the top beam width subtrees for the i-th instance mask = np.zeros_like(children_scores, dtype=np.bool_) np.put_along_axis(mask, top_beam_width_indices, True, axis=1) From d9ec262a3b5da05592e5685f7b28202e9e70827b Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Sun, 13 Apr 2025 20:27:11 +0800 Subject: [PATCH 13/15] Rename the function --- libmultilabel/linear/tree.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index ac0fb80..f96726d 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -8,6 +8,7 @@ import sklearn.preprocessing from tqdm import tqdm import psutil + from . import linear __all__ = ["train_tree", "TreeModel"] @@ -69,17 +70,18 @@ def predict_values( np.ndarray: A matrix with dimension number of instances * number of classes. """ if beam_width >= len(self.root.children): - # Beam width sufficiently large; pruning skipped, pruning not applied, computing decision values for all nodes. + # Beam width sufficiently large; pruning not applied. + # Calculates decision values for all nodes. all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels) else: - # Beam width is small; pruning applied, computing decision values selectively. + # Beam width is small; pruning applied to reduce computation. if not self.model_separated: - self._separate_model_for_partial_predictions() + self._separate_model_for_pruning_tree() self.model_separated = True all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) - def _separate_model_for_partial_predictions(self): + def _separate_model_for_pruning_tree(self): """ This function seperates the weights for the root node and its children into (K+1) FlatModel for efficient beam search traversal in Python. @@ -109,7 +111,7 @@ def _separate_model_for_partial_predictions(self): self.subtree_models.append(subtree_flatmodel) def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray: - """Calculates the selected decision values associated with instances x. + """Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees. Only subtrees corresponding to the top beam_width candidates from the root are evaluated, skipping the rest to avoid unnecessary computation. From abcd111b71cf21cbe6807fd7643a70023a20123a Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Sat, 19 Apr 2025 15:21:17 +0800 Subject: [PATCH 14/15] Rename variable --- libmultilabel/linear/tree.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index f96726d..512440f 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -53,7 +53,7 @@ def __init__( self.flat_model = flat_model self.node_ptr = node_ptr self.multiclass = False - self.model_separated = False # Used for faster prediction + self._model_separated = False # Indicates whether the model has been separated for pruning tree. def predict_values( self, @@ -70,14 +70,14 @@ def predict_values( np.ndarray: A matrix with dimension number of instances * number of classes. """ if beam_width >= len(self.root.children): - # Beam width sufficiently large; pruning not applied. + # Beam_width is sufficiently large; pruning not applied. # Calculates decision values for all nodes. all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels) else: - # Beam width is small; pruning applied to reduce computation. - if not self.model_separated: + # Beam_width is small; pruning applied to reduce computation. + if not self._model_separated: self._separate_model_for_pruning_tree() - self.model_separated = True + self._model_separated = True all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels) return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])]) @@ -137,7 +137,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) # Select indices of the top beam_width subtrees for each instance top_beam_width_indices = np.argsort(-children_scores, axis=1, kind="stable")[:, :beam_width] - # Build a mask where mask[i, j] is True if the j-th subtree is among the top beam width subtrees for the i-th instance + # Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance mask = np.zeros_like(children_scores, dtype=np.bool_) np.put_along_axis(mask, top_beam_width_indices, True, axis=1) @@ -345,7 +345,8 @@ def _flatten_model(root: Node) -> tuple[linear.FlatModel, np.ndarray]: """Flattens tree weight matrices into a single weight matrix. The flattened weight matrix is used to predict all possible values, which is cached for beam search. This pessimizes complexity but is faster in practice. - Consecutive values of the returned array where the classifiers for a node are stored in: + Consecutive values of the returned array denotes the start and end indices of each node in the weight matrix. + To extract a node's classifiers: slice = np.s_[node_ptr[node.index]: node_ptr[node.index+1]] node.model.weights == flat_model.weights[:, slice] From 219f178cd1ba13b35e54627d006066c32d6e8126 Mon Sep 17 00:00:00 2001 From: zhi-bao Date: Mon, 21 Apr 2025 18:14:59 +0800 Subject: [PATCH 15/15] Fix the use of np.NINF which is not supported in numpy 2.0 and above. --- libmultilabel/linear/tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libmultilabel/linear/tree.py b/libmultilabel/linear/tree.py index 512440f..39bfbf0 100644 --- a/libmultilabel/linear/tree.py +++ b/libmultilabel/linear/tree.py @@ -125,7 +125,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) """ # Initialize space for all predictions with negative infinity num_instances, num_labels = x.shape[0], self.node_ptr[-1] - all_preds = np.full((num_instances, num_labels), np.NINF) + all_preds = np.full((num_instances, num_labels), -np.inf) # Calculate root decision values and scores root_preds = linear.predict_values(self.root_model, x)