Skip to content

Commit c473d48

Browse files
authored
Merge pull request #24 from Eleven1Liu/free_worddict
Free worddict
2 parents 9af0865 + 33e407f commit c473d48

File tree

7 files changed

+52
-53
lines changed

7 files changed

+52
-53
lines changed

docs/examples/plot_KimCNN_quickstart.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
model_name=model_name,
5757
network_config=network_config,
5858
classes=classes,
59-
word_dict=word_dict,
6059
embed_vecs=embed_vecs,
6160
learning_rate=learning_rate,
6261
monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"],
@@ -66,7 +65,7 @@
6665
# * ``model_name`` leads ``init_model`` function to find a network model.
6766
# * ``network_config`` contains the configurations of a network model.
6867
# * ``classes`` is the label set of the data.
69-
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
68+
# * ``embed_vecs`` is the the pre-trained word vectors.
7069
# * ``moniter_metrics`` includes metrics you would like to track.
7170
#
7271
#

docs/examples/plot_bert_quickstart.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
# * ``model_name`` leads ``init_model`` function to find a network model.
7171
# * ``network_config`` contains the configurations of a network model.
7272
# * ``classes`` is the label set of the data.
73-
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
7473
# * ``moniter_metrics`` includes metrics you would like to track.
7574
#
7675
#

libmultilabel/nn/attentionxml.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
5+
import pickle
46
from functools import partial
57
from pathlib import Path
68
from typing import Generator, Sequence, Optional
@@ -33,6 +35,7 @@
3335

3436
class PLTTrainer:
3537
CHECKPOINT_NAME = "model_"
38+
WORD_DICT_NAME = "word_dict.pickle"
3639

3740
def __init__(
3841
self,
@@ -261,7 +264,6 @@ def fit(self, datasets):
261264
model_name="AttentionXML_0",
262265
network_config=self.network_config,
263266
classes=clusters,
264-
word_dict=self.word_dict,
265267
embed_vecs=self.embed_vecs,
266268
init_weight=self.init_weight,
267269
log_path=self.log_path,
@@ -380,7 +382,6 @@ def fit(self, datasets):
380382

381383
model_1 = PLTModel(
382384
classes=self.classes,
383-
word_dict=self.word_dict,
384385
network=network,
385386
log_path=self.log_path,
386387
learning_rate=self.learning_rate,
@@ -427,7 +428,11 @@ def test(self, dataset):
427428
save_k_predictions=self.save_k_predictions,
428429
metrics=self.metrics,
429430
)
430-
self.word_dict = model_1.word_dict
431+
432+
word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME)
433+
if os.path.exists(word_dict_path):
434+
with open(word_dict_path, "rb") as f:
435+
self.word_dict = pickle.load(f)
431436
classes = model_1.classes
432437

433438
test_x = self.reformat_text(dataset)
@@ -489,9 +494,11 @@ def reformat_text(self, dataset):
489494
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
490495
encoded_text = list(
491496
map(
492-
lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
493-
if text
494-
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
497+
lambda text: (
498+
torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
499+
if text
500+
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64)
501+
),
495502
[instance["text"][: self.max_seq_length] for instance in dataset],
496503
)
497504
)
@@ -519,15 +526,13 @@ class PLTModel(Model):
519526
def __init__(
520527
self,
521528
classes,
522-
word_dict,
523529
network,
524530
loss_function="binary_cross_entropy_with_logits",
525531
log_path=None,
526532
**kwargs,
527533
):
528534
super().__init__(
529535
classes=classes,
530-
word_dict=word_dict,
531536
network=network,
532537
loss_function=loss_function,
533538
log_path=log_path,

libmultilabel/nn/model.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,27 +181,17 @@ class Model(MultiLabelModel):
181181
182182
Args:
183183
classes (list): List of class names.
184-
word_dict (dict): A dictionary for mapping tokens to indices.
185184
network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
186185
loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
187186
cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.
188187
log_path (str): Path to a directory holding the log files and models.
189188
"""
190189

191-
def __init__(
192-
self,
193-
classes,
194-
word_dict,
195-
network,
196-
loss_function="binary_cross_entropy_with_logits",
197-
log_path=None,
198-
**kwargs
199-
):
190+
def __init__(self, classes, network, loss_function="binary_cross_entropy_with_logits", log_path=None, **kwargs):
200191
super().__init__(num_classes=len(classes), log_path=log_path, **kwargs)
201192
self.save_hyperparameters(
202193
ignore=["log_path"]
203194
) # If log_path is saved, loading the checkpoint will cause an error since each experiment has unique log_path (result_dir).
204-
self.word_dict = word_dict
205195
self.classes = classes
206196
self.network = network
207197
self.configure_loss_function(loss_function)

libmultilabel/nn/nn_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def init_model(
3737
model_name,
3838
network_config,
3939
classes,
40-
word_dict=None,
4140
embed_vecs=None,
4241
init_weight=None,
4342
log_path=None,
@@ -61,7 +60,6 @@ def init_model(
6160
model_name (str): Model to be used such as KimCNN.
6261
network_config (dict): Configuration for defining the network.
6362
classes (list): List of class names.
64-
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
6563
embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape
6664
(vocab_size, embed_dim). Defaults to None.
6765
init_weight (str): Weight initialization method from `torch.nn.init`.
@@ -98,7 +96,6 @@ def init_model(
9896

9997
model = Model(
10098
classes=classes,
101-
word_dict=word_dict,
10299
network=network,
103100
log_path=log_path,
104101
learning_rate=learning_rate,

tests/nn/components.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_name(self):
2020
return "token_to_id"
2121

2222
def get_from_trainer(self, trainer):
23-
return trainer.model.word_dict
23+
return trainer.word_dict
2424

2525
def compare(self, a, b):
2626
return a == b
@@ -34,7 +34,7 @@ def get_name(self):
3434
return "embed_vecs"
3535

3636
def get_from_trainer(self, trainer):
37-
return trainer.model.embed_vecs
37+
return trainer.embed_vecs
3838

3939
def compare(self, a, b):
4040
return (a == b).all()

torch_trainer.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import pickle
34

45
import numpy as np
56
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -25,6 +26,8 @@ class TorchTrainer:
2526
Defaults to True.
2627
"""
2728

29+
WORD_DICT_NAME = "word_dict.pickle"
30+
2831
def __init__(
2932
self,
3033
config: dict,
@@ -44,6 +47,11 @@ def __init__(
4447
self.device = init_device(use_cpu=config.cpu)
4548
self.config = config
4649

50+
# Set dataset meta info
51+
self.embed_vecs = embed_vecs
52+
self.word_dict = word_dict
53+
self.classes = classes
54+
4755
# Load pretrained tokenizer for dataset loader
4856
self.tokenizer = None
4957
tokenize_text = "lm_weight" not in config.network_config
@@ -69,8 +77,9 @@ def __init__(
6977
# Note that AttentionXML produces two models. checkpoint_path directs to model_1
7078
if config.checkpoint_path is None:
7179
if self.config.embed_file is not None:
72-
logging.info("Load word dictionary ")
73-
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
80+
word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME)
81+
logging.info(f"Load and cache the word dictionary into {word_dict_path}.")
82+
self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict(
7483
dataset=self.datasets["train"] + self.datasets["val"],
7584
vocab_file=config.vocab_file,
7685
min_vocab_freq=config.min_vocab_freq,
@@ -79,9 +88,11 @@ def __init__(
7988
normalize_embed=config.normalize_embed,
8089
embed_cache_dir=config.embed_cache_dir,
8190
)
91+
with open(word_dict_path, "wb") as f:
92+
pickle.dump(self.word_dict, f)
8293

83-
if not classes:
84-
classes = data_utils.load_or_build_label(
94+
if not self.classes:
95+
self.classes = data_utils.load_or_build_label(
8596
self.datasets, self.config.label_file, self.config.include_test_labels
8697
)
8798

@@ -98,15 +109,12 @@ def __init__(
98109
f"Add {self.config.val_metric} to `monitor_metrics`."
99110
)
100111
self.config.monitor_metrics += [self.config.val_metric]
101-
self.trainer = PLTTrainer(self.config, classes=classes, embed_vecs=embed_vecs, word_dict=word_dict)
112+
self.trainer = PLTTrainer(
113+
self.config, classes=self.classes, embed_vecs=self.embed_vecs, word_dict=self.word_dict
114+
)
102115
return
103-
self._setup_model(
104-
classes=classes,
105-
word_dict=word_dict,
106-
embed_vecs=embed_vecs,
107-
log_path=self.log_path,
108-
checkpoint_path=config.checkpoint_path,
109-
)
116+
117+
self._setup_model(log_path=self.log_path, checkpoint_path=config.checkpoint_path)
110118
self.trainer = init_trainer(
111119
checkpoint_dir=self.checkpoint_dir,
112120
epochs=config.epochs,
@@ -125,19 +133,13 @@ def __init__(
125133

126134
def _setup_model(
127135
self,
128-
classes: list = None,
129-
word_dict: dict = None,
130-
embed_vecs=None,
131136
log_path: str = None,
132137
checkpoint_path: str = None,
133138
):
134139
"""Setup model from checkpoint if a checkpoint path is passed in or specified in the config.
135140
Otherwise, initialize model from scratch.
136141
137142
Args:
138-
classes(list): List of class names.
139-
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
140-
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
141143
log_path (str): Path to the log file. The log file contains the validation
142144
results for each epoch and the test results. If the `log_path` is None, no performance
143145
results will be logged.
@@ -149,11 +151,16 @@ def _setup_model(
149151
if checkpoint_path is not None:
150152
logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...")
151153
self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path)
154+
word_dict_path = os.path.join(os.path.dirname(checkpoint_path), self.WORD_DICT_NAME)
155+
if os.path.exists(word_dict_path):
156+
with open(word_dict_path, "rb") as f:
157+
self.word_dict = pickle.load(f)
152158
else:
153159
logging.info("Initialize model from scratch.")
154160
if self.config.embed_file is not None:
155-
logging.info("Load word dictionary ")
156-
word_dict, embed_vecs = data_utils.load_or_build_text_dict(
161+
word_dict_path = os.path.join(self.checkpoint_dir, self.WORD_DICT_NAME)
162+
logging.info(f"Load and cache the word dictionary into {word_dict_path}.")
163+
self.word_dict, self.embed_vecs = data_utils.load_or_build_text_dict(
157164
dataset=self.datasets["train"],
158165
vocab_file=self.config.vocab_file,
159166
min_vocab_freq=self.config.min_vocab_freq,
@@ -162,8 +169,11 @@ def _setup_model(
162169
normalize_embed=self.config.normalize_embed,
163170
embed_cache_dir=self.config.embed_cache_dir,
164171
)
165-
if not classes:
166-
classes = data_utils.load_or_build_label(
172+
with open(word_dict_path, "wb") as f:
173+
pickle.dump(self.word_dict, f)
174+
175+
if not self.classes:
176+
self.classes = data_utils.load_or_build_label(
167177
self.datasets, self.config.label_file, self.config.include_test_labels
168178
)
169179

@@ -184,9 +194,8 @@ def _setup_model(
184194
self.model = init_model(
185195
model_name=self.config.model_name,
186196
network_config=dict(self.config.network_config),
187-
classes=classes,
188-
word_dict=word_dict,
189-
embed_vecs=embed_vecs,
197+
classes=self.classes,
198+
embed_vecs=self.embed_vecs,
190199
init_weight=self.config.init_weight,
191200
log_path=log_path,
192201
learning_rate=self.config.learning_rate,
@@ -222,7 +231,7 @@ def _get_dataset_loader(self, split, shuffle=False):
222231
batch_size=self.config.batch_size if split == "train" else self.config.eval_batch_size,
223232
shuffle=shuffle,
224233
data_workers=self.config.data_workers,
225-
word_dict=self.model.word_dict,
234+
word_dict=self.word_dict,
226235
tokenizer=self.tokenizer,
227236
add_special_tokens=self.config.add_special_tokens,
228237
)

0 commit comments

Comments
 (0)