Skip to content

Commit 10c2d9e

Browse files
author
Winter Deng
committed
reformat all codes
1 parent b04e1cc commit 10c2d9e

File tree

5 files changed

+50
-37
lines changed

5 files changed

+50
-37
lines changed

libmultilabel/common_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def argsort_top_k(vals, k, axis=-1):
8282
k: Consider only the top k elements for each query
8383
axis: Axis along which to sort. The default is -1 (the last axis).
8484
85-
Returns:
85+
Returns:
8686
Array of indices that sort vals along the specified axis.
8787
"""
8888
unsorted_top_k_idx = np.argpartition(vals, -k, axis=axis)[:, -k:]

libmultilabel/linear/tree.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self.flat_model = flat_model
5757
self.node_ptr = node_ptr
5858
self.multiclass = False
59-
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
59+
self._model_separated = False # Indicates whether the model has been separated for pruning tree.
6060

6161
def predict_values(
6262
self,
@@ -75,44 +75,42 @@ def predict_values(
7575
if beam_width >= len(self.root.children):
7676
# Beam_width is sufficiently large; pruning not applied.
7777
# Calculates decision values for all nodes.
78-
all_preds = linear.predict_values(self.flat_model, x) # number of instances * (number of labels + total number of metalabels)
78+
all_preds = linear.predict_values(
79+
self.flat_model, x
80+
) # number of instances * (number of labels + total number of metalabels)
7981
else:
8082
# Beam_width is small; pruning applied to reduce computation.
8183
if not self._model_separated:
8284
self._separate_model_for_pruning_tree()
8385
self._model_separated = True
84-
all_preds = self._prune_tree_and_predict_values(x, beam_width) # number of instances * (number of labels + total number of metalabels)
86+
all_preds = self._prune_tree_and_predict_values(
87+
x, beam_width
88+
) # number of instances * (number of labels + total number of metalabels)
8589
return np.vstack([self._beam_search(all_preds[i], beam_width) for i in range(all_preds.shape[0])])
8690

8791
def _separate_model_for_pruning_tree(self):
8892
"""
8993
This function separates the weights for the root node and its children into (K+1) FlatModel
9094
for efficient beam search traversal in Python.
9195
"""
92-
tree_flat_model_params = {
93-
'bias': self.root.model.bias,
94-
'thresholds': 0,
95-
'multiclass': False
96-
}
96+
tree_flat_model_params = {"bias": self.root.model.bias, "thresholds": 0, "multiclass": False}
9797
slice = np.s_[:, self.node_ptr[self.root.index] : self.node_ptr[self.root.index + 1]]
9898
self.root_model = linear.FlatModel(
99-
name="root-flattened-tree",
100-
weights=self.flat_model.weights[slice].tocsr(),
101-
**tree_flat_model_params
99+
name="root-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params
102100
)
103101

104102
self.subtree_models = []
105103
for i in range(len(self.root.children)):
106104
subtree_weights_start = self.node_ptr[self.root.children[i].index]
107-
subtree_weights_end = self.node_ptr[self.root.children[i+1].index] if i+1 < len(self.root.children) else self.node_ptr[-1]
105+
subtree_weights_end = (
106+
self.node_ptr[self.root.children[i + 1].index] if i + 1 < len(self.root.children) else self.node_ptr[-1]
107+
)
108108
slice = np.s_[:, subtree_weights_start:subtree_weights_end]
109109
subtree_flatmodel = linear.FlatModel(
110-
name="subtree-flattened-tree",
111-
weights=self.flat_model.weights[slice].tocsr(),
112-
**tree_flat_model_params
110+
name="subtree-flattened-tree", weights=self.flat_model.weights[slice].tocsr(), **tree_flat_model_params
113111
)
114112
self.subtree_models.append(subtree_flatmodel)
115-
113+
116114
def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int) -> np.ndarray:
117115
"""Calculates the selective decision values associated with instances x by evaluating only the most relevant subtrees.
118116
@@ -143,7 +141,7 @@ def _prune_tree_and_predict_values(self, x: sparse.csr_matrix, beam_width: int)
143141
# Build a mask where mask[i, j] is True if the j-th subtree is among the top beam_width subtrees for the i-th instance
144142
mask = np.zeros_like(children_scores, dtype=np.bool_)
145143
np.put_along_axis(mask, top_beam_width_indices, True, axis=1)
146-
144+
147145
# Calculate predictions for each subtree with its corresponding instances
148146
for subtree_idx in range(len(self.root.children)):
149147
subtree_model = self.subtree_models[subtree_idx]
@@ -427,7 +425,7 @@ def train_ensemble_tree(
427425
seed: int = None,
428426
) -> EnsembleTreeModel:
429427
"""Trains an ensemble of tree models (Parabel/Bonsai-style).
430-
428+
431429
Args:
432430
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
433431
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
@@ -443,7 +441,7 @@ def train_ensemble_tree(
443441
"""
444442
if seed is None:
445443
seed = 42
446-
444+
447445
tree_models = []
448446
for i in range(n_trees):
449447
np.random.seed(seed + i)

libmultilabel/linear/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ class MultiLabelEstimator(sklearn.base.BaseEstimator):
7676
scoring_metric (str, optional): The scoring metric. Defaults to 'P@1'.
7777
"""
7878

79-
def __init__(self, options: str = "", linear_technique: str = "1vsrest", scoring_metric: str = "P@1", multiclass: bool = False):
79+
def __init__(
80+
self,
81+
options: str = "",
82+
linear_technique: str = "1vsrest",
83+
scoring_metric: str = "P@1",
84+
multiclass: bool = False,
85+
):
8086
super().__init__()
8187
self.options = options
8288
self.linear_technique = linear_technique
@@ -97,9 +103,7 @@ def predict(self, X: sparse.csr_matrix) -> np.ndarray:
97103

98104
def score(self, X: sparse.csr_matrix, y: sparse.csr_matrix) -> float:
99105
metrics = linear.get_metrics(
100-
monitor_metrics=[self.scoring_metric],
101-
num_classes=y.shape[1],
102-
multiclass=self.multiclass
106+
monitor_metrics=[self.scoring_metric], num_classes=y.shape[1], multiclass=self.multiclass
103107
)
104108
preds = self.predict(X)
105109
metrics.update(preds, y.toarray())

main.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def add_all_arguments(parser):
4141
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")
4242

4343
# data
44-
parser.add_argument("--data_name", default="unnamed_data", help="Dataset name for generating the output directory (default: %(default)s)")
44+
parser.add_argument(
45+
"--data_name",
46+
default="unnamed_data",
47+
help="Dataset name for generating the output directory (default: %(default)s)",
48+
)
4549
parser.add_argument("--training_file", help="Path to training data (default: %(default)s)")
4650
parser.add_argument("--val_file", help="Path to validation data (default: %(default)s)")
4751
parser.add_argument("--test_file", help="Path to test data (default: %(default)s)")
@@ -104,7 +108,9 @@ def add_all_arguments(parser):
104108
# pretrained vocab / embeddings
105109
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
106110
parser.add_argument(
107-
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
111+
"--embed_file",
112+
type=str,
113+
help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)",
108114
)
109115

110116
# train
@@ -235,7 +241,10 @@ def add_all_arguments(parser):
235241
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
236242
)
237243
parser.add_argument(
238-
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
244+
"--tree_ensemble_models",
245+
type=int,
246+
default=1,
247+
help="Number of models in the tree ensemble (default: %(default)s)",
239248
)
240249
parser.add_argument(
241250
"--beam_width",

search_params.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,18 @@ def load_static_data(config):
191191
)
192192
return {
193193
"datasets": datasets,
194-
"word_dict": None
195-
if config.embed_file is None
196-
else data_utils.load_or_build_text_dict(
197-
dataset=datasets["train"],
198-
vocab_file=config.vocab_file,
199-
min_vocab_freq=config.min_vocab_freq,
200-
embed_file=config.embed_file,
201-
embed_cache_dir=config.embed_cache_dir,
202-
silent=config.silent,
203-
normalize_embed=config.normalize_embed,
194+
"word_dict": (
195+
None
196+
if config.embed_file is None
197+
else data_utils.load_or_build_text_dict(
198+
dataset=datasets["train"],
199+
vocab_file=config.vocab_file,
200+
min_vocab_freq=config.min_vocab_freq,
201+
embed_file=config.embed_file,
202+
embed_cache_dir=config.embed_cache_dir,
203+
silent=config.silent,
204+
normalize_embed=config.normalize_embed,
205+
)
204206
),
205207
"classes": data_utils.load_or_build_label(datasets, config.label_file, config.include_test_labels),
206208
}

0 commit comments

Comments
 (0)