@@ -25,6 +25,7 @@ class TorchTrainer:
2525 save_checkpoints (bool, optional): Whether to save the last and the best checkpoint or not.
2626 Defaults to True.
2727 """
28+
2829 WORD_DICT_NAME = "word_dict.pickle"
2930
3031 def __init__ (
@@ -87,7 +88,7 @@ def __init__(
8788 normalize_embed = config .normalize_embed ,
8889 embed_cache_dir = config .embed_cache_dir ,
8990 )
90- with open (word_dict_path , "wb" ) as f :
91+ with open (word_dict_path , "wb" ) as f :
9192 pickle .dump (self .word_dict , f )
9293
9394 if not self .classes :
@@ -108,9 +109,11 @@ def __init__(
108109 f"Add { self .config .val_metric } to `monitor_metrics`."
109110 )
110111 self .config .monitor_metrics += [self .config .val_metric ]
111- self .trainer = PLTTrainer (self .config , classes = self .classes , embed_vecs = self .embed_vecs , word_dict = self .word_dict )
112+ self .trainer = PLTTrainer (
113+ self .config , classes = self .classes , embed_vecs = self .embed_vecs , word_dict = self .word_dict
114+ )
112115 return
113-
116+
114117 self ._setup_model (log_path = self .log_path , checkpoint_path = config .checkpoint_path )
115118 self .trainer = init_trainer (
116119 checkpoint_dir = self .checkpoint_dir ,
@@ -144,7 +147,7 @@ def _setup_model(
144147 """
145148 if "checkpoint_path" in self .config and self .config .checkpoint_path is not None :
146149 checkpoint_path = self .config .checkpoint_path
147-
150+
148151 if checkpoint_path is not None :
149152 logging .info (f"Loading model from `{ checkpoint_path } ` with the previously saved hyper-parameter..." )
150153 self .model = Model .load_from_checkpoint (checkpoint_path , log_path = log_path )
0 commit comments