Skip to content

Commit 7da8a48

Browse files
author
Jakub Pieszczek
committed
save state_dict.pt as a separate file
1 parent 2c5e40c commit 7da8a48

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

timm/utils/checkpoint_saver.py

+5
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ def save_checkpoint(self, epoch, metric=None):
107107
"model_kwargs": self.args.model_kwargs,
108108
}
109109
torch.save(model_dict, temp_location)
110+
torch.save(
111+
get_state_dict(self.model, self.unwrap_fn),
112+
os.path.join(temp_dir, "state_dict.pt"),
113+
)
110114
mlflow.log_artifact(temp_location)
115+
mlflow.log_artifact(os.path.join(temp_dir, "state_dict.pt"))
111116

112117

113118
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)

0 commit comments

Comments
 (0)