-
Notifications
You must be signed in to change notification settings - Fork 50
Expand file tree
/
Copy pathexport_model.py
More file actions
30 lines (27 loc) · 1.43 KB
/
export_model.py
File metadata and controls
30 lines (27 loc) · 1.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import glob
import yaml
import argparse
import torch
from utils.model import *
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task", required=True, type=str, help="Name of the task to run.")
parser.add_argument("--checkpoint", type=str, help="Path of model checkpoint to load. Overrides config file if provided.")
args = parser.parse_args()
cfg_file = os.path.join("envs", "{}.yaml".format(args.task))
with open(cfg_file, "r", encoding="utf-8") as f:
cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
if args.checkpoint is not None:
cfg["basic"]["checkpoint"] = args.checkpoint
model = ActorCritic(cfg["env"]["num_actions"], cfg["env"]["num_observations"], cfg["env"]["num_privileged_obs"])
if not cfg["basic"]["checkpoint"] or (cfg["basic"]["checkpoint"] == "-1") or (cfg["basic"]["checkpoint"] == -1):
cfg["basic"]["checkpoint"] = sorted(glob.glob(os.path.join("logs", "**/*.pth"), recursive=True), key=os.path.getmtime)[-1]
print("Loading model from {}".format(cfg["basic"]["checkpoint"]))
model_dict = torch.load(cfg["basic"]["checkpoint"], map_location="cpu", weights_only=True)
model.load_state_dict(model_dict["model"])
model.eval()
script_module = torch.jit.script(model.actor)
save_path = os.path.splitext(cfg["basic"]["checkpoint"])[0] + ".pt"
script_module.save(save_path)
print(f"Saved model to {save_path}")