Skip to content

Commit 88c6732

Browse files
committed
Feat: modify ACRL2 for chunk_size
1 parent 3b8185d commit 88c6732

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

rsl_rl/modules/actor_critic_rl2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def __init__(
5050

5151
# ActorCritic base expects num_actor_obs = rnn_hidden_dim
5252
super().__init__(
53-
num_actor_obs=rnn_hidden_dim + num_actor_obs, # 现在的实验条件默认critic_obs=actor_obs
54-
num_critic_obs=rnn_hidden_dim + num_critic_obs,
53+
num_actor_obs=num_actor_obs, # 现在的实验条件默认critic_obs=actor_obs
54+
num_critic_obs=num_critic_obs,
5555
num_actions=num_actions,
5656
actor_hidden_dims=actor_hidden_dims,
5757
critic_hidden_dims=critic_hidden_dims,
@@ -111,7 +111,7 @@ def act(self, observations, prev_actions, masks=None, hidden_states=None):
111111
input_a = torch.cat([observations, prev_actions], dim=-1)
112112
input_a = self.memory_a(input_a, masks, hidden_states)
113113
mlp_a_input = torch.cat([input_a.squeeze(0), observations], dim=-1)
114-
return super().act(mlp_a_input)
114+
return super().act(observations)
115115

116116
# 脚本训练过程用不到,应该不影响训练,暂时不修改
117117
def act_inference(self, observations, prev_actions):
@@ -124,7 +124,7 @@ def evaluate(self, critic_observations, prev_action, masks=None, hidden_states=N
124124
# actor和critic共用一个RNN
125125
input_c = self.memory_a(input_c, masks, hidden_states)
126126
mlp_c_input = torch.cat([input_c.squeeze(0), critic_observations], dim=-1)
127-
return super().evaluate(mlp_c_input)
127+
return super().evaluate(critic_observations)
128128

129129
# # 我们改成critic和actor使用同一个RNN,输入相同context和obs拼接
130130
# def evaluate(self, observations, prev_actions, masks=None, hidden_states=None):

0 commit comments

Comments
 (0)