-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_eval.py
51 lines (39 loc) · 1.26 KB
/
custom_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os
import pickle
import torch
from custom_env import CustomEnv
from rsl_rl.runners import OnPolicyRunner
import genesis as gs
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--exp_name", type=str, default="custom-walking")
parser.add_argument("--ckpt", type=int, default=1000)
args = parser.parse_args()
gs.init()
log_dir = f"logs/{args.exp_name}"
env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb"))
reward_cfg["reward_scales"] = {}
env = CustomEnv(
num_envs=1,
env_cfg=env_cfg,
obs_cfg=obs_cfg,
reward_cfg=reward_cfg,
command_cfg=command_cfg,
show_viewer=True,
)
runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0")
resume_path = os.path.join(log_dir, f"model_{args.ckpt}.pt")
runner.load(resume_path)
policy = runner.get_inference_policy(device="cuda:0")
obs, _ = env.reset()
with torch.no_grad():
while True:
actions = policy(obs)
obs, _, rews, dones, infos = env.step(actions)
if __name__ == "__main__":
main()
"""
# evaluation
python examples/locomotion/go2_eval.py -e go2-walking -v --ckpt 100
"""