@@ -56,7 +56,7 @@ def __init__(
5656 self .flat_model = flat_model
5757 self .node_ptr = node_ptr
5858 self .multiclass = False
59- self ._model_separated = False # Indicates whether the model has been separated for pruning tree.
59+ self ._model_separated = False # Indicates whether the model has been separated for pruning tree.
6060
6161 def predict_values (
6262 self ,
@@ -75,44 +75,42 @@ def predict_values(
7575 if beam_width >= len (self .root .children ):
7676 # Beam_width is sufficiently large; pruning not applied.
7777 # Calculates decision values for all nodes.
78- all_preds = linear .predict_values (self .flat_model , x ) # number of instances * (number of labels + total number of metalabels)
78+ all_preds = linear .predict_values (
79+ self .flat_model , x
80+ ) # number of instances * (number of labels + total number of metalabels)
7981 else :
8082 # Beam_width is small; pruning applied to reduce computation.
8183 if not self ._model_separated :
8284 self ._separate_model_for_pruning_tree ()
8385 self ._model_separated = True
84- all_preds = self ._prune_tree_and_predict_values (x , beam_width ) # number of instances * (number of labels + total number of metalabels)
86+ all_preds = self ._prune_tree_and_predict_values (
87+ x , beam_width
88+ ) # number of instances * (number of labels + total number of metalabels)
8589 return np .vstack ([self ._beam_search (all_preds [i ], beam_width ) for i in range (all_preds .shape [0 ])])
8690
8791 def _separate_model_for_pruning_tree (self ):
8892 """
8993 This function separates the weights for the root node and its children into (K+1) FlatModel
9094 for efficient beam search traversal in Python.
9195 """
92- tree_flat_model_params = {
93- 'bias' : self .root .model .bias ,
94- 'thresholds' : 0 ,
95- 'multiclass' : False
96- }
96+ tree_flat_model_params = {"bias" : self .root .model .bias , "thresholds" : 0 , "multiclass" : False }
9797 slice = np .s_ [:, self .node_ptr [self .root .index ] : self .node_ptr [self .root .index + 1 ]]
9898 self .root_model = linear .FlatModel (
99- name = "root-flattened-tree" ,
100- weights = self .flat_model .weights [slice ].tocsr (),
101- ** tree_flat_model_params
99+ name = "root-flattened-tree" , weights = self .flat_model .weights [slice ].tocsr (), ** tree_flat_model_params
102100 )
103101
104102 self .subtree_models = []
105103 for i in range (len (self .root .children )):
106104 subtree_weights_start = self .node_ptr [self .root .children [i ].index ]
107- subtree_weights_end = self .node_ptr [self .root .children [i + 1 ].index ] if i + 1 < len (self .root .children ) else self .node_ptr [- 1 ]
105+ subtree_weights_end = (
106+ self .node_ptr [self .root .children [i + 1 ].index ] if i + 1 < len (self .root .children ) else self .node_ptr [- 1 ]
107+ )
108108 slice = np .s_ [:, subtree_weights_start :subtree_weights_end ]
109109 subtree_flatmodel = linear .FlatModel (
110- name = "subtree-flattened-tree" ,
111- weights = self .flat_model .weights [slice ].tocsr (),
112- ** tree_flat_model_params
110+ name = "subtree-flattened-tree" , weights = self .flat_model .weights [slice ].tocsr (), ** tree_flat_model_params
113111 )
114112 self .subtree_models .append (subtree_flatmodel )
115-
113+
116114 def _prune_tree_and_predict_values (self , x : sparse .csr_matrix , beam_width : int ) -> np .ndarray :
117115 """Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees.
118116
@@ -143,7 +141,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
143141 # Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance
144142 mask = np .zeros_like (children_scores , dtype = np .bool_ )
145143 np .put_along_axis (mask , top_beam_width_indices , True , axis = 1 )
146-
144+
147145 # Calculate predictions for each subtree with its corresponding instances
148146 for subtree_idx in range (len (self .root .children )):
149147 subtree_model = self .subtree_models [subtree_idx ]
@@ -427,7 +425,7 @@ def train_ensemble_tree(
427425 seed : int = None ,
428426) -> EnsembleTreeModel :
429427 """Trains an ensemble of tree models (Parabel/Bonsai-style).
430-
428+
431429 Args:
432430 y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
433431 x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
@@ -443,7 +441,7 @@ def train_ensemble_tree(
443441 """
444442 if seed is None :
445443 seed = 42
446-
444+
447445 tree_models = []
448446 for i in range (n_trees ):
449447 np .random .seed (seed + i )
0 commit comments