11import logging
22import os
3+ import pickle
34
45import numpy as np
56from lightning .pytorch .callbacks import ModelCheckpoint
@@ -25,6 +26,8 @@ class TorchTrainer:
2526 Defaults to True.
2627 """
2728
29+ WORD_DICT_NAME = "word_dict.pickle"
30+
2831 def __init__ (
2932 self ,
3033 config : dict ,
@@ -44,6 +47,11 @@ def __init__(
4447 self .device = init_device (use_cpu = config .cpu )
4548 self .config = config
4649
50+ # Set dataset meta info
51+ self .embed_vecs = embed_vecs
52+ self .word_dict = word_dict
53+ self .classes = classes
54+
4755 # Load pretrained tokenizer for dataset loader
4856 self .tokenizer = None
4957 tokenize_text = "lm_weight" not in config .network_config
@@ -69,8 +77,9 @@ def __init__(
6977 # Note that AttentionXML produces two models. checkpoint_path directs to model_1
7078 if config .checkpoint_path is None :
7179 if self .config .embed_file is not None :
72- logging .info ("Load word dictionary " )
73- word_dict , embed_vecs = data_utils .load_or_build_text_dict (
80+ word_dict_path = os .path .join (self .checkpoint_dir , self .WORD_DICT_NAME )
81+ logging .info (f"Load and cache the word dictionary into { word_dict_path } ." )
82+ self .word_dict , self .embed_vecs = data_utils .load_or_build_text_dict (
7483 dataset = self .datasets ["train" ] + self .datasets ["val" ],
7584 vocab_file = config .vocab_file ,
7685 min_vocab_freq = config .min_vocab_freq ,
@@ -79,9 +88,11 @@ def __init__(
7988 normalize_embed = config .normalize_embed ,
8089 embed_cache_dir = config .embed_cache_dir ,
8190 )
91+ with open (word_dict_path , "wb" ) as f :
92+ pickle .dump (self .word_dict , f )
8293
83- if not classes :
84- classes = data_utils .load_or_build_label (
94+ if not self . classes :
95+ self . classes = data_utils .load_or_build_label (
8596 self .datasets , self .config .label_file , self .config .include_test_labels
8697 )
8798
@@ -98,15 +109,12 @@ def __init__(
98109 f"Add { self .config .val_metric } to `monitor_metrics`."
99110 )
100111 self .config .monitor_metrics += [self .config .val_metric ]
101- self .trainer = PLTTrainer (self .config , classes = classes , embed_vecs = embed_vecs , word_dict = word_dict )
112+ self .trainer = PLTTrainer (
113+ self .config , classes = self .classes , embed_vecs = self .embed_vecs , word_dict = self .word_dict
114+ )
102115 return
103- self ._setup_model (
104- classes = classes ,
105- word_dict = word_dict ,
106- embed_vecs = embed_vecs ,
107- log_path = self .log_path ,
108- checkpoint_path = config .checkpoint_path ,
109- )
116+
117+ self ._setup_model (log_path = self .log_path , checkpoint_path = config .checkpoint_path )
110118 self .trainer = init_trainer (
111119 checkpoint_dir = self .checkpoint_dir ,
112120 epochs = config .epochs ,
@@ -125,19 +133,13 @@ def __init__(
125133
126134 def _setup_model (
127135 self ,
128- classes : list = None ,
129- word_dict : dict = None ,
130- embed_vecs = None ,
131136 log_path : str = None ,
132137 checkpoint_path : str = None ,
133138 ):
134139 """Setup model from checkpoint if a checkpoint path is passed in or specified in the config.
135140 Otherwise, initialize model from scratch.
136141
137142 Args:
138- classes(list): List of class names.
139- word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
140- embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
141143 log_path (str): Path to the log file. The log file contains the validation
142144 results for each epoch and the test results. If the `log_path` is None, no performance
143145 results will be logged.
@@ -149,11 +151,16 @@ def _setup_model(
149151 if checkpoint_path is not None :
150152 logging .info (f"Loading model from `{ checkpoint_path } ` with the previously saved hyper-parameter..." )
151153 self .model = Model .load_from_checkpoint (checkpoint_path , log_path = log_path )
154+ word_dict_path = os .path .join (os .path .dirname (checkpoint_path ), self .WORD_DICT_NAME )
155+ if os .path .exists (word_dict_path ):
156+ with open (word_dict_path , "rb" ) as f :
157+ self .word_dict = pickle .load (f )
152158 else :
153159 logging .info ("Initialize model from scratch." )
154160 if self .config .embed_file is not None :
155- logging .info ("Load word dictionary " )
156- word_dict , embed_vecs = data_utils .load_or_build_text_dict (
161+ word_dict_path = os .path .join (self .checkpoint_dir , self .WORD_DICT_NAME )
162+ logging .info (f"Load and cache the word dictionary into { word_dict_path } ." )
163+ self .word_dict , self .embed_vecs = data_utils .load_or_build_text_dict (
157164 dataset = self .datasets ["train" ],
158165 vocab_file = self .config .vocab_file ,
159166 min_vocab_freq = self .config .min_vocab_freq ,
@@ -162,8 +169,11 @@ def _setup_model(
162169 normalize_embed = self .config .normalize_embed ,
163170 embed_cache_dir = self .config .embed_cache_dir ,
164171 )
165- if not classes :
166- classes = data_utils .load_or_build_label (
172+ with open (word_dict_path , "wb" ) as f :
173+ pickle .dump (self .word_dict , f )
174+
175+ if not self .classes :
176+ self .classes = data_utils .load_or_build_label (
167177 self .datasets , self .config .label_file , self .config .include_test_labels
168178 )
169179
@@ -184,9 +194,8 @@ def _setup_model(
184194 self .model = init_model (
185195 model_name = self .config .model_name ,
186196 network_config = dict (self .config .network_config ),
187- classes = classes ,
188- word_dict = word_dict ,
189- embed_vecs = embed_vecs ,
197+ classes = self .classes ,
198+ embed_vecs = self .embed_vecs ,
190199 init_weight = self .config .init_weight ,
191200 log_path = log_path ,
192201 learning_rate = self .config .learning_rate ,
@@ -222,7 +231,7 @@ def _get_dataset_loader(self, split, shuffle=False):
222231 batch_size = self .config .batch_size if split == "train" else self .config .eval_batch_size ,
223232 shuffle = shuffle ,
224233 data_workers = self .config .data_workers ,
225- word_dict = self .model . word_dict ,
234+ word_dict = self .word_dict ,
226235 tokenizer = self .tokenizer ,
227236 add_special_tokens = self .config .add_special_tokens ,
228237 )
0 commit comments