From d46bc2ea80ccf7f2b9ba5419a4102065c80a31ba Mon Sep 17 00:00:00 2001 From: Pascal Roth Date: Sat, 13 Sep 2025 18:41:06 +0200 Subject: [PATCH 1/3] add files for perceptive example --- rsl_rl/modules/__init__.py | 2 + rsl_rl/modules/actor_critic.py | 12 +- rsl_rl/modules/perceptive_actor_critic.py | 236 ++++++++++++++++++++++ rsl_rl/networks/__init__.py | 1 + rsl_rl/networks/cnn.py | 94 +++++++++ 5 files changed, 339 insertions(+), 6 deletions(-) create mode 100644 rsl_rl/modules/perceptive_actor_critic.py create mode 100644 rsl_rl/networks/cnn.py diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index 0d23fff6..474624da 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -7,6 +7,7 @@ from .actor_critic import ActorCritic from .actor_critic_recurrent import ActorCriticRecurrent +from .perceptive_actor_critic import PerceptiveActorCritic from .rnd import * from .student_teacher import StudentTeacher from .student_teacher_recurrent import StudentTeacherRecurrent @@ -15,6 +16,7 @@ __all__ = [ "ActorCritic", "ActorCriticRecurrent", + "PerceptiveActorCritic", "StudentTeacher", "StudentTeacherRecurrent", ] diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py index 2d659547..ee993ea2 100644 --- a/rsl_rl/modules/actor_critic.py +++ b/rsl_rl/modules/actor_critic.py @@ -20,12 +20,12 @@ def __init__( obs, obs_groups, num_actions, - actor_obs_normalization=False, - critic_obs_normalization=False, - actor_hidden_dims=[256, 256, 256], - critic_hidden_dims=[256, 256, 256], - activation="elu", - init_noise_std=1.0, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: list[int] = [256, 256, 256], + critic_hidden_dims: list[int] = [256, 256, 256], + activation: str = "elu", + init_noise_std: float = 1.0, noise_std_type: str = "scalar", **kwargs, ): diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/perceptive_actor_critic.py new file mode 100644 index 00000000..ff270645 --- /dev/null +++ b/rsl_rl/modules/perceptive_actor_critic.py @@ -0,0 +1,236 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributions import Normal + +from .actor_critic import ActorCritic + +from rsl_rl.networks import MLP, CNN, CNNConfig, EmpiricalNormalization + + +class PerceptiveActorCritic(ActorCritic): + def __init__( + self, + obs, + obs_groups, + num_actions, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: list[int] = [256, 256, 256], + critic_hidden_dims: list[int] = [256, 256, 256], + actor_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, + critic_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, + activation: str = "elu", + init_noise_std: float = 1.0, + noise_std_type: str = "scalar", + **kwargs, + ): + if kwargs: + print( + "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs.keys()]) + ) + nn.Module.__init__(self) + + # get the observation dimensions + self.obs_groups = obs_groups + num_actor_obs = 0 + num_actor_in_channels = [] + self.actor_obs_group_1d = [] + self.actor_obs_group_2d = [] + for obs_group in obs_groups["policy"]: + if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + self.actor_obs_group_2d.append(obs_group) + num_actor_in_channels.append(obs[obs_group].shape[0]) + elif len(obs[obs_group].shape) == 1: + self.actor_obs_group_1d.append(obs_group) + num_actor_obs += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") + + self.critic_obs_group_1d = [] + self.critic_obs_group_2d = [] + num_critic_obs = 0 + num_critic_in_channels = [] + for obs_group in obs_groups["critic"]: + if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + self.critic_obs_group_2d.append(obs_group) + num_critic_in_channels.append(obs[obs_group].shape[0]) + else: + self.critic_obs_group_1d.append(obs_group) + num_critic_obs += obs[obs_group].shape[-1] + + # actor cnn + if self.actor_obs_group_2d: + assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." + + # check if multiple 2D actor observations are provided + if len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, CNNConfig): + print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.") + actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) + elif len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, dict): + assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." + elif len(self.actor_obs_group_2d) == 1 and isinstance(actor_cnn_config, CNNConfig): + actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) + else: + raise ValueError(f"Invalid combination of 2D actor observations {self.actor_obs_group_2d} and actor CNN config {actor_cnn_config}.") + + self.actor_cnns = {} + encoding_dims = [] + for idx, obs_group in enumerate(self.actor_obs_group_2d): + self.actor_cnns[obs_group] = CNN(actor_cnn_config[obs_group], num_actor_in_channels[idx], activation) + print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") + + # compute the encoding dimension + encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group]).shape[-1]) + + encoding_dim = sum(encoding_dims) + else: + self.actor_cnns = None + encoding_dim = 0 + + # actor mlp + self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) + + # actor observation normalization (only for 1D actor observations) + self.actor_obs_normalization = actor_obs_normalization + if actor_obs_normalization: + self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) + else: + self.actor_obs_normalizer = torch.nn.Identity() + print(f"Actor MLP: {self.actor}") + + # critic cnn + if self.critic_obs_group_2d: + assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." + + # check if multiple 2D critic observations are provided + if len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, CNNConfig): + print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.") + critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))) + elif len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, dict): + assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." + elif len(self.critic_obs_group_2d) == 1 and isinstance(critic_cnn_config, CNNConfig): + critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) + else: + raise ValueError(f"Invalid combination of 2D critic observations {self.critic_obs_group_2d} and critic CNN config {critic_cnn_config}.") + + self.critic_cnns = {} + encoding_dims = [] + for idx, obs_group in enumerate(self.critic_obs_group_2d): + self.critic_cnns[obs_group] = CNN(critic_cnn_config[obs_group], num_critic_in_channels[idx], activation) + print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") + + # compute the encoding dimension + encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group]).shape[-1]) + + encoding_dim = sum(encoding_dims) + else: + self.critic_cnns = None + encoding_dim = 0 + + # critic mlp + self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) + + # critic observation normalization (only for 1D critic observations) + self.critic_obs_normalization = critic_obs_normalization + if critic_obs_normalization: + self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) + else: + self.critic_obs_normalizer = torch.nn.Identity() + print(f"Critic MLP: {self.critic}") + + # Action noise + self.noise_std_type = noise_std_type + if self.noise_std_type == "scalar": + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + elif self.noise_std_type == "log": + self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) + else: + raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") + + # Action distribution (populated in update_distribution) + self.distribution: Normal = None + # disable args validation for speedup + Normal.set_default_validate_args(False) + + def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): + + if self.actor_cnns is not None: + # encode the 2D actor observations + cnn_enc_list = [] + for obs_group in self.actor_obs_group_2d: + cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # update mlp obs + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + super().update_distribution(mlp_obs) + + def act(self, obs, **kwargs): + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + self.update_distribution(mlp_obs, cnn_obs) + return self.distribution.sample() + + def act_inference(self, obs): + mlp_obs, cnn_obs = self.get_actor_obs(obs) + mlp_obs = self.actor_obs_normalizer(mlp_obs) + + if self.actor_cnns is not None: + # encode the 2D actor observations + cnn_enc_list = [] + for obs_group in self.actor_obs_group_2d: + cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # update mlp obs + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + return self.actor(mlp_obs) + + def evaluate(self, obs, **kwargs): + mlp_obs, cnn_obs = self.get_critic_obs(obs) + mlp_obs = self.critic_obs_normalizer(mlp_obs) + + if self.critic_cnns is not None: + # encode the 2D critic observations + cnn_enc_list = [] + for obs_group in self.critic_obs_group_2d: + cnn_enc_list.append(self.critic_cnns[obs_group](cnn_obs[obs_group])) + cnn_enc = torch.cat(cnn_enc_list, dim=-1) + # update mlp obs + mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) + + return self.critic(mlp_obs) + + def get_actor_obs(self, obs): + obs_list_1d = [] + obs_dict_2d = {} + for obs_group in self.actor_obs_group_1d: + obs_list_1d.append(obs[obs_group]) + for obs_group in self.actor_obs_group_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def get_critic_obs(self, obs): + obs_list_1d = [] + obs_dict_2d = {} + for obs_group in self.critic_obs_group_1d: + obs_list_1d.append(obs[obs_group]) + for obs_group in self.critic_obs_group_2d: + obs_dict_2d[obs_group] = obs[obs_group] + return torch.cat(obs_list_1d, dim=-1), obs_dict_2d + + def update_normalization(self, obs): + if self.actor_obs_normalization: + actor_obs, _ = self.get_actor_obs(obs) + self.actor_obs_normalizer.update(actor_obs) + if self.critic_obs_normalization: + critic_obs, _ = self.get_critic_obs(obs) + self.critic_obs_normalizer.update(critic_obs) \ No newline at end of file diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index c18f487a..245c395a 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -7,4 +7,5 @@ from .memory import Memory from .mlp import MLP +from .cnn import CNN, CNNConfig from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py new file mode 100644 index 00000000..a1368635 --- /dev/null +++ b/rsl_rl/networks/cnn.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +from dataclasses import MISSING, dataclass +from torch import nn as nn + +from rsl_rl.utils import resolve_nn_activation + + +@dataclass +class CNNConfig: + out_channels: list[int] = MISSING + kernel_size: list[tuple[int, int]] | tuple[int, int] = MISSING + stride: list[int] | int = 1 + flatten: bool = True + avg_pool: tuple[int, int] | None = None + batchnorm: bool | list[bool] = False + max_pool: bool | list[bool] = False + + +class CNN(nn.Module): + def __init__(self, cfg: CNNConfig, in_channels: int, activation: str): + """ + Convolutional Neural Network model. + + .. note:: + Do not save config to allow for the model to be jit compiled. + """ + super().__init__() + + if isinstance(cfg.batchnorm, bool): + cfg.batchnorm = [cfg.batchnorm] * len(cfg.out_channels) + if isinstance(cfg.max_pool, bool): + cfg.max_pool = [cfg.max_pool] * len(cfg.out_channels) + if isinstance(cfg.kernel_size, tuple): + cfg.kernel_size = [cfg.kernel_size] * len(cfg.out_channels) + if isinstance(cfg.stride, int): + cfg.stride = [cfg.stride] * len(cfg.out_channels) + + # get activation function + activation_function = resolve_nn_activation(activation) + + # build model layers + modules = [] + + for idx in range(len(cfg.out_channels)): + in_channels = cfg.in_channels if idx == 0 else cfg.out_channels[idx - 1] + modules.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=cfg.out_channels[idx], + kernel_size=cfg.kernel_size[idx], + stride=cfg.stride[idx], + ) + ) + if cfg.batchnorm[idx]: + modules.append(nn.BatchNorm2d(num_features=cfg.out_channels[idx])) + modules.append(activation_function) + if cfg.max_pool[idx]: + modules.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + self.architecture = nn.Sequential(*modules) + + if cfg.avg_pool is not None: + self.avgpool = nn.AdaptiveAvgPool2d(cfg.avg_pool) + else: + self.avgpool = None + + # initialize weights + self.init_weights(self.architecture) + + # save flatten config for forward function + self.flatten = cfg.flatten + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.architecture(x) + if self.flatten: + x = x.flatten(start_dim=1) + elif self.avgpool is not None: + x = self.avgpool(x) + x = x.flatten(start_dim=1) + return x + + @staticmethod + def init_weights(sequential): + [ + torch.nn.init.xavier_uniform_(module.weight) + for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Conv2d)) + ] From 3f4d485ba378c9192800335155ebb3a20a483930 Mon Sep 17 00:00:00 2001 From: Pascal Roth Date: Tue, 16 Sep 2025 10:25:45 +0200 Subject: [PATCH 2/3] working training --- rsl_rl/modules/perceptive_actor_critic.py | 56 +++++++------- rsl_rl/networks/__init__.py | 2 +- rsl_rl/networks/cnn.py | 89 ++++++++++------------- rsl_rl/runners/on_policy_runner.py | 4 +- 4 files changed, 69 insertions(+), 82 deletions(-) diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/perceptive_actor_critic.py index ff270645..9e3e9790 100644 --- a/rsl_rl/modules/perceptive_actor_critic.py +++ b/rsl_rl/modules/perceptive_actor_critic.py @@ -11,7 +11,7 @@ from .actor_critic import ActorCritic -from rsl_rl.networks import MLP, CNN, CNNConfig, EmpiricalNormalization +from rsl_rl.networks import MLP, CNN, EmpiricalNormalization class PerceptiveActorCritic(ActorCritic): @@ -24,8 +24,8 @@ def __init__( critic_obs_normalization: bool = False, actor_hidden_dims: list[int] = [256, 256, 256], critic_hidden_dims: list[int] = [256, 256, 256], - actor_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, - critic_cnn_config: dict[str, CNNConfig] | CNNConfig | None = None, + actor_cnn_config: dict[str, dict] | dict | None = None, + critic_cnn_config: dict[str, dict] | dict | None = None, activation: str = "elu", init_noise_std: float = 1.0, noise_std_type: str = "scalar", @@ -45,10 +45,10 @@ def __init__( self.actor_obs_group_1d = [] self.actor_obs_group_2d = [] for obs_group in obs_groups["policy"]: - if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + if len(obs[obs_group].shape) == 4: # B, C, H, W self.actor_obs_group_2d.append(obs_group) - num_actor_in_channels.append(obs[obs_group].shape[0]) - elif len(obs[obs_group].shape) == 1: + num_actor_in_channels.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C self.actor_obs_group_1d.append(obs_group) num_actor_obs += obs[obs_group].shape[-1] else: @@ -59,36 +59,36 @@ def __init__( num_critic_obs = 0 num_critic_in_channels = [] for obs_group in obs_groups["critic"]: - if len(obs[obs_group].shape) == 2: # FIXME: should be 3??? + if len(obs[obs_group].shape) == 4: # B, C, H, W self.critic_obs_group_2d.append(obs_group) - num_critic_in_channels.append(obs[obs_group].shape[0]) - else: + num_critic_in_channels.append(obs[obs_group].shape[1]) + elif len(obs[obs_group].shape) == 2: # B, C self.critic_obs_group_1d.append(obs_group) num_critic_obs += obs[obs_group].shape[-1] + else: + raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") # actor cnn if self.actor_obs_group_2d: assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." # check if multiple 2D actor observations are provided - if len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, CNNConfig): + if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): + assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." + elif len(self.actor_obs_group_2d) > 1: print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.") actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) - elif len(self.actor_obs_group_2d) > 1 and isinstance(actor_cnn_config, dict): - assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." - elif len(self.actor_obs_group_2d) == 1 and isinstance(actor_cnn_config, CNNConfig): - actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) else: - raise ValueError(f"Invalid combination of 2D actor observations {self.actor_obs_group_2d} and actor CNN config {actor_cnn_config}.") + actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) - self.actor_cnns = {} + self.actor_cnns = nn.ModuleDict() encoding_dims = [] for idx, obs_group in enumerate(self.actor_obs_group_2d): - self.actor_cnns[obs_group] = CNN(actor_cnn_config[obs_group], num_actor_in_channels[idx], activation) + self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_config[obs_group]) print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") - # compute the encoding dimension - encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group]).shape[-1]) + # compute the encoding dimension (cpu necessary as model not moved to device yet) + encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) encoding_dim = sum(encoding_dims) else: @@ -111,24 +111,22 @@ def __init__( assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." # check if multiple 2D critic observations are provided - if len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, CNNConfig): + if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): + assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." + elif len(self.critic_obs_group_2d) > 1: print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.") critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))) - elif len(self.critic_obs_group_2d) > 1 and isinstance(critic_cnn_config, dict): - assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." - elif len(self.critic_obs_group_2d) == 1 and isinstance(critic_cnn_config, CNNConfig): - critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) else: - raise ValueError(f"Invalid combination of 2D critic observations {self.critic_obs_group_2d} and critic CNN config {critic_cnn_config}.") + critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) - self.critic_cnns = {} + self.critic_cnns = nn.ModuleDict() encoding_dims = [] for idx, obs_group in enumerate(self.critic_obs_group_2d): - self.critic_cnns[obs_group] = CNN(critic_cnn_config[obs_group], num_critic_in_channels[idx], activation) + self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group]) print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") - # compute the encoding dimension - encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group]).shape[-1]) + # compute the encoding dimension (cpu necessary as model not moved to device yet) + encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) encoding_dim = sum(encoding_dims) else: diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 245c395a..366941d6 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -7,5 +7,5 @@ from .memory import Memory from .mlp import MLP -from .cnn import CNN, CNNConfig +from .cnn import CNN from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index a1368635..fbe06b07 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -6,25 +6,13 @@ from __future__ import annotations import torch -from dataclasses import MISSING, dataclass from torch import nn as nn from rsl_rl.utils import resolve_nn_activation -@dataclass -class CNNConfig: - out_channels: list[int] = MISSING - kernel_size: list[tuple[int, int]] | tuple[int, int] = MISSING - stride: list[int] | int = 1 - flatten: bool = True - avg_pool: tuple[int, int] | None = None - batchnorm: bool | list[bool] = False - max_pool: bool | list[bool] = False - - -class CNN(nn.Module): - def __init__(self, cfg: CNNConfig, in_channels: int, activation: str): +class CNN(nn.Sequential): + def __init__(self, in_channels: int, activation: str, out_channels: list[int], kernel_size: list[tuple[int, int]] | tuple[int, int], stride: list[int] | int = 1, flatten: bool = True, avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False): """ Convolutional Neural Network model. @@ -33,52 +21,52 @@ def __init__(self, cfg: CNNConfig, in_channels: int, activation: str): """ super().__init__() - if isinstance(cfg.batchnorm, bool): - cfg.batchnorm = [cfg.batchnorm] * len(cfg.out_channels) - if isinstance(cfg.max_pool, bool): - cfg.max_pool = [cfg.max_pool] * len(cfg.out_channels) - if isinstance(cfg.kernel_size, tuple): - cfg.kernel_size = [cfg.kernel_size] * len(cfg.out_channels) - if isinstance(cfg.stride, int): - cfg.stride = [cfg.stride] * len(cfg.out_channels) + if isinstance(batchnorm, bool): + batchnorm = [batchnorm] * len(out_channels) + if isinstance(max_pool, bool): + max_pool = [max_pool] * len(out_channels) + if isinstance(kernel_size, tuple): + kernel_size = [kernel_size] * len(out_channels) + if isinstance(stride, int): + stride = [stride] * len(out_channels) # get activation function activation_function = resolve_nn_activation(activation) # build model layers - modules = [] + layers = [] - for idx in range(len(cfg.out_channels)): - in_channels = cfg.in_channels if idx == 0 else cfg.out_channels[idx - 1] - modules.append( + for idx in range(len(out_channels)): + in_channels = in_channels if idx == 0 else out_channels[idx - 1] + layers.append( nn.Conv2d( in_channels=in_channels, - out_channels=cfg.out_channels[idx], - kernel_size=cfg.kernel_size[idx], - stride=cfg.stride[idx], + out_channels=out_channels[idx], + kernel_size=kernel_size[idx], + stride=stride[idx], ) ) - if cfg.batchnorm[idx]: - modules.append(nn.BatchNorm2d(num_features=cfg.out_channels[idx])) - modules.append(activation_function) - if cfg.max_pool[idx]: - modules.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) - - self.architecture = nn.Sequential(*modules) - - if cfg.avg_pool is not None: - self.avgpool = nn.AdaptiveAvgPool2d(cfg.avg_pool) + if batchnorm[idx]: + layers.append(nn.BatchNorm2d(num_features=out_channels[idx])) + layers.append(activation_function) + if max_pool[idx]: + layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + + # register the layers + for idx, layer in enumerate(layers): + self.add_module(f"{idx}", layer) + + if avg_pool is not None: + self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) else: self.avgpool = None - # initialize weights - self.init_weights(self.architecture) - # save flatten config for forward function - self.flatten = cfg.flatten + self.flatten = flatten def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.architecture(x) + for layer in self: + x = layer(x) if self.flatten: x = x.flatten(start_dim=1) elif self.avgpool is not None: @@ -86,9 +74,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.flatten(start_dim=1) return x - @staticmethod - def init_weights(sequential): - [ - torch.nn.init.xavier_uniform_(module.weight) - for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Conv2d)) - ] + def init_weights(self, scales: float | tuple[float]): + """Initialize the weights of the CNN.""" + + # initialize the weights + for idx, module in enumerate(self): + if isinstance(module, nn.Conv2d): + nn.init.xavier_uniform_(module.weight) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 36f11f37..3613aaf4 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -15,7 +15,7 @@ import rsl_rl from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config from rsl_rl.utils import resolve_obs_groups, store_code_state @@ -416,7 +416,7 @@ def _construct_algorithm(self, obs) -> PPO: # initialize the actor-critic actor_critic_class = eval(self.policy_cfg.pop("class_name")) - actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class( + actor_critic: ActorCritic | ActorCriticRecurrent | PerceptiveActorCritic = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) From 22cd1ec06a6fdf470eefbaa25c0f1501c995eb91 Mon Sep 17 00:00:00 2001 From: Pascal Roth Date: Tue, 16 Sep 2025 10:27:01 +0200 Subject: [PATCH 3/3] formatter --- rsl_rl/modules/perceptive_actor_critic.py | 56 ++++++++++++++--------- rsl_rl/networks/__init__.py | 2 +- rsl_rl/networks/cnn.py | 13 +++++- rsl_rl/runners/on_policy_runner.py | 8 +++- 4 files changed, 54 insertions(+), 25 deletions(-) diff --git a/rsl_rl/modules/perceptive_actor_critic.py b/rsl_rl/modules/perceptive_actor_critic.py index 9e3e9790..9862a6c6 100644 --- a/rsl_rl/modules/perceptive_actor_critic.py +++ b/rsl_rl/modules/perceptive_actor_critic.py @@ -9,13 +9,13 @@ import torch.nn as nn from torch.distributions import Normal -from .actor_critic import ActorCritic +from rsl_rl.networks import CNN, MLP, EmpiricalNormalization -from rsl_rl.networks import MLP, CNN, EmpiricalNormalization +from .actor_critic import ActorCritic class PerceptiveActorCritic(ActorCritic): - def __init__( + def __init__( # noqa: C901 self, obs, obs_groups, @@ -53,7 +53,7 @@ def __init__( num_actor_obs += obs[obs_group].shape[-1] else: raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") - + self.critic_obs_group_1d = [] self.critic_obs_group_2d = [] num_critic_obs = 0 @@ -71,12 +71,16 @@ def __init__( # actor cnn if self.actor_obs_group_2d: assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." - + # check if multiple 2D actor observations are provided if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): - assert len(actor_cnn_config) == len(self.actor_obs_group_2d), "Number of CNN configs must match number of 2D actor observations." + assert len(actor_cnn_config) == len( + self.actor_obs_group_2d + ), "Number of CNN configs must match number of 2D actor observations." elif len(self.actor_obs_group_2d) > 1: - print(f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups.") + print( + "Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." + ) actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) else: actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) @@ -89,7 +93,7 @@ def __init__( # compute the encoding dimension (cpu necessary as model not moved to device yet) encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - + encoding_dim = sum(encoding_dims) else: self.actor_cnns = None @@ -97,7 +101,7 @@ def __init__( # actor mlp self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) - + # actor observation normalization (only for 1D actor observations) self.actor_obs_normalization = actor_obs_normalization if actor_obs_normalization: @@ -109,25 +113,33 @@ def __init__( # critic cnn if self.critic_obs_group_2d: assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." - + # check if multiple 2D critic observations are provided if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): - assert len(critic_cnn_config) == len(self.critic_obs_group_2d), "Number of CNN configs must match number of 2D critic observations." + assert len(critic_cnn_config) == len( + self.critic_obs_group_2d + ), "Number of CNN configs must match number of 2D critic observations." elif len(self.critic_obs_group_2d) > 1: - print(f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups.") - critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d))) + print( + "Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." + ) + critic_cnn_config = dict( + zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)) + ) else: critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) self.critic_cnns = nn.ModuleDict() encoding_dims = [] for idx, obs_group in enumerate(self.critic_obs_group_2d): - self.critic_cnns[obs_group] = CNN(num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group]) + self.critic_cnns[obs_group] = CNN( + num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group] + ) print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") # compute the encoding dimension (cpu necessary as model not moved to device yet) encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) - + encoding_dim = sum(encoding_dims) else: self.critic_cnns = None @@ -135,7 +147,7 @@ def __init__( # critic mlp self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) - + # critic observation normalization (only for 1D critic observations) self.critic_obs_normalization = critic_obs_normalization if critic_obs_normalization: @@ -159,7 +171,7 @@ def __init__( Normal.set_default_validate_args(False) def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): - + if self.actor_cnns is not None: # encode the 2D actor observations cnn_enc_list = [] @@ -168,7 +180,7 @@ def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Te cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - + super().update_distribution(mlp_obs) def act(self, obs, **kwargs): @@ -180,7 +192,7 @@ def act(self, obs, **kwargs): def act_inference(self, obs): mlp_obs, cnn_obs = self.get_actor_obs(obs) mlp_obs = self.actor_obs_normalizer(mlp_obs) - + if self.actor_cnns is not None: # encode the 2D actor observations cnn_enc_list = [] @@ -189,7 +201,7 @@ def act_inference(self, obs): cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - + return self.actor(mlp_obs) def evaluate(self, obs, **kwargs): @@ -204,7 +216,7 @@ def evaluate(self, obs, **kwargs): cnn_enc = torch.cat(cnn_enc_list, dim=-1) # update mlp obs mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) - + return self.critic(mlp_obs) def get_actor_obs(self, obs): @@ -231,4 +243,4 @@ def update_normalization(self, obs): self.actor_obs_normalizer.update(actor_obs) if self.critic_obs_normalization: critic_obs, _ = self.get_critic_obs(obs) - self.critic_obs_normalizer.update(critic_obs) \ No newline at end of file + self.critic_obs_normalizer.update(critic_obs) diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index 366941d6..860f5298 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -5,7 +5,7 @@ """Definitions for components of modules.""" +from .cnn import CNN from .memory import Memory from .mlp import MLP -from .cnn import CNN from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization diff --git a/rsl_rl/networks/cnn.py b/rsl_rl/networks/cnn.py index fbe06b07..49b576ed 100644 --- a/rsl_rl/networks/cnn.py +++ b/rsl_rl/networks/cnn.py @@ -12,7 +12,18 @@ class CNN(nn.Sequential): - def __init__(self, in_channels: int, activation: str, out_channels: list[int], kernel_size: list[tuple[int, int]] | tuple[int, int], stride: list[int] | int = 1, flatten: bool = True, avg_pool: tuple[int, int] | None = None, batchnorm: bool | list[bool] = False, max_pool: bool | list[bool] = False): + def __init__( + self, + in_channels: int, + activation: str, + out_channels: list[int], + kernel_size: list[tuple[int, int]] | tuple[int, int], + stride: list[int] | int = 1, + flatten: bool = True, + avg_pool: tuple[int, int] | None = None, + batchnorm: bool | list[bool] = False, + max_pool: bool | list[bool] = False, + ): """ Convolutional Neural Network model. diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 3613aaf4..5ae5981f 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -15,7 +15,13 @@ import rsl_rl from rsl_rl.algorithms import PPO from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, PerceptiveActorCritic, resolve_rnd_config, resolve_symmetry_config +from rsl_rl.modules import ( + ActorCritic, + ActorCriticRecurrent, + PerceptiveActorCritic, + resolve_rnd_config, + resolve_symmetry_config, +) from rsl_rl.utils import resolve_obs_groups, store_code_state