-
Notifications
You must be signed in to change notification settings - Fork 397
Description
Description
In the PPO.update()
method in rsl_rl/algorithms/ppo.py
, there appears to be an issue with how the batch size is calculated when using symmetric data augmentation with RNN models.
Problem
rsl_rl/rsl_rl/algorithms/ppo.py
Line 218 in 8363520
original_batch_size = obs_batch.batch_size[0] |
rsl_rl/rsl_rl/algorithms/ppo.py
Line 237 in 8363520
num_aug = int(obs_batch.batch_size[0] / original_batch_size) |
When using RNN models, the obs_batch
returned from recurrent_mini_batch_generator
is a TensorDict with dimensions [time_steps, batch_size, obs_dim]
. The batch_size
attribute of this TensorDict is [time_steps, batch_size]
, so:
obs_batch.batch_size[0]
returns the number of time steps, not the batch sizeobs_batch.batch_size[1]
returns the actual batch size
However, the current implementation always uses obs_batch.batch_size[0]
, which means for RNN models it's using the time step count instead of the batch size, leading to incorrect calculations of num_aug
(number of augmentations).