diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 188fad8c..7c318e01 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -11,7 +11,7 @@ from itertools import chain from tensordict import TensorDict -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent +from rsl_rl.modules import ActorCritic, ActorCriticPerceptive, ActorCriticRecurrent from rsl_rl.modules.rnd import RandomNetworkDistillation from rsl_rl.storage import RolloutStorage from rsl_rl.utils import string_to_callable @@ -20,12 +20,12 @@ class PPO: """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" - policy: ActorCritic | ActorCriticRecurrent + policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive """The actor critic module.""" def __init__( self, - policy: ActorCritic | ActorCriticRecurrent, + policy: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive, num_learning_epochs: int = 5, num_mini_batches: int = 4, clip_param: float = 0.2, diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index efb8613a..7803aa08 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -6,6 +6,7 @@ """Definitions for neural-network components for RL-agents.""" from .actor_critic import ActorCritic +from .actor_critic_perceptive import ActorCriticPerceptive from .actor_critic_recurrent import ActorCriticRecurrent from .rnd import RandomNetworkDistillation, resolve_rnd_config from .student_teacher import StudentTeacher @@ -14,6 +15,7 @@ __all__ = [ "ActorCritic", + "ActorCriticPerceptive", "ActorCriticRecurrent", "RandomNetworkDistillation", "StudentTeacher", diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py index 9f01b2f4..da55e704 100644 --- a/rsl_rl/modules/actor_critic.py +++ b/rsl_rl/modules/actor_critic.py @@ -49,9 +49,8 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The ActorCritic module only supports 1D observations." num_critic_obs += obs[obs_group].shape[-1] - self.state_dependent_std = state_dependent_std - # Actor + self.state_dependent_std = state_dependent_std if self.state_dependent_std: self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation) else: @@ -121,7 +120,7 @@ def action_std(self) -> torch.Tensor: def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def _update_distribution(self, obs: TensorDict) -> None: + def _update_distribution(self, obs: torch.Tensor) -> None: if self.state_dependent_std: # Compute mean and standard deviation mean_and_std = self.actor(obs) diff --git a/rsl_rl/modules/actor_critic_perceptive.py b/rsl_rl/modules/actor_critic_perceptive.py new file mode 100644 index 00000000..7d693223 --- /dev/null +++ b/rsl_rl/modules/actor_critic_perceptive.py @@ -0,0 +1,255 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +import torch.nn as nn +from tensordict import TensorDict +from torch.distributions import Normal +from typing import Any + +from rsl_rl.networks import CNN, MLP, EmpiricalNormalization + +from .actor_critic import ActorCritic + + +class ActorCriticPerceptive(ActorCritic): + def __init__( + self, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: list[int] = [256, 256, 256], + critic_hidden_dims: list[int] = [256, 256, 256], + actor_cnn_cfg: dict[str, dict] | dict | None = None, + critic_cnn_cfg: dict[str, dict] | dict | None = None, + activation: str = "elu", + init_noise_std: float = 1.0, + noise_std_type: str = "scalar", + state_dependent_std: bool = False, + **kwargs: dict[str, Any], + ) -> None: + if kwargs: + print( + "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs]) + ) + nn.Module.__init__(self) + + # Get the observation dimensions + self.obs_groups = obs_groups + num_actor_obs = 0 + num_actor_in_channels = [] + self.actor_obs_groups_1d = [] + self.actor_obs_groups_2d = [] + for obs_group in obs_groups["policy"]: + if len(obs[obs_group].shape) == 4: # B, C, H, W + self.actor_obs_groups_2d.append(obs_group) + num_actor_in_channels.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C + self.actor_obs_groups_1d.append(obs_group) + num_actor_obs += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + num_critic_obs = 0 + num_critic_in_channels = [] + self.critic_obs_groups_1d = [] + self.critic_obs_groups_2d = [] + for obs_group in obs_groups["critic"]: + if len(obs[obs_group].shape) == 4: # B, C, H, W + self.critic_obs_groups_2d.append(obs_group) + num_critic_in_channels.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C + self.critic_obs_groups_1d.append(obs_group) + num_critic_obs += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + + # Actor CNN + if self.actor_obs_groups_2d: + assert actor_cnn_cfg is not None, "An actor CNN configuration is required for 2D actor observations." + + # Check if multiple 2D actor observations are provided + if len(self.actor_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_cfg.values()): + assert len(actor_cnn_cfg) == len(self.actor_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D actor observations." + ) + elif len(self.actor_obs_groups_2d) > 1: + print( + "Only one CNN configuration for multiple 2D actor observations given, using the same configuration " + "for all groups." + ) + actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg] * len(self.actor_obs_groups_2d))) + else: + actor_cnn_cfg = dict(zip(self.actor_obs_groups_2d, [actor_cnn_cfg])) + + # Create CNNs for each 2D actor observation + self.actor_cnns = nn.ModuleDict() + encoding_dims = [] + for idx, obs_group in enumerate(self.actor_obs_groups_2d): + self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_cfg[obs_group]) + print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") + + # Compute the encoding dimension (cpu necessary as model not moved to device yet) + encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) + encoding_dim = sum(encoding_dims) + else: + self.actor_cnns = None + encoding_dim = 0 + + # Actor MLP + self.state_dependent_std = state_dependent_std + if self.state_dependent_std: + self.actor = MLP(num_actor_obs + encoding_dim, [2, num_actions], actor_hidden_dims, activation) + else: + self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) + print(f"Actor MLP: {self.actor}") + + # Actor observation normalization (only for 1D actor observations) + self.actor_obs_normalization = actor_obs_normalization + if actor_obs_normalization: + self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) + else: + self.actor_obs_normalizer = torch.nn.Identity() + + # Critic CNN + if self.critic_obs_groups_2d: + assert critic_cnn_cfg is not None, " A critic CNN configuration is required for 2D critic observations." + + # check if multiple 2D critic observations are provided + if len(self.critic_obs_groups_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_cfg.values()): + assert len(critic_cnn_cfg) == len(self.critic_obs_groups_2d), ( + "The number of CNN configurations must match the number of 2D critic observations." + ) + elif len(self.critic_obs_groups_2d) > 1: + print( + "Only one CNN configuration for multiple 2D critic observations given, using the same configuration" + " for all groups." + ) + critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg] * len(self.critic_obs_groups_2d))) + else: + critic_cnn_cfg = dict(zip(self.critic_obs_groups_2d, [critic_cnn_cfg])) + + # Create CNNs for each 2D critic observation + self.critic_cnns = nn.ModuleDict() + encoding_dims = [] + for idx, obs_group in enumerate(self.critic_obs_groups_2d): + self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_cfg[obs_group]) + print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") + + # Compute the encoding dimension (cpu necessary as model not moved to device yet) + encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) + encoding_dim = sum(encoding_dims) + else: + self.critic_cnns = None + encoding_dim = 0 + + # Critic MLP + self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) + print(f"Critic MLP: {self.critic}") + + # Critic observation normalization (only for 1D critic observations) + self.critic_obs_normalization = critic_obs_normalization + if critic_obs_normalization: + self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) + else: + self.critic_obs_normalizer = torch.nn.Identity() + + # Action noise + self.noise_std_type = noise_std_type + if self.state_dependent_std: + torch.nn.init.zeros_(self.actor[-2].weight[num_actions:]) + if self.noise_std_type == "scalar": + torch.nn.init.constant_(self.actor[-2].bias[num_actions:], init_noise_std) + elif self.noise_std_type == "log": + torch.nn.init.constant_( + self.actor[-2].bias[num_actions:], torch.log(torch.tensor(init_noise_std + 1e-7)) + ) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + else: + if self.noise_std_type == "scalar": + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + elif self.noise_std_type == "log": + self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + + # Action distribution + # Note: Populated in update_distribution + self.distribution = None + + # Disable args validation for speedup + Normal.set_default_validate_args(False) + + def _update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]) -> None: + if self.actor_cnns is not None: + # Encode the 2D actor observations + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # Concatenate to the MLP observations + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + super()._update_distribution(mlp_obs) + + def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + self._update_distribution(mlp_obs, cnn_obs) + return self.distribution.sample() + + def act_inference(self, obs: TensorDict) -> torch.Tensor: + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + + if self.actor_cnns is not None: + # Encode the 2D actor observations + cnn_enc_list = [self.actor_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.actor_obs_groups_2d] + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # Concatenate to the MLP observations + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + if self.state_dependent_std: + return self.actor(obs)[..., 0, :] + else: + return self.actor(mlp_obs) + + def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: + mlp_obs, cnn_obs = self.get_critic_obs(obs) + mlp_obs = self.critic_obs_normalizer(mlp_obs) + + if self.critic_cnns is not None: + # Encode the 2D critic observations + cnn_enc_list = [self.critic_cnns[obs_group](cnn_obs[obs_group]) for obs_group in self.critic_obs_groups_2d] + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # Concatenate to the MLP observations + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + return self.critic(mlp_obs) + + def get_actor_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + obs_list_1d = [obs[obs_group] for obs_group in self.actor_obs_groups_1d] + obs_dict_2d = {} + for obs_group in self.actor_obs_groups_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def get_critic_obs(self, obs: TensorDict) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + obs_list_1d = [obs[obs_group] for obs_group in self.critic_obs_groups_1d] + obs_dict_2d = {} + for obs_group in self.critic_obs_groups_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def update_normalization(self, obs: TensorDict) -> None: + if self.actor_obs_normalization: + actor_obs, _ = self.get_actor_obs(obs) + self.actor_obs_normalizer.update(actor_obs) + if self.critic_obs_normalization: + critic_obs, _ = self.get_critic_obs(obs) + self.critic_obs_normalizer.update(critic_obs) diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py index 0930702a..2fcc8b96 100644 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -61,9 +61,8 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The ActorCriticRecurrent module only supports 1D observations." num_critic_obs += obs[obs_group].shape[-1] - self.state_dependent_std = state_dependent_std - # Actor + self.state_dependent_std = state_dependent_std self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) if self.state_dependent_std: self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation) @@ -138,7 +137,7 @@ def reset(self, dones: torch.Tensor | None = None) -> None: def forward(self) -> NoReturn: raise NotImplementedError - def _update_distribution(self, obs: TensorDict) -> None: + def _update_distribution(self, obs: torch.Tensor) -> None: if self.state_dependent_std: # Compute mean and standard deviation mean_and_std = self.actor(obs) diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 106a3c53..e95b2622 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -5,11 +5,13 @@ """Definitions for components of modules.""" +from .cnn import CNN from .memory import Memory from .mlp import MLP from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization __all__ = [ + "CNN", "MLP", "EmpiricalDiscountedVariationNormalization", "EmpiricalNormalization", diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py new file mode 100644 index 00000000..feffb255 --- /dev/null +++ b/rsl_rl/networks/cnn.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +from torch import nn as nn + +from rsl_rl.utils import resolve_nn_activation + + +class CNN(nn.Sequential): + """Convolutional Neural Network (CNN). + + The CNN network is a sequence of convolutional layers, optional batch normalization, activation functions, and + optional max pooling. The final output can be flattened or pooled depending on the configuration. + """ + + def __init__( + self, + in_channels: int, + activation: str, + out_channels: list[int], + kernel_size: list[tuple[int, int]] | tuple[int, int], + stride: list[int] | int = 1, + flatten: bool = True, + avg_pool: tuple[int, int] | None = None, + batchnorm: bool | list[bool] = False, + max_pool: bool | list[bool] = False, + ) -> None: + """Initialize the CNN. + + Args: + in_channels: Number of input channels. + activation: Activation function to use. + out_channels: List of output channels for each convolutional layer. + kernel_size: List of kernel sizes for each convolutional layer or a single kernel size for all layers. + stride: List of strides for each convolutional layer or a single stride for all layers. + flatten: Whether to flatten the output tensor. + avg_pool: If specified, applies an adaptive average pooling to the given output size after the convolutions. + batchnorm: Whether to apply batch normalization after each convolutional layer. + max_pool: Whether to apply max pooling after each convolutional layer. + + .. note:: + Do not save config to allow for the model to be jit compiled. + """ + super().__init__() + + # If parameters are not lists, convert them to lists + if isinstance(batchnorm, bool): + batchnorm = [batchnorm] * len(out_channels) + if isinstance(max_pool, bool): + max_pool = [max_pool] * len(out_channels) + if isinstance(kernel_size, tuple): + kernel_size = [kernel_size] * len(out_channels) + if isinstance(stride, int): + stride = [stride] * len(out_channels) + + # Resolve activation function + activation_function = resolve_nn_activation(activation) + + # Create layers sequentially + layers = [] + for idx in range(len(out_channels)): + in_channels = in_channels if idx == 0 else out_channels[idx - 1] + layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels[idx], + kernel_size=kernel_size[idx], + stride=stride[idx], + ) + ) + if batchnorm[idx]: + layers.append(nn.BatchNorm2d(num_features=out_channels[idx])) + layers.append(activation_function) + if max_pool[idx]: + layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # Register the layers + for idx, layer in enumerate(layers): + self.add_module(f"{idx}", layer) + + # Add avgpool if specified + if avg_pool is not None: + self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) + else: + self.avgpool = None + + # Save flatten flag for forward function + self.flatten = flatten + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self: + x = layer(x) + if self.flatten: + x = x.flatten(start_dim=1) + elif self.avgpool is not None: + x = self.avgpool(x) + x = x.flatten(start_dim=1) + return x + + def init_weights(self) -> None: + """Initialize the weights of the CNN with Xavier initialization.""" + for idx, module in enumerate(self): + if isinstance(module, nn.Conv2d): + nn.init.xavier_uniform_(module.weight) diff --git a/rsl_rl/networks/memory.py b/rsl_rl/networks/memory.py index ef6ffc53..0029a3a7 100644 --- a/rsl_rl/networks/memory.py +++ b/rsl_rl/networks/memory.py @@ -12,9 +12,9 @@ class Memory(nn.Module): - """Memory module for recurrent networks. + """Memory network for recurrent architectures. - This module is used to store the hidden states of the policy. It currently only supports GRU and LSTM. + This network is used to store the hidden states of the policy. It currently only supports GRU and LSTM. """ def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None: diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 46a9b524..2b0d7664 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -16,7 +16,13 @@ import rsl_rl from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.modules import ( + ActorCritic, + ActorCriticPerceptive, + ActorCriticRecurrent, + resolve_rnd_config, + resolve_symmetry_config, +) from rsl_rl.utils import resolve_obs_groups, store_code_state @@ -414,7 +420,7 @@ def _construct_algorithm(self, obs: TensorDict) -> PPO: # Initialize the policy actor_critic_class = eval(self.policy_cfg.pop("class_name")) - actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class( + actor_critic: ActorCritic | ActorCriticRecurrent | ActorCriticPerceptive = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device)