-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathacktr_disc.py
155 lines (129 loc) · 6.67 KB
/
acktr_disc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os.path as osp
import time
import joblib
import numpy as np
import tensorflow as tf
from baselines import logger
from baselines.common import set_global_seeds, explained_variance
from baselines.a2c.a2c import Runner
from baselines.a2c.utils import discount_with_dones
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.utils import cat_entropy, mse
from baselines.acktr import kfac
class Model(object):
def __init__(self, policy, ob_space, ac_space, nenvs,total_timesteps, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, lrschedule='linear'):
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=nprocs,
inter_op_parallelism_threads=nprocs)
config.gpu_options.allow_growth = True
self.sess = sess = tf.Session(config=config)
nact = ac_space.n
nbatch = nenvs * nsteps
A = tf.placeholder(tf.int32, [nbatch])
ADV = tf.placeholder(tf.float32, [nbatch])
R = tf.placeholder(tf.float32, [nbatch])
PG_LR = tf.placeholder(tf.float32, [])
VF_LR = tf.placeholder(tf.float32, [])
self.model = step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False)
self.model2 = train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True)
logpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A)
self.logits = logits = train_model.pi
##training loss
pg_loss = tf.reduce_mean(ADV*logpac)
entropy = tf.reduce_mean(cat_entropy(train_model.pi))
pg_loss = pg_loss - ent_coef * entropy
vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.vf), R))
train_loss = pg_loss + vf_coef * vf_loss
##Fisher loss construction
self.pg_fisher = pg_fisher_loss = -tf.reduce_mean(logpac)
sample_net = train_model.vf + tf.random_normal(tf.shape(train_model.vf))
self.vf_fisher = vf_fisher_loss = - vf_fisher_coef*tf.reduce_mean(tf.pow(train_model.vf - tf.stop_gradient(sample_net), 2))
self.joint_fisher = joint_fisher_loss = pg_fisher_loss + vf_fisher_loss
self.params=params = find_trainable_variables("model")
self.grads_check = grads = tf.gradients(train_loss,params)
with tf.device('/gpu:0'):
self.optim = optim = kfac.KfacOptimizer(learning_rate=PG_LR, clip_kl=kfac_clip,\
momentum=0.9, kfac_update=1, epsilon=0.01,\
stats_decay=0.99, async=1, cold_iter=10, max_grad_norm=max_grad_norm)
update_stats_op = optim.compute_and_apply_stats(joint_fisher_loss, var_list=params)
train_op, q_runner = optim.apply_gradients(list(zip(grads,params)))
self.q_runner = q_runner
self.lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
def train(obs, states, rewards, masks, actions, values):
advs = rewards - values
for step in range(len(obs)):
cur_lr = self.lr.value()
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, PG_LR:cur_lr}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
policy_loss, value_loss, policy_entropy, _ = sess.run(
[pg_loss, vf_loss, entropy, train_op],
td_map
)
return policy_loss, value_loss, policy_entropy
def save(save_path):
ps = sess.run(params)
joblib.dump(ps, save_path)
def load(load_path):
loaded_params = joblib.load(load_path)
restores = []
for p, loaded_p in zip(params, loaded_params):
restores.append(p.assign(loaded_p))
sess.run(restores)
self.train = train
self.save = save
self.load = load
self.train_model = train_model
self.step_model = step_model
self.step = step_model.step
self.value = step_model.value
self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess)
def learn(policy, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear'):
tf.reset_default_graph()
set_global_seeds(seed)
nenvs = env.num_envs
ob_space = env.observation_space
ac_space = env.action_space
make_model = lambda : Model(policy, ob_space, ac_space, nenvs, total_timesteps, nprocs=nprocs, nsteps
=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, vf_fisher_coef=
vf_fisher_coef, lr=lr, max_grad_norm=max_grad_norm, kfac_clip=kfac_clip,
lrschedule=lrschedule)
if save_interval and logger.get_dir():
import cloudpickle
with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
fh.write(cloudpickle.dumps(make_model))
model = make_model()
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
nbatch = nenvs*nsteps
tstart = time.time()
coord = tf.train.Coordinator()
enqueue_threads = model.q_runner.create_threads(model.sess, coord=coord, start=True)
for update in range(1, total_timesteps//nbatch+1):
obs, states, rewards, masks, actions, values = runner.run()
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
model.old_obs = obs
nseconds = time.time()-tstart
fps = int((update*nbatch)/nseconds)
if update % log_interval == 0 or update == 1:
ev = explained_variance(values, rewards)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", update*nbatch)
logger.record_tabular("fps", fps)
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("policy_loss", float(policy_loss))
logger.record_tabular("value_loss", float(value_loss))
logger.record_tabular("explained_variance", float(ev))
logger.dump_tabular()
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
savepath = osp.join(logger.get_dir(), 'checkpoint%.5i'%update)
print('Saving to', savepath)
model.save(savepath)
coord.request_stop()
coord.join(enqueue_threads)
env.close()