Skip to content

Commit f8fb696

Browse files
Assume default model dir if train is not present (#814)
1 parent b9667d3 commit f8fb696

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

silnlp/nmt/experiment.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test(self):
8080
scorers=self.scorers,
8181
produce_multiple_translations=self.produce_multiple_translations,
8282
save_confidences=self.save_confidences,
83-
save_checkpoints=self.save_checkpoints,
83+
use_default_model_dir=self.save_checkpoints,
8484
)
8585

8686
def translate(self):
@@ -94,7 +94,7 @@ def translate(self):
9494
translator = TranslationTask(
9595
name=self.name,
9696
checkpoint=config.get("checkpoint", "last"),
97-
use_default_model_dir=True if not (self.run_train) else self.save_checkpoints,
97+
use_default_model_dir=self.save_checkpoints,
9898
commit=self.commit,
9999
)
100100

@@ -155,7 +155,12 @@ def main() -> None:
155155
help="Run remotely on ClearML queue. Default: None - don't register with ClearML. The queue 'local' will run "
156156
+ "it locally and register it with ClearML.",
157157
)
158-
parser.add_argument("--save-checkpoints", default=False, action="store_true", help="Save checkpoints to S3 bucket")
158+
parser.add_argument(
159+
"--save-checkpoints",
160+
default=False,
161+
action="store_true",
162+
help="Save checkpoints to bucket. Only used if running the train step.",
163+
)
159164
parser.add_argument("--preprocess", default=False, action="store_true", help="Run the preprocess step.")
160165
parser.add_argument("--train", default=False, action="store_true", help="Run the train step.")
161166
parser.add_argument("--test", default=False, action="store_true", help="Run the test step.")
@@ -206,6 +211,9 @@ def main() -> None:
206211
args.train = True
207212
args.test = True
208213

214+
if not args.train:
215+
args.save_checkpoints = True
216+
209217
exp = SILExperiment(
210218
name=args.experiment,
211219
make_stats=args.stats,

silnlp/nmt/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,10 @@ def test(
592592
by_book: bool = False,
593593
produce_multiple_translations: bool = False,
594594
save_confidences: bool = False,
595-
save_checkpoints: bool = False,
595+
use_default_model_dir: bool = False,
596596
):
597597
exp_name = experiment
598-
config = load_config(exp_name, save_checkpoints)
598+
config = load_config(exp_name, use_default_model_dir)
599599

600600
if not any(config.exp_dir.glob("test*.src.txt")):
601601
LOGGER.info("No test dataset.")

0 commit comments

Comments
 (0)