From e57b2ffba47b093772f366e6fb5e6553f4962c20 Mon Sep 17 00:00:00 2001 From: kaixi287 Date: Thu, 13 Jun 2024 15:49:15 +0200 Subject: [PATCH 1/3] Change dones to bool mask for hidden states reset --- rsl_rl/modules/actor_critic_recurrent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py index 6321ec51..433d7ada 100644 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -92,6 +92,8 @@ def forward(self, input, masks=None, hidden_states=None): return out def reset(self, dones=None): + if dones is not None: + dones = dones.bool() # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state for hidden_state in self.hidden_states: hidden_state[..., dones, :] = 0.0 From 4287ee5a01417b129879c8ea6951c7a98e3da95e Mon Sep 17 00:00:00 2001 From: kaixi287 Date: Thu, 8 Aug 2024 15:07:24 +0200 Subject: [PATCH 2/3] Add option for other arguments in PPO class --- rsl_rl/algorithms/ppo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index ffe58147..a589b767 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -30,6 +30,7 @@ def __init__( schedule="fixed", desired_kl=0.01, device="cpu", + **kwargs, ): self.device = device From ddcc99ede7784ccd143c1817f6308ee385fffb7b Mon Sep 17 00:00:00 2001 From: kaixi287 Date: Thu, 8 Aug 2024 15:10:02 +0200 Subject: [PATCH 3/3] Revert "Add option for other arguments in PPO class" This reverts commit 4287ee5a01417b129879c8ea6951c7a98e3da95e. --- rsl_rl/algorithms/ppo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index a589b767..ffe58147 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -30,7 +30,6 @@ def __init__( schedule="fixed", desired_kl=0.01, device="cpu", - **kwargs, ): self.device = device