@@ -80,7 +80,7 @@ def test(self):
8080 scorers = self .scorers ,
8181 produce_multiple_translations = self .produce_multiple_translations ,
8282 save_confidences = self .save_confidences ,
83- save_checkpoints = self .save_checkpoints ,
83+ use_default_model_dir = self .save_checkpoints ,
8484 )
8585
8686 def translate (self ):
@@ -94,7 +94,7 @@ def translate(self):
9494 translator = TranslationTask (
9595 name = self .name ,
9696 checkpoint = config .get ("checkpoint" , "last" ),
97- use_default_model_dir = True if not ( self . run_train ) else self .save_checkpoints ,
97+ use_default_model_dir = self .save_checkpoints ,
9898 commit = self .commit ,
9999 )
100100
@@ -155,7 +155,12 @@ def main() -> None:
155155 help = "Run remotely on ClearML queue. Default: None - don't register with ClearML. The queue 'local' will run "
156156 + "it locally and register it with ClearML." ,
157157 )
158- parser .add_argument ("--save-checkpoints" , default = False , action = "store_true" , help = "Save checkpoints to S3 bucket" )
158+ parser .add_argument (
159+ "--save-checkpoints" ,
160+ default = False ,
161+ action = "store_true" ,
162+ help = "Save checkpoints to bucket. Only used if running the train step." ,
163+ )
159164 parser .add_argument ("--preprocess" , default = False , action = "store_true" , help = "Run the preprocess step." )
160165 parser .add_argument ("--train" , default = False , action = "store_true" , help = "Run the train step." )
161166 parser .add_argument ("--test" , default = False , action = "store_true" , help = "Run the test step." )
@@ -206,6 +211,9 @@ def main() -> None:
206211 args .train = True
207212 args .test = True
208213
214+ if not args .train :
215+ args .save_checkpoints = True
216+
209217 exp = SILExperiment (
210218 name = args .experiment ,
211219 make_stats = args .stats ,
0 commit comments