diff --git a/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py b/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py index 8456f305d..8037875ff 100644 --- a/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/configuration_tdmpc2.py @@ -102,8 +102,8 @@ class TDMPC2Config: """ # Input / output structure. - n_action_repeats: int = 2 - horizon: int = 5 + n_action_repeats: int = 1 + horizon: int = 3 n_action_steps: int = 1 input_shapes: dict[str, list[int]] = field( @@ -128,7 +128,7 @@ class TDMPC2Config: # Neural networks. image_encoder_hidden_dim: int = 32 state_encoder_hidden_dim: int = 256 - latent_dim: int = 50 + latent_dim: int = 512 q_ensemble_size: int = 5 mlp_dim: int = 512 # Reinforcement learning. @@ -137,39 +137,34 @@ class TDMPC2Config: # actor log_std_min: float = -10 log_std_max: float = 2 - entropy_coef: float = 1e-4 # critic num_bins: int = 101 vmin: int = -10 vmax: int = +10 - rho: float = 0.5 - tau: float = 0.01 # Inference. use_mpc: bool = True cem_iterations: int = 6 max_std: float = 2.0 min_std: float = 0.05 n_gaussian_samples: int = 512 - n_pi_samples: int = 51 - uncertainty_regularizer_coeff: float = 1.0 - n_elites: int = 50 + n_pi_samples: int = 24 + n_elites: int = 64 elite_weighting_temperature: float = 0.5 - gaussian_mean_momentum: float = 0.1 # Training and loss computation. max_random_shift_ratio: float = 0.0476 # Loss coefficients. reward_coeff: float = 0.1 - expectile_weight: float = 0.9 value_coeff: float = 0.1 consistency_coeff: float = 20.0 - advantage_scaling: float = 3.0 - pi_coeff: float = 0.5 + entropy_coef: float = 1e-4 temporal_decay_coeff: float = 0.5 - # Target model. - target_model_momentum: float = 0.995 + # Target model. NOTE (michel_aractingi) this is equivelant to + # 1 - target_model_momentum of our TD-MPC1 implementation because + # of the use of `torch.lerp` + target_model_momentum: float = 0.01 def __post_init__(self): """Input validation (not exhaustive).""" diff --git a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py index 27b3295da..e3127168b 100644 --- a/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py +++ b/lerobot/common/policies/tdmpc2/modeling_tdmpc2.py @@ -108,7 +108,7 @@ def __init__( if "observation.environment_state" in config.input_shapes: self._use_env_state = True - self.scale = RunningScale(self.config.tau) + self.scale = RunningScale(self.config.target_model_momentum) self.discount = self.config.discount #TODO (michel-aractingi) downscale discount according to episode length self.reset() @@ -249,19 +249,14 @@ def plan(self, z: Tensor) -> Tensor: score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) score /= score.sum(axis=0, keepdim=True) # (horizon, batch, action_dim) - _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) - _std = torch.sqrt( + mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (score.sum(0) + 1e-9) + std = torch.sqrt( torch.sum( einops.rearrange(score, "n b -> n b 1") - * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, + * (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2, dim=1, - ) - ) - # Update mean with an exponential moving average, and std with a direct replacement. - mean = ( - self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean - ) - std = _std.clamp_(self.config.min_std, self.config.max_std) + ) / (score.sum(0) + 1e-9) + ).clamp_(self.config.min_std, self.config.max_std) # Keep track of the mean for warm-starting subsequent steps. self._prev_mean = mean @@ -687,20 +682,20 @@ def __init__(self, config: TDMPC2Config): elif "observation.state" in config.input_shapes: encoder_module = nn.ModuleList() - encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim)) + encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) assert config.num_enc_layers > 0 for _ in range(config.num_enc_layers - 1): - encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim)) - encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) + encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) + encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) encoder_module = nn.Sequential(*encoder_module) elif "observation.environment_state" in config.input_shapes: encoder_module = nn.ModuleList() - encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.enc_dim)) + encoder_module.append(NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)) assert config.num_enc_layers > 0 for _ in range(config.num_enc_layers - 1): - encoder_module.append(NormedLinear(config.enc_dim, config.enc_dim)) - encoder_module.append(NormedLinear(config.enc_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) + encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)) + encoder_module.append(NormedLinear(config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim))) encoder_module = nn.Sequential(*encoder_module) else: