1010import psutil
1111
1212from . import linear
13- from scipy .special import log_expit
14-
1513
1614__all__ = ["train_tree" , "TreeModel" , "train_ensemble_tree" , "EnsembleTreeModel" ]
1715
@@ -49,31 +47,13 @@ def __init__(
4947 root : Node ,
5048 flat_model : linear .FlatModel ,
5149 node_ptr : np .ndarray ,
52- options : str ,
5350 ):
5451 self .name = "tree"
5552 self .root = root
5653 self .flat_model = flat_model
5754 self .node_ptr = node_ptr
58- self .options = options
5955 self .multiclass = False
60- self ._model_separated = False # Indicates whether the model has been separated for pruning tree.
61-
62- def _is_lr (self ) -> bool :
63- options = self .options or ""
64- options_split = options .split ()
65- if "-s" in options_split :
66- i = options_split .index ("-s" )
67- if i + 1 < len (options_split ):
68- solver_type = options_split [i + 1 ]
69- return solver_type in ["0" , "6" , "7" ]
70- return False
71-
72- def _get_scores (self , pred : np .ndarray , parent_score : float = 0.0 ) -> np .ndarray :
73- if self ._is_lr ():
74- return parent_score + log_expit (pred )
75- else :
76- return parent_score - np .square (np .maximum (0 , 1 - pred ))
56+ self ._model_separated = False # Indicates whether the model has been separated for pruning tree.
7757
7858 def predict_values (
7959 self ,
@@ -92,42 +72,44 @@ def predict_values(
9272 if beam_width >= len (self .root .children ):
9373 # Beam_width is sufficiently large; pruning not applied.
9474 # Calculates decision values for all nodes.
95- all_preds = linear .predict_values (
96- self .flat_model , x
97- ) # number of instances * (number of labels + total number of metalabels)
75+ all_preds = linear .predict_values (self .flat_model , x ) # number of instances * (number of labels + total number of metalabels)
9876 else :
9977 # Beam_width is small; pruning applied to reduce computation.
10078 if not self ._model_separated :
10179 self ._separate_model_for_pruning_tree ()
10280 self ._model_separated = True
103- all_preds = self ._prune_tree_and_predict_values (
104- x , beam_width
105- ) # number of instances * (number of labels + total number of metalabels)
81+ all_preds = self ._prune_tree_and_predict_values (x , beam_width ) # number of instances * (number of labels + total number of metalabels)
10682 return np .vstack ([self ._beam_search (all_preds [i ], beam_width ) for i in range (all_preds .shape [0 ])])
10783
10884 def _separate_model_for_pruning_tree (self ):
10985 """
11086 This function separates the weights for the root node and its children into (K+1) FlatModel
11187 for efficient beam search traversal in Python.
11288 """
113- tree_flat_model_params = {"bias" : self .root .model .bias , "thresholds" : 0 , "multiclass" : False }
89+ tree_flat_model_params = {
90+ 'bias' : self .root .model .bias ,
91+ 'thresholds' : 0 ,
92+ 'multiclass' : False
93+ }
11494 slice = np .s_ [:, self .node_ptr [self .root .index ] : self .node_ptr [self .root .index + 1 ]]
11595 self .root_model = linear .FlatModel (
116- name = "root-flattened-tree" , weights = self .flat_model .weights [slice ].tocsr (), ** tree_flat_model_params
96+ name = "root-flattened-tree" ,
97+ weights = self .flat_model .weights [slice ].tocsr (),
98+ ** tree_flat_model_params
11799 )
118100
119101 self .subtree_models = []
120102 for i in range (len (self .root .children )):
121103 subtree_weights_start = self .node_ptr [self .root .children [i ].index ]
122- subtree_weights_end = (
123- self .node_ptr [self .root .children [i + 1 ].index ] if i + 1 < len (self .root .children ) else - 1
124- )
104+ subtree_weights_end = self .node_ptr [self .root .children [i + 1 ].index ] if i + 1 < len (self .root .children ) else - 1
125105 slice = np .s_ [:, subtree_weights_start :subtree_weights_end ]
126106 subtree_flatmodel = linear .FlatModel (
127- name = "subtree-flattened-tree" , weights = self .flat_model .weights [slice ].tocsr (), ** tree_flat_model_params
107+ name = "subtree-flattened-tree" ,
108+ weights = self .flat_model .weights [slice ].tocsr (),
109+ ** tree_flat_model_params
128110 )
129111 self .subtree_models .append (subtree_flatmodel )
130-
112+
131113 def _prune_tree_and_predict_values (self , x : sparse .csr_matrix , beam_width : int ) -> np .ndarray :
132114 """Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees.
133115
@@ -147,7 +129,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
147129
148130 # Calculate root decision values and scores
149131 root_preds = linear .predict_values (self .root_model , x )
150- children_scores = self . _get_scores ( root_preds )
132+ children_scores = 0.0 - np . square ( np . maximum ( 0 , 1 - root_preds ) )
151133
152134 slice = np .s_ [:, self .node_ptr [self .root .index ] : self .node_ptr [self .root .index + 1 ]]
153135 all_preds [slice ] = root_preds
@@ -158,7 +140,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
158140 # Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance
159141 mask = np .zeros_like (children_scores , dtype = np .bool_ )
160142 np .put_along_axis (mask , top_beam_width_indices , True , axis = 1 )
161-
143+
162144 # Calculate predictions for each subtree with its corresponding instances
163145 for subtree_idx in range (len (self .root .children )):
164146 subtree_model = self .subtree_models [subtree_idx ]
@@ -197,7 +179,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
197179 continue
198180 slice = np .s_ [self .node_ptr [node .index ] : self .node_ptr [node .index + 1 ]]
199181 pred = instance_preds [slice ]
200- children_score = self . _get_scores ( pred , score )
182+ children_score = score - np . square ( np . maximum ( 0 , 1 - pred ) )
201183 next_level .extend (zip (node .children , children_score .tolist ()))
202184
203185 cur_level = sorted (next_level , key = lambda pair : - pair [1 ])[:beam_width ]
@@ -208,7 +190,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
208190 for node , score in cur_level :
209191 slice = np .s_ [self .node_ptr [node .index ] : self .node_ptr [node .index + 1 ]]
210192 pred = instance_preds [slice ]
211- scores [node .label_map ] = np .exp (self . _get_scores ( pred , score ))
193+ scores [node .label_map ] = np .exp (score - np . square ( np . maximum ( 0 , 1 - pred ) ))
212194 return scores
213195
214196
@@ -276,7 +258,7 @@ def visit(node):
276258 pbar .close ()
277259
278260 flat_model , node_ptr = _flatten_model (root )
279- return TreeModel (root , flat_model , node_ptr , options )
261+ return TreeModel (root , flat_model , node_ptr )
280262
281263
282264def _build_tree (label_representation : sparse .csr_matrix , label_map : np .ndarray , d : int , K : int , dmax : int ) -> Node :
0 commit comments