Skip to content

Commit 33e407f

Browse files
committed
Apply black formatter.
1 parent c01dedb commit 33e407f

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

libmultilabel/nn/attentionxml.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def test(self, dataset):
428428
save_k_predictions=self.save_k_predictions,
429429
metrics=self.metrics,
430430
)
431-
431+
432432
word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME)
433433
if os.path.exists(word_dict_path):
434434
with open(word_dict_path, "rb") as f:
@@ -494,9 +494,11 @@ def reformat_text(self, dataset):
494494
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
495495
encoded_text = list(
496496
map(
497-
lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
498-
if text
499-
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
497+
lambda text: (
498+
torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
499+
if text
500+
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64)
501+
),
500502
[instance["text"][: self.max_seq_length] for instance in dataset],
501503
)
502504
)

libmultilabel/nn/model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,7 @@ class Model(MultiLabelModel):
187187
log_path (str): Path to a directory holding the log files and models.
188188
"""
189189

190-
def __init__(
191-
self,
192-
classes,
193-
network,
194-
loss_function="binary_cross_entropy_with_logits",
195-
log_path=None,
196-
**kwargs
197-
):
190+
def __init__(self, classes, network, loss_function="binary_cross_entropy_with_logits", log_path=None, **kwargs):
198191
super().__init__(num_classes=len(classes), log_path=log_path, **kwargs)
199192
self.save_hyperparameters(
200193
ignore=["log_path"]

torch_trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class TorchTrainer:
2525
save_checkpoints (bool, optional): Whether to save the last and the best checkpoint or not.
2626
Defaults to True.
2727
"""
28+
2829
WORD_DICT_NAME = "word_dict.pickle"
2930

3031
def __init__(
@@ -87,7 +88,7 @@ def __init__(
8788
normalize_embed=config.normalize_embed,
8889
embed_cache_dir=config.embed_cache_dir,
8990
)
90-
with open(word_dict_path, "wb") as f:
91+
with open(word_dict_path, "wb") as f:
9192
pickle.dump(self.word_dict, f)
9293

9394
if not self.classes:
@@ -108,9 +109,11 @@ def __init__(
108109
f"Add {self.config.val_metric} to `monitor_metrics`."
109110
)
110111
self.config.monitor_metrics += [self.config.val_metric]
111-
self.trainer = PLTTrainer(self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict)
112+
self.trainer = PLTTrainer(
113+
self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict
114+
)
112115
return
113-
116+
114117
self._setup_model(log_path=self.log_path, checkpoint_path=config.checkpoint_path)
115118
self.trainer = init_trainer(
116119
checkpoint_dir=self.checkpoint_dir,
@@ -144,7 +147,7 @@ def _setup_model(
144147
"""
145148
if "checkpoint_path" in self.config and self.config.checkpoint_path is not None:
146149
checkpoint_path = self.config.checkpoint_path
147-
150+
148151
if checkpoint_path is not None:
149152
logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...")
150153
self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path)

0 commit comments

Comments
 (0)