@@ -41,7 +41,7 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
4141 classes = classes ,
4242 word_dict = word_dict ,
4343 search_params = True ,
44- save_checkpoints = False )
44+ save_checkpoints = True )
4545 trainer .train ()
4646
4747
@@ -213,9 +213,20 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
213213
214214 data = load_static_data (
215215 best_config , merge_train_val = best_config .merge_train_val )
216- logging .info (f'Re-training with best config: \n { best_config } ' )
217- trainer = TorchTrainer (config = best_config , ** data )
218- trainer .train ()
216+
217+ if merge_train_val :
218+ logging .info (f'Re-training with best config: \n { best_config } ' )
219+ trainer = TorchTrainer (config = best_config , ** data )
220+ trainer .train ()
221+ else :
222+ # If not merging training and validation data, load the best result from tune experiments.
223+ logging .info (f'Loading best model with best config: \n { best_config } ' )
224+ trainer = TorchTrainer (config = best_config , ** data )
225+ best_checkpoint = os .path .join (best_log_dir , 'best_model.ckpt' )
226+ last_checkpoint = os .path .join (best_log_dir , 'last.ckpt' )
227+ trainer ._setup_model (checkpoint_path = best_checkpoint )
228+ os .popen (f"cp { best_checkpoint } { os .path .join (checkpoint_dir , 'best_model.ckpt' )} " )
229+ os .popen (f"cp { last_checkpoint } { os .path .join (checkpoint_dir , 'last.ckpt' )} " )
219230
220231 if 'test' in data ['datasets' ]:
221232 test_results = trainer .test ()
0 commit comments