Skip to content

Commit

Permalink
updated configuration parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
michel-aractingi committed Nov 22, 2024
1 parent 3198464 commit 166c1fc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 32 deletions.
25 changes: 10 additions & 15 deletions lerobot/common/policies/tdmpc2/configuration_tdmpc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class TDMPC2Config:
"""

# Input / output structure.
n_action_repeats: int = 2
horizon: int = 5
n_action_repeats: int = 1
horizon: int = 3
n_action_steps: int = 1

input_shapes: dict[str, list[int]] = field(
Expand All @@ -128,7 +128,7 @@ class TDMPC2Config:
# Neural networks.
image_encoder_hidden_dim: int = 32
state_encoder_hidden_dim: int = 256
latent_dim: int = 50
latent_dim: int = 512
q_ensemble_size: int = 5
mlp_dim: int = 512
# Reinforcement learning.
Expand All @@ -137,39 +137,34 @@ class TDMPC2Config:
# actor
log_std_min: float = -10
log_std_max: float = 2
entropy_coef: float = 1e-4

# critic
num_bins: int = 101
vmin: int = -10
vmax: int = +10

rho: float = 0.5
tau: float = 0.01
# Inference.
use_mpc: bool = True
cem_iterations: int = 6
max_std: float = 2.0
min_std: float = 0.05
n_gaussian_samples: int = 512
n_pi_samples: int = 51
uncertainty_regularizer_coeff: float = 1.0
n_elites: int = 50
n_pi_samples: int = 24
n_elites: int = 64
elite_weighting_temperature: float = 0.5
gaussian_mean_momentum: float = 0.1

# Training and loss computation.
max_random_shift_ratio: float = 0.0476
# Loss coefficients.
reward_coeff: float = 0.1
expectile_weight: float = 0.9
value_coeff: float = 0.1
consistency_coeff: float = 20.0
advantage_scaling: float = 3.0
pi_coeff: float = 0.5
entropy_coef: float = 1e-4
temporal_decay_coeff: float = 0.5
# Target model.
target_model_momentum: float = 0.995
# Target model. NOTE (michel_aractingi) this is equivelant to
# 1 - target_model_momentum of our TD-MPC1 implementation because
# of the use of `torch.lerp`
target_model_momentum: float = 0.01

def __post_init__(self):
"""Input validation (not exhaustive)."""
Expand Down
29 changes: 12 additions & 17 deletions lerobot/common/policies/tdmpc2/modeling_tdmpc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True

self.scale = RunningScale(self.config.tau)
self.scale = RunningScale(self.config.target_model_momentum)
self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length

self.reset()
Expand Down Expand Up @@ -249,19 +249,14 @@ def plan(self, z: Tensor) -> Tensor:
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_std = torch.sqrt(
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9)
std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
) / (score.sum(0) + 1e-9)
).clamp_(self.config.min_std, self.config.max_std)

# Keep track of the mean for warm-starting subsequent steps.
self._prev_mean = mean
Expand Down Expand Up @@ -687,20 +682,20 @@ def __init__(self, config: TDMPC2Config):

elif "observation.state" in config.input_shapes:
encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module)

elif "observation.environment_state" in config.input_shapes:
encoder_module = nn.ModuleList()
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim))
encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim))
assert config.num_enc_layers > 0
for _ in range(config.num_enc_layers - 1):
encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim))
encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim))
encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)))
encoder_module = nn.Sequential(*encoder_module)

else:
Expand Down

0 comments on commit 166c1fc

Please sign in to comment.