Skip to content
63 changes: 60 additions & 3 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import scipy.sparse as sparse
from sparsekmeans import LloydKmeans, ElkanKmeans
import sklearn.preprocessing
from scipy.special import log_expit
from tqdm import tqdm
import psutil

from . import linear
from . import metrics

__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]

Expand Down Expand Up @@ -57,7 +59,11 @@ def __init__(
self.node_ptr = node_ptr
self.multiclass = False
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
self.estimator_parameter = 3

def sigmoid_A(self, x):
return log_expit(self.estimator_parameter * x)

def predict_values(
self,
x: sparse.csr_matrix,
Expand All @@ -68,10 +74,12 @@ def predict_values(
Args:
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
beam_width (int, optional): Number of candidates considered during beam search. Defaults to 10.
estimation_parameter (int, optional): The tunable parameter of probability estimation function, that is sigmoid(estimation_parameter * preds).

Returns:
np.ndarray: A matrix with dimension number of instances * number of classes.
"""

if beam_width >= len(self.root.children):
# Beam_width is sufficiently large; pruning not applied.
# Calculates decision values for all nodes.
Expand Down Expand Up @@ -132,7 +140,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)

# Calculate root decision values and scores
root_preds = linear.predict_values(self.root_model, x)
children_scores = 0.0 - np.square(np.maximum(0, 1 - root_preds))
children_scores = 0.0 + self.sigmoid_A(root_preds)

slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
all_preds[slice] = root_preds
Expand Down Expand Up @@ -182,7 +190,7 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
continue
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
children_score = score - np.square(np.maximum(0, 1 - pred))
children_score = score + self.sigmoid_A(pred)
next_level.extend(zip(node.children, children_score.tolist()))

cur_level = sorted(next_level, key=lambda pair: -pair[1])[:beam_width]
Expand All @@ -193,9 +201,58 @@ def _beam_search(self, instance_preds: np.ndarray, beam_width: int) -> np.ndarra
for node, score in cur_level:
slice = np.s_[self.node_ptr[node.index] : self.node_ptr[node.index + 1]]
pred = instance_preds[slice]
scores[node.label_map] = np.exp(score - np.square(np.maximum(0, 1 - pred)))
scores[node.label_map] = np.exp(score + self.sigmoid_A(pred))
return scores

def tuning_A_by_cross_validation(
self,
y: sparse.csr_matrix,
x: sparse.csr_matrix,
n_folds: int,
batch_size: int,
beamwidth: int,
metric: list,
A_candidates: list,
options: str = "",
K=100,
dmax=10,
):
data_splits = []
for n in range(n_folds):
start = np.ceil(n/n_folds*x.shape[0]).astype(int)
end = np.ceil((n+1)/n_folds*x.shape[0]).astype(int)
data_splits.append({'x':x[start:end, :], 'y':y[start:end ,:]})

score = {m:{A:0 for A in A_candidates} for m in metric}
for n in range(n_folds):
data_y = sparse.vstack([data_splits[j]["y"] for j in range(n_folds) if j != n])
data_x = sparse.vstack([data_splits[j]["x"] for j in range(n_folds) if j != n])

model = train_tree(
data_y,
data_x,
options,
K,
dmax,
)

for A in A_candidates:
model.estimator_parameter = A

num_instances = data_splits[n]["x"].shape[0]
num_batch = np.ceil(num_instances/batch_size).astype(int)
metric_eval = metrics.get_metrics(metric ,num_classes = data_y.shape[1])
for i in range(num_batch):
valid_x = data_splits[n]["x"][i * batch_size : (i+1) * batch_size]
valid_y = data_splits[n]["y"][i * batch_size : (i+1) * batch_size]
preds = model.predict_values(valid_x, beam_width=beamwidth)
metric_eval.update(preds, valid_y)

eval = metric_eval.compute()
for k in eval.keys():
score[k][A] += eval[k]

self.estimator_parameter = max(score[k], key=score[k].get)

def train_tree(
y: sparse.csr_matrix,
Expand Down