forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathddpg_agent.py
739 lines (648 loc) · 28.3 KB
/
ddpg_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A DDPG/NAF agent.
Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from
"Continuous control with deep reinforcement learning" - Lilicrap et al.
https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF)
algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al.
https://arxiv.org/pdf/1603.00748.
"""
import tensorflow as tf
slim = tf.contrib.slim
import gin.tf
from utils import utils
from agents import ddpg_networks as networks
@gin.configurable
class DdpgAgent(object):
"""An RL agent that learns using the DDPG algorithm.
Example usage:
def critic_net(states, actions):
...
def actor_net(states, num_action_dims):
...
Given a tensorflow environment tf_env,
(of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment)
obs_spec = tf_env.observation_spec()
action_spec = tf_env.action_spec()
ddpg_agent = agent.DdpgAgent(obs_spec,
action_spec,
actor_net=actor_net,
critic_net=critic_net)
we can perform actions on the environment as follows:
state = tf_env.observations()[0]
action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :]
transition_type, reward, discount = tf_env.step([action])
Train:
critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts,
next_states)
actor_loss = ddpg_agent.actor_loss(states)
critic_train_op = slim.learning.create_train_op(
critic_loss,
critic_optimizer,
variables_to_train=ddpg_agent.get_trainable_critic_vars(),
)
actor_train_op = slim.learning.create_train_op(
actor_loss,
actor_optimizer,
variables_to_train=ddpg_agent.get_trainable_actor_vars(),
)
"""
ACTOR_NET_SCOPE = 'actor_net'
CRITIC_NET_SCOPE = 'critic_net'
TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
def __init__(self,
observation_spec,
action_spec,
actor_net=networks.actor_net,
critic_net=networks.critic_net,
td_errors_loss=tf.losses.huber_loss,
dqda_clipping=0.,
actions_regularizer=0.,
target_q_clipping=None,
residual_phi=0.0,
debug_summaries=False):
"""Constructs a DDPG agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._observation_spec = observation_spec[0]
self._action_spec = action_spec[0]
self._state_shape = tf.TensorShape([None]).concatenate(
self._observation_spec.shape)
self._action_shape = tf.TensorShape([None]).concatenate(
self._action_spec.shape)
self._num_action_dims = self._action_spec.shape.num_elements()
self._scope = tf.get_variable_scope().name
self._actor_net = tf.make_template(
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._critic_net = tf.make_template(
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._target_actor_net = tf.make_template(
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._target_critic_net = tf.make_template(
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._td_errors_loss = td_errors_loss
if dqda_clipping < 0:
raise ValueError('dqda_clipping must be >= 0.')
self._dqda_clipping = dqda_clipping
self._actions_regularizer = actions_regularizer
self._target_q_clipping = target_q_clipping
self._residual_phi = residual_phi
self._debug_summaries = debug_summaries
def _batch_state(self, state):
"""Convert state to a batched state.
Args:
state: Either a list/tuple with an state tensor [num_state_dims].
Returns:
A tensor [1, num_state_dims]
"""
if isinstance(state, (tuple, list)):
state = state[0]
if state.get_shape().ndims == 1:
state = tf.expand_dims(state, 0)
return state
def action(self, state):
"""Returns the next action for the state.
Args:
state: A [num_state_dims] tensor representing a state.
Returns:
A [num_action_dims] tensor representing the action.
"""
return self.actor_net(self._batch_state(state), stop_gradients=True)[0, :]
@gin.configurable('ddpg_sample_action')
def sample_action(self, state, stddev=1.0):
"""Returns the action for the state with additive noise.
Args:
state: A [num_state_dims] tensor representing a state.
stddev: stddev for the Ornstein-Uhlenbeck noise.
Returns:
A [num_action_dims] action tensor.
"""
agent_action = self.action(state)
agent_action += tf.random_normal(tf.shape(agent_action)) * stddev
return utils.clip_to_spec(agent_action, self._action_spec)
def actor_net(self, states, stop_gradients=False):
"""Returns the output of the actor network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
stop_gradients: (boolean) if true, gradients cannot be propogated through
this operation.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self._actor_net(states, self._action_spec)
if stop_gradients:
actions = tf.stop_gradient(actions)
return actions
def critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
return self._critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_actor_net(self, states):
"""Returns the output of the target actor network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A [batch_size, num_action_dims] tensor of actions.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self._target_actor_net(states, self._action_spec)
return tf.stop_gradient(actions)
def target_critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
return tf.stop_gradient(
self._target_critic_net(states, actions,
for_critic_loss=for_critic_loss))
def value_net(self, states, for_critic_loss=False):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions = self.actor_net(states)
return self.critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_value_net(self, states, for_critic_loss=False):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions = self.target_actor_net(states)
return self.target_critic_net(states, target_actions,
for_critic_loss=for_critic_loss)
def critic_loss(self, states, actions, rewards, discounts,
next_states):
"""Computes a loss for training the critic network.
The loss is the mean squared error between the Q value predictions of the
critic and Q values estimated using TD-lambda.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
rewards: A [batch_size, ...] tensor representing a batch of rewards,
broadcastable to the critic net output.
discounts: A [batch_size, ...] tensor representing a batch of discounts,
broadcastable to the critic net output.
next_states: A [batch_size, num_state_dims] tensor representing a batch
of next states.
Returns:
A rank-0 tensor representing the critic loss.
Raises:
ValueError: If any of the inputs do not have the expected dimensions, or
if their batch_sizes do not match.
"""
self._validate_states(states)
self._validate_actions(actions)
self._validate_states(next_states)
target_q_values = self.target_value_net(next_states, for_critic_loss=True)
td_targets = target_q_values * discounts + rewards
if self._target_q_clipping is not None:
td_targets = tf.clip_by_value(td_targets, self._target_q_clipping[0],
self._target_q_clipping[1])
q_values = self.critic_net(states, actions, for_critic_loss=True)
td_errors = td_targets - q_values
if self._debug_summaries:
gen_debug_td_error_summaries(
target_q_values, q_values, td_targets, td_errors)
loss = self._td_errors_loss(td_targets, q_values)
if self._residual_phi > 0.0: # compute residual gradient loss
residual_q_values = self.value_net(next_states, for_critic_loss=True)
residual_td_targets = residual_q_values * discounts + rewards
if self._target_q_clipping is not None:
residual_td_targets = tf.clip_by_value(residual_td_targets,
self._target_q_clipping[0],
self._target_q_clipping[1])
residual_td_errors = residual_td_targets - q_values
residual_loss = self._td_errors_loss(
residual_td_targets, residual_q_values)
loss = (loss * (1.0 - self._residual_phi) +
residual_loss * self._residual_phi)
return loss
def actor_loss(self, states):
"""Computes a loss for training the actor network.
Note that output does not represent an actual loss. It is called a loss only
in the sense that its gradient w.r.t. the actor network weights is the
correct gradient for training the actor network,
i.e. dloss/dweights = (dq/da)*(da/dweights)
which is the gradient used in Algorithm 1 of Lilicrap et al.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
A rank-0 tensor representing the actor loss.
Raises:
ValueError: If `states` does not have the expected dimensions.
"""
self._validate_states(states)
actions = self.actor_net(states, stop_gradients=False)
critic_values = self.critic_net(states, actions)
q_values = self.critic_function(critic_values, states)
dqda = tf.gradients([q_values], [actions])[0]
dqda_unclipped = dqda
if self._dqda_clipping > 0:
dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping)
actions_norm = tf.norm(actions)
if self._debug_summaries:
with tf.name_scope('dqda'):
tf.summary.scalar('actions_norm', actions_norm)
tf.summary.histogram('dqda', dqda)
tf.summary.histogram('dqda_unclipped', dqda_unclipped)
tf.summary.histogram('actions', actions)
for a in range(self._num_action_dims):
tf.summary.histogram('dqda_unclipped_%d' % a, dqda_unclipped[:, a])
tf.summary.histogram('dqda_%d' % a, dqda[:, a])
actions_norm *= self._actions_regularizer
return slim.losses.mean_squared_error(tf.stop_gradient(dqda + actions),
actions,
scope='actor_loss') + actions_norm
@gin.configurable('ddpg_critic_function')
def critic_function(self, critic_values, states, weights=None):
"""Computes q values based on critic_net outputs, states, and weights.
Args:
critic_values: A tf.float32 [batch_size, ...] tensor representing outputs
from the critic net.
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
weights: A list or Numpy array or tensor with a shape broadcastable to
`critic_values`.
Returns:
A tf.float32 [batch_size] tensor representing q values.
"""
del states # unused args
if weights is not None:
weights = tf.convert_to_tensor(weights, dtype=critic_values.dtype)
critic_values *= weights
if critic_values.shape.ndims > 1:
critic_values = tf.reduce_sum(critic_values,
range(1, critic_values.shape.ndims))
critic_values.shape.assert_has_rank(1)
return critic_values
@gin.configurable('ddpg_update_targets')
def update_targets(self, tau=1.0):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
update_actor = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
tau)
update_critic = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
tau)
return tf.group(update_actor, update_critic, name='update_targets')
def get_trainable_critic_vars(self):
"""Returns a list of trainable variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
def get_trainable_actor_vars(self):
"""Returns a list of trainable variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
def get_critic_vars(self):
"""Returns a list of all variables in the critic network.
Returns:
A list of trainable variables in the critic network.
"""
return slim.get_model_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
def get_actor_vars(self):
"""Returns a list of all variables in the actor network.
Returns:
A list of trainable variables in the actor network.
"""
return slim.get_model_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
def _validate_states(self, states):
"""Raises a value error if `states` does not have the expected shape.
Args:
states: A tensor.
Raises:
ValueError: If states.shape or states.dtype are not compatible with
observation_spec.
"""
states.shape.assert_is_compatible_with(self._state_shape)
if not states.dtype.is_compatible_with(self._observation_spec.dtype):
raise ValueError('states.dtype={} is not compatible with'
' observation_spec.dtype={}'.format(
states.dtype, self._observation_spec.dtype))
def _validate_actions(self, actions):
"""Raises a value error if `actions` does not have the expected shape.
Args:
actions: A tensor.
Raises:
ValueError: If actions.shape or actions.dtype are not compatible with
action_spec.
"""
actions.shape.assert_is_compatible_with(self._action_shape)
if not actions.dtype.is_compatible_with(self._action_spec.dtype):
raise ValueError('actions.dtype={} is not compatible with'
' action_spec.dtype={}'.format(
actions.dtype, self._action_spec.dtype))
@gin.configurable
class TD3Agent(DdpgAgent):
"""An RL agent that learns using the TD3 algorithm."""
ACTOR_NET_SCOPE = 'actor_net'
CRITIC_NET_SCOPE = 'critic_net'
CRITIC_NET2_SCOPE = 'critic_net2'
TARGET_ACTOR_NET_SCOPE = 'target_actor_net'
TARGET_CRITIC_NET_SCOPE = 'target_critic_net'
TARGET_CRITIC_NET2_SCOPE = 'target_critic_net2'
def __init__(self,
observation_spec,
action_spec,
actor_net=networks.actor_net,
critic_net=networks.critic_net,
td_errors_loss=tf.losses.huber_loss,
dqda_clipping=0.,
actions_regularizer=0.,
target_q_clipping=None,
residual_phi=0.0,
debug_summaries=False):
"""Constructs a TD3 agent.
Args:
observation_spec: A TensorSpec defining the observations.
action_spec: A BoundedTensorSpec defining the actions.
actor_net: A callable that creates the actor network. Must take the
following arguments: states, num_actions. Please see networks.actor_net
for an example.
critic_net: A callable that creates the critic network. Must take the
following arguments: states, actions. Please see networks.critic_net
for an example.
td_errors_loss: A callable defining the loss function for the critic
td error.
dqda_clipping: (float) clips the gradient dqda element-wise between
[-dqda_clipping, dqda_clipping]. Does not perform clipping if
dqda_clipping == 0.
actions_regularizer: A scalar, when positive penalizes the norm of the
actions. This can prevent saturation of actions for the actor_loss.
target_q_clipping: (tuple of floats) clips target q values within
(low, high) values when computing the critic loss.
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that
interpolates between Q-learning and residual gradient algorithm.
http://www.leemon.com/papers/1995b.pdf
debug_summaries: If True, add summaries to help debug behavior.
Raises:
ValueError: If 'dqda_clipping' is < 0.
"""
self._observation_spec = observation_spec[0]
self._action_spec = action_spec[0]
self._state_shape = tf.TensorShape([None]).concatenate(
self._observation_spec.shape)
self._action_shape = tf.TensorShape([None]).concatenate(
self._action_spec.shape)
self._num_action_dims = self._action_spec.shape.num_elements()
self._scope = tf.get_variable_scope().name
self._actor_net = tf.make_template(
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._critic_net = tf.make_template(
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._critic_net2 = tf.make_template(
self.CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
self._target_actor_net = tf.make_template(
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True)
self._target_critic_net = tf.make_template(
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True)
self._target_critic_net2 = tf.make_template(
self.TARGET_CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True)
self._td_errors_loss = td_errors_loss
if dqda_clipping < 0:
raise ValueError('dqda_clipping must be >= 0.')
self._dqda_clipping = dqda_clipping
self._actions_regularizer = actions_regularizer
self._target_q_clipping = target_q_clipping
self._residual_phi = residual_phi
self._debug_summaries = debug_summaries
def get_trainable_critic_vars(self):
"""Returns a list of trainable variables in the critic network.
NOTE: This gets the vars of both critic networks.
Returns:
A list of trainable variables in the critic network.
"""
return (
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)))
def critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the critic network.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
values1 = self._critic_net(states, actions,
for_critic_loss=for_critic_loss)
values2 = self._critic_net2(states, actions,
for_critic_loss=for_critic_loss)
if for_critic_loss:
return values1, values2
return values1
def target_critic_net(self, states, actions, for_critic_loss=False):
"""Returns the output of the target critic network.
The target network is used to compute stable targets for training.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
actions: A [batch_size, num_action_dims] tensor representing a batch
of actions.
Returns:
q values: A [batch_size] tensor of q values.
Raises:
ValueError: If `states` or `actions' do not have the expected dimensions.
"""
self._validate_states(states)
self._validate_actions(actions)
values1 = tf.stop_gradient(
self._target_critic_net(states, actions,
for_critic_loss=for_critic_loss))
values2 = tf.stop_gradient(
self._target_critic_net2(states, actions,
for_critic_loss=for_critic_loss))
if for_critic_loss:
return values1, values2
return values1
def value_net(self, states, for_critic_loss=False):
"""Returns the output of the critic evaluated with the actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
actions = self.actor_net(states)
return self.critic_net(states, actions,
for_critic_loss=for_critic_loss)
def target_value_net(self, states, for_critic_loss=False):
"""Returns the output of the target critic evaluated with the target actor.
Args:
states: A [batch_size, num_state_dims] tensor representing a batch
of states.
Returns:
q values: A [batch_size] tensor of q values.
"""
target_actions = self.target_actor_net(states)
noise = tf.clip_by_value(
tf.random_normal(tf.shape(target_actions), stddev=0.2), -0.5, 0.5)
values1, values2 = self.target_critic_net(
states, target_actions + noise,
for_critic_loss=for_critic_loss)
values = tf.minimum(values1, values2)
return values, values
@gin.configurable('td3_update_targets')
def update_targets(self, tau=1.0):
"""Performs a soft update of the target network parameters.
For each weight w_s in the actor/critic networks, and its corresponding
weight w_t in the target actor/critic networks, a soft update is:
w_t = (1- tau) x w_t + tau x ws
Args:
tau: A float scalar in [0, 1]
Returns:
An operation that performs a soft update of the target network parameters.
Raises:
ValueError: If `tau` is not in [0, 1].
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
update_actor = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
tau)
# NOTE: This updates both critic networks.
update_critic = utils.soft_variables_update(
slim.get_trainable_variables(
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
slim.get_trainable_variables(
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
tau)
return tf.group(update_actor, update_critic, name='update_targets')
def gen_debug_td_error_summaries(
target_q_values, q_values, td_targets, td_errors):
"""Generates debug summaries for critic given a set of batch samples.
Args:
target_q_values: set of predicted next stage values.
q_values: current predicted value for the critic network.
td_targets: discounted target_q_values with added next stage reward.
td_errors: the different between td_targets and q_values.
"""
with tf.name_scope('td_errors'):
tf.summary.histogram('td_targets', td_targets)
tf.summary.histogram('q_values', q_values)
tf.summary.histogram('target_q_values', target_q_values)
tf.summary.histogram('td_errors', td_errors)
with tf.name_scope('td_targets'):
tf.summary.scalar('mean', tf.reduce_mean(td_targets))
tf.summary.scalar('max', tf.reduce_max(td_targets))
tf.summary.scalar('min', tf.reduce_min(td_targets))
with tf.name_scope('q_values'):
tf.summary.scalar('mean', tf.reduce_mean(q_values))
tf.summary.scalar('max', tf.reduce_max(q_values))
tf.summary.scalar('min', tf.reduce_min(q_values))
with tf.name_scope('target_q_values'):
tf.summary.scalar('mean', tf.reduce_mean(target_q_values))
tf.summary.scalar('max', tf.reduce_max(target_q_values))
tf.summary.scalar('min', tf.reduce_min(target_q_values))
with tf.name_scope('td_errors'):
tf.summary.scalar('mean', tf.reduce_mean(td_errors))
tf.summary.scalar('max', tf.reduce_max(td_errors))
tf.summary.scalar('min', tf.reduce_min(td_errors))
tf.summary.scalar('mean_abs', tf.reduce_mean(tf.abs(td_errors)))