|
40 | 40 |
|
41 | 41 | """
|
42 | 42 |
|
43 |
| -import argparse |
44 | 43 | import multiprocessing
|
45 | 44 | import threading
|
46 | 45 | import time
|
47 | 46 |
|
48 |
| -import gym |
49 | 47 | import numpy as np
|
50 | 48 | import tensorflow as tf
|
51 |
| -import tensorflow_probability as tfp |
52 | 49 |
|
| 50 | +import gym |
| 51 | +import tensorflow_probability as tfp |
53 | 52 | import tensorlayer as tl
|
54 | 53 | from common.buffer import *
|
55 | 54 | from common.networks import *
|
|
66 | 65 | class ACNet(object):
|
67 | 66 |
|
68 | 67 | def __init__(
|
69 |
| - self, scope, entropy_beta, action_dim, state_dim, actor_hidden_dim, actor_hidden_layer, critic_hidden_dim, |
70 |
| - critic_hidden_layer, action_bound, globalAC=None |
| 68 | + self, scope, entropy_beta, action_dim, state_dim, actor_hidden_dim, actor_hidden_layer, critic_hidden_dim, |
| 69 | + critic_hidden_layer, action_bound, globalAC=None |
71 | 70 | ):
|
72 | 71 | self.scope = scope # the scope is for naming networks for each worker differently
|
73 | 72 | self.save_path = './model'
|
@@ -107,7 +106,7 @@ def __init__(
|
107 | 106 |
|
108 | 107 | @tf.function # convert numpy functions to tf.Operations in the TFgraph, return tensor
|
109 | 108 | def update_global(
|
110 |
| - self, buffer_s, buffer_a, buffer_v_target, globalAC |
| 109 | + self, buffer_s, buffer_a, buffer_v_target, globalAC |
111 | 110 | ): # refer to the global Actor-Crtic network for updating it with samples
|
112 | 111 | ''' update the global critic '''
|
113 | 112 | with tf.GradientTape() as tape:
|
@@ -164,8 +163,8 @@ def load_ckpt(self): # load trained weights
|
164 | 163 | class Worker(object):
|
165 | 164 |
|
166 | 165 | def __init__(
|
167 |
| - self, env_id, name, globalAC, train_episodes, gamma, update_itr, entropy_beta, action_dim, state_dim, |
168 |
| - actor_hidden_dim, actor_hidden_layer, critic_hidden_dim, critic_hidden_layer, action_bound |
| 166 | + self, env_id, name, globalAC, train_episodes, gamma, update_itr, entropy_beta, action_dim, state_dim, |
| 167 | + actor_hidden_dim, actor_hidden_layer, critic_hidden_dim, critic_hidden_layer, action_bound |
169 | 168 | ):
|
170 | 169 | self.env = make_env(env_id)
|
171 | 170 | self.name = name
|
@@ -242,9 +241,9 @@ def work(self, globalAC):
|
242 | 241 |
|
243 | 242 |
|
244 | 243 | def learn(
|
245 |
| - env_id, train_episodes, test_episodes=1000, max_steps=150, number_workers=0, update_itr=10, gamma=0.99, |
246 |
| - entropy_beta=0.005, actor_lr=5e-5, critic_lr=1e-4, actor_hidden_dim=300, actor_hidden_layer=2, |
247 |
| - critic_hidden_dim=300, critic_hidden_layer=2, seed=2, save_interval=500, mode='train' |
| 244 | + env_id, train_episodes, test_episodes=1000, max_steps=150, number_workers=0, update_itr=10, gamma=0.99, |
| 245 | + entropy_beta=0.005, actor_lr=5e-5, critic_lr=1e-4, actor_hidden_dim=300, actor_hidden_layer=2, |
| 246 | + critic_hidden_dim=300, critic_hidden_layer=2, seed=2, save_interval=500, mode='train' |
248 | 247 | ):
|
249 | 248 | '''
|
250 | 249 | parameters
|
|
0 commit comments