@@ -50,8 +50,8 @@ def __init__(
5050
5151 # ActorCritic base expects num_actor_obs = rnn_hidden_dim
5252 super ().__init__ (
53- num_actor_obs = rnn_hidden_dim + num_actor_obs , # 现在的实验条件默认critic_obs=actor_obs
54- num_critic_obs = rnn_hidden_dim + num_critic_obs ,
53+ num_actor_obs = num_actor_obs , # 现在的实验条件默认critic_obs=actor_obs
54+ num_critic_obs = num_critic_obs ,
5555 num_actions = num_actions ,
5656 actor_hidden_dims = actor_hidden_dims ,
5757 critic_hidden_dims = critic_hidden_dims ,
@@ -111,7 +111,7 @@ def act(self, observations, prev_actions, masks=None, hidden_states=None):
111111 input_a = torch .cat ([observations , prev_actions ], dim = - 1 )
112112 input_a = self .memory_a (input_a , masks , hidden_states )
113113 mlp_a_input = torch .cat ([input_a .squeeze (0 ), observations ], dim = - 1 )
114- return super ().act (mlp_a_input )
114+ return super ().act (observations )
115115
116116 # 脚本训练过程用不到,应该不影响训练,暂时不修改
117117 def act_inference (self , observations , prev_actions ):
@@ -124,7 +124,7 @@ def evaluate(self, critic_observations, prev_action, masks=None, hidden_states=N
124124 # actor和critic共用一个RNN
125125 input_c = self .memory_a (input_c , masks , hidden_states )
126126 mlp_c_input = torch .cat ([input_c .squeeze (0 ), critic_observations ], dim = - 1 )
127- return super ().evaluate (mlp_c_input )
127+ return super ().evaluate (critic_observations )
128128
129129 # # 我们改成critic和actor使用同一个RNN,输入相同context和obs拼接
130130 # def evaluate(self, observations, prev_actions, masks=None, hidden_states=None):
0 commit comments