-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcandidate_selection.py
More file actions
93 lines (69 loc) · 3.53 KB
/
candidate_selection.py
File metadata and controls
93 lines (69 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import safety_gymnasium as gym
from env_wrapper import SafetyGoalFeatureWrapper
from stable_baselines3.ppo.var_ppo import VaRPPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
import time
import torch.nn as nn
import os
def create_env(env_name, mode, log_path, seed):
env = gym.make(env_name)
if mode == "trex":
samples_filename = f"trex_reward_seed{seed}.pkl"
else:
samples_filename = os.path.join(f"samples_seed{seed}.pkl")
env = SafetyGoalFeatureWrapper(env, env_name, samples_filename=samples_filename, mode=mode)
env = Monitor(env, log_path)
return env
def candidate_selection(args, verbose=True):
start_time = time.time()
# checkpoint/monitor path
log_path = os.path.join("./candidate_selection", args.env_name, args.mode, args.exp_name)
os.makedirs(log_path, exist_ok=True)
Monitor.EXT = f"monitor_seed{args.trial_seed}.csv"
# create env
env = create_env(args.env_name, args.mode, log_path, args.trial_seed)
n_rewards = env.get_n_rewards()
# create model
policy_kwargs = dict(activation_fn=nn.ReLU,
net_arch=dict(pi=[128, 128], vf=[128, 128]))
model = VaRPPO("MlpPolicy", env, n_rewards, alpha=args.alpha, verbose=1,
policy_kwargs=policy_kwargs, learning_rate=args.lr, seed=args.seed, n_steps=4000)
if args.mode == "dist":
# POSTPI
# set expected return of initial policy under reward samples
# so that algorithm improves over initial policy
import pickle
with open(os.path.join("evaluation", args.env_name, "demo", f"returns_seed{args.trial_seed}.pkl"), "rb") as f:
expected_returns = pickle.load(f)["cp_returns"].mean(0)
expected_returns *= args.epsilon
model.set_expected_return_init(expected_returns)
# save initial model
model.save(os.path.join(log_path, f"rl_model_seed{args.trial_seed}_0_steps.zip"))
print(model.policy)
# checkpoint callback
checkpoint_callback = CheckpointCallback(
save_freq = args.save_freq,
save_path = log_path,
name_prefix = f"rl_model_seed{args.trial_seed}",
save_replay_buffer=True
)
model.learn(args.total_timesteps, callback=checkpoint_callback)
print("Time:", time.time() - start_time)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("env_name", help="Environment name.")
parser.add_argument("--exp_name", default="")
# POSTPI specific
parser.add_argument("--epsilon", type=float, default=1.0, help="constant to multiply with expected return of initial pi.")
parser.add_argument("--total_timesteps", default=1000000, type=int, help="Total number of timesteps to train")
parser.add_argument("--save_freq", default=1000000, type=int, help="Frequency of saving model checkpoint")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument("--mode", type=str, default="dist",
help="'dist': VaR-EVD, 'mean': B-REX Mean, 'map': B-REX MAP, 'trex': T-REX, 'pgbroil': PG-BROIL.")
parser.add_argument("--alpha", type=float, default=0.975, help="alpha of Value-at-Risk or CVaR. 0.975 for POSTPI and 0.95 for PG-BROIL. Useless for others")
parser.add_argument("--trial_seed", type=int, default=1, help="Seed to determine which trial [1-20].")
parser.add_argument("--seed", type=int, default=42, help="Random Seed")
args = parser.parse_args()
candidate_selection(args)