-
Notifications
You must be signed in to change notification settings - Fork 398
Adds perceptive actor-critic class #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pascal-roth
wants to merge
3
commits into
main
Choose a base branch
from
feature/perceptive-nav-rl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+357
−8
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from __future__ import annotations | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributions import Normal | ||
|
||
from rsl_rl.networks import CNN, MLP, EmpiricalNormalization | ||
|
||
from .actor_critic import ActorCritic | ||
|
||
|
||
class PerceptiveActorCritic(ActorCritic): | ||
def __init__( # noqa: C901 | ||
self, | ||
obs, | ||
obs_groups, | ||
num_actions, | ||
actor_obs_normalization: bool = False, | ||
critic_obs_normalization: bool = False, | ||
actor_hidden_dims: list[int] = [256, 256, 256], | ||
critic_hidden_dims: list[int] = [256, 256, 256], | ||
actor_cnn_config: dict[str, dict] | dict | None = None, | ||
critic_cnn_config: dict[str, dict] | dict | None = None, | ||
activation: str = "elu", | ||
init_noise_std: float = 1.0, | ||
noise_std_type: str = "scalar", | ||
**kwargs, | ||
): | ||
if kwargs: | ||
print( | ||
"PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: " | ||
+ str([key for key in kwargs.keys()]) | ||
) | ||
nn.Module.__init__(self) | ||
|
||
# get the observation dimensions | ||
self.obs_groups = obs_groups | ||
num_actor_obs = 0 | ||
num_actor_in_channels = [] | ||
self.actor_obs_group_1d = [] | ||
self.actor_obs_group_2d = [] | ||
for obs_group in obs_groups["policy"]: | ||
if len(obs[obs_group].shape) == 4: # B, C, H, W | ||
self.actor_obs_group_2d.append(obs_group) | ||
num_actor_in_channels.append(obs[obs_group].shape[1]) | ||
elif len(obs[obs_group].shape) == 2: # B, C | ||
self.actor_obs_group_1d.append(obs_group) | ||
num_actor_obs += obs[obs_group].shape[-1] | ||
else: | ||
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") | ||
|
||
self.critic_obs_group_1d = [] | ||
self.critic_obs_group_2d = [] | ||
num_critic_obs = 0 | ||
num_critic_in_channels = [] | ||
for obs_group in obs_groups["critic"]: | ||
if len(obs[obs_group].shape) == 4: # B, C, H, W | ||
self.critic_obs_group_2d.append(obs_group) | ||
num_critic_in_channels.append(obs[obs_group].shape[1]) | ||
elif len(obs[obs_group].shape) == 2: # B, C | ||
self.critic_obs_group_1d.append(obs_group) | ||
num_critic_obs += obs[obs_group].shape[-1] | ||
else: | ||
raise ValueError(f"Invalid observation shape for {obs_group}: {obs[obs_group].shape}") | ||
|
||
# actor cnn | ||
if self.actor_obs_group_2d: | ||
assert actor_cnn_config is not None, "Actor CNN config is required for 2D actor observations." | ||
|
||
# check if multiple 2D actor observations are provided | ||
if len(self.actor_obs_group_2d) > 1 and all(isinstance(item, dict) for item in actor_cnn_config.values()): | ||
assert len(actor_cnn_config) == len( | ||
self.actor_obs_group_2d | ||
), "Number of CNN configs must match number of 2D actor observations." | ||
elif len(self.actor_obs_group_2d) > 1: | ||
print( | ||
"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." | ||
) | ||
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config] * len(self.actor_obs_group_2d))) | ||
else: | ||
actor_cnn_config = dict(zip(self.actor_obs_group_2d, [actor_cnn_config])) | ||
|
||
self.actor_cnns = nn.ModuleDict() | ||
encoding_dims = [] | ||
for idx, obs_group in enumerate(self.actor_obs_group_2d): | ||
self.actor_cnns[obs_group] = CNN(num_actor_in_channels[idx], activation, **actor_cnn_config[obs_group]) | ||
print(f"Actor CNN for {obs_group}: {self.actor_cnns[obs_group]}") | ||
|
||
# compute the encoding dimension (cpu necessary as model not moved to device yet) | ||
encoding_dims.append(self.actor_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) | ||
|
||
encoding_dim = sum(encoding_dims) | ||
else: | ||
self.actor_cnns = None | ||
encoding_dim = 0 | ||
|
||
# actor mlp | ||
self.actor = MLP(num_actor_obs + encoding_dim, num_actions, actor_hidden_dims, activation) | ||
|
||
# actor observation normalization (only for 1D actor observations) | ||
self.actor_obs_normalization = actor_obs_normalization | ||
if actor_obs_normalization: | ||
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) | ||
else: | ||
self.actor_obs_normalizer = torch.nn.Identity() | ||
print(f"Actor MLP: {self.actor}") | ||
|
||
# critic cnn | ||
if self.critic_obs_group_2d: | ||
assert critic_cnn_config is not None, "Critic CNN config is required for 2D critic observations." | ||
|
||
# check if multiple 2D critic observations are provided | ||
if len(self.critic_obs_group_2d) > 1 and all(isinstance(item, dict) for item in critic_cnn_config.values()): | ||
assert len(critic_cnn_config) == len( | ||
self.critic_obs_group_2d | ||
), "Number of CNN configs must match number of 2D critic observations." | ||
elif len(self.critic_obs_group_2d) > 1: | ||
print( | ||
"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." | ||
) | ||
critic_cnn_config = dict( | ||
zip(self.critic_obs_group_2d, [critic_cnn_config] * len(self.critic_obs_group_2d)) | ||
) | ||
else: | ||
critic_cnn_config = dict(zip(self.critic_obs_group_2d, [critic_cnn_config])) | ||
|
||
self.critic_cnns = nn.ModuleDict() | ||
encoding_dims = [] | ||
for idx, obs_group in enumerate(self.critic_obs_group_2d): | ||
self.critic_cnns[obs_group] = CNN( | ||
num_critic_in_channels[idx], activation, **critic_cnn_config[obs_group] | ||
) | ||
print(f"Critic CNN for {obs_group}: {self.critic_cnns[obs_group]}") | ||
|
||
# compute the encoding dimension (cpu necessary as model not moved to device yet) | ||
encoding_dims.append(self.critic_cnns[obs_group](obs[obs_group].to("cpu")).shape[-1]) | ||
|
||
encoding_dim = sum(encoding_dims) | ||
else: | ||
self.critic_cnns = None | ||
encoding_dim = 0 | ||
|
||
# critic mlp | ||
self.critic = MLP(num_critic_obs + encoding_dim, 1, critic_hidden_dims, activation) | ||
|
||
# critic observation normalization (only for 1D critic observations) | ||
self.critic_obs_normalization = critic_obs_normalization | ||
if critic_obs_normalization: | ||
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) | ||
else: | ||
self.critic_obs_normalizer = torch.nn.Identity() | ||
print(f"Critic MLP: {self.critic}") | ||
|
||
# Action noise | ||
self.noise_std_type = noise_std_type | ||
if self.noise_std_type == "scalar": | ||
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) | ||
elif self.noise_std_type == "log": | ||
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions))) | ||
else: | ||
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") | ||
|
||
# Action distribution (populated in update_distribution) | ||
self.distribution: Normal = None | ||
# disable args validation for speedup | ||
Normal.set_default_validate_args(False) | ||
|
||
def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Tensor]): | ||
|
||
if self.actor_cnns is not None: | ||
# encode the 2D actor observations | ||
cnn_enc_list = [] | ||
for obs_group in self.actor_obs_group_2d: | ||
cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) | ||
cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
# update mlp obs | ||
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
||
super().update_distribution(mlp_obs) | ||
|
||
def act(self, obs, **kwargs): | ||
mlp_obs, cnn_obs = self.get_actor_obs(obs) | ||
mlp_obs = self.actor_obs_normalizer(mlp_obs) | ||
self.update_distribution(mlp_obs, cnn_obs) | ||
return self.distribution.sample() | ||
|
||
def act_inference(self, obs): | ||
mlp_obs, cnn_obs = self.get_actor_obs(obs) | ||
mlp_obs = self.actor_obs_normalizer(mlp_obs) | ||
|
||
if self.actor_cnns is not None: | ||
# encode the 2D actor observations | ||
cnn_enc_list = [] | ||
for obs_group in self.actor_obs_group_2d: | ||
cnn_enc_list.append(self.actor_cnns[obs_group](cnn_obs[obs_group])) | ||
cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
# update mlp obs | ||
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
||
return self.actor(mlp_obs) | ||
|
||
def evaluate(self, obs, **kwargs): | ||
mlp_obs, cnn_obs = self.get_critic_obs(obs) | ||
mlp_obs = self.critic_obs_normalizer(mlp_obs) | ||
|
||
if self.critic_cnns is not None: | ||
# encode the 2D critic observations | ||
cnn_enc_list = [] | ||
for obs_group in self.critic_obs_group_2d: | ||
cnn_enc_list.append(self.critic_cnns[obs_group](cnn_obs[obs_group])) | ||
cnn_enc = torch.cat(cnn_enc_list, dim=-1) | ||
# update mlp obs | ||
mlp_obs = torch.cat([mlp_obs, cnn_enc], dim=-1) | ||
|
||
return self.critic(mlp_obs) | ||
|
||
def get_actor_obs(self, obs): | ||
obs_list_1d = [] | ||
obs_dict_2d = {} | ||
for obs_group in self.actor_obs_group_1d: | ||
obs_list_1d.append(obs[obs_group]) | ||
for obs_group in self.actor_obs_group_2d: | ||
obs_dict_2d[obs_group] = obs[obs_group] | ||
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d | ||
|
||
def get_critic_obs(self, obs): | ||
obs_list_1d = [] | ||
obs_dict_2d = {} | ||
for obs_group in self.critic_obs_group_1d: | ||
obs_list_1d.append(obs[obs_group]) | ||
for obs_group in self.critic_obs_group_2d: | ||
obs_dict_2d[obs_group] = obs[obs_group] | ||
return torch.cat(obs_list_1d, dim=-1), obs_dict_2d | ||
|
||
def update_normalization(self, obs): | ||
if self.actor_obs_normalization: | ||
actor_obs, _ = self.get_actor_obs(obs) | ||
self.actor_obs_normalizer.update(actor_obs) | ||
if self.critic_obs_normalization: | ||
critic_obs, _ = self.get_critic_obs(obs) | ||
self.critic_obs_normalizer.update(critic_obs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION | ||
# All rights reserved. | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from __future__ import annotations | ||
|
||
import torch | ||
from torch import nn as nn | ||
|
||
from rsl_rl.utils import resolve_nn_activation | ||
|
||
|
||
class CNN(nn.Sequential): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a small description of the architecture here, similar to the mlp? """Multi-layer perceptron.
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
activation: str, | ||
out_channels: list[int], | ||
kernel_size: list[tuple[int, int]] | tuple[int, int], | ||
stride: list[int] | int = 1, | ||
flatten: bool = True, | ||
avg_pool: tuple[int, int] | None = None, | ||
batchnorm: bool | list[bool] = False, | ||
max_pool: bool | list[bool] = False, | ||
): | ||
""" | ||
Convolutional Neural Network model. | ||
|
||
.. note:: | ||
Do not save config to allow for the model to be jit compiled. | ||
""" | ||
super().__init__() | ||
|
||
if isinstance(batchnorm, bool): | ||
batchnorm = [batchnorm] * len(out_channels) | ||
if isinstance(max_pool, bool): | ||
max_pool = [max_pool] * len(out_channels) | ||
if isinstance(kernel_size, tuple): | ||
kernel_size = [kernel_size] * len(out_channels) | ||
if isinstance(stride, int): | ||
stride = [stride] * len(out_channels) | ||
|
||
# get activation function | ||
activation_function = resolve_nn_activation(activation) | ||
|
||
# build model layers | ||
layers = [] | ||
|
||
for idx in range(len(out_channels)): | ||
in_channels = in_channels if idx == 0 else out_channels[idx - 1] | ||
layers.append( | ||
nn.Conv2d( | ||
in_channels=in_channels, | ||
out_channels=out_channels[idx], | ||
kernel_size=kernel_size[idx], | ||
stride=stride[idx], | ||
) | ||
) | ||
if batchnorm[idx]: | ||
layers.append(nn.BatchNorm2d(num_features=out_channels[idx])) | ||
layers.append(activation_function) | ||
if max_pool[idx]: | ||
layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | ||
|
||
# register the layers | ||
for idx, layer in enumerate(layers): | ||
self.add_module(f"{idx}", layer) | ||
|
||
if avg_pool is not None: | ||
self.avgpool = nn.AdaptiveAvgPool2d(avg_pool) | ||
else: | ||
self.avgpool = None | ||
|
||
# save flatten config for forward function | ||
self.flatten = flatten | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
for layer in self: | ||
x = layer(x) | ||
if self.flatten: | ||
x = x.flatten(start_dim=1) | ||
elif self.avgpool is not None: | ||
x = self.avgpool(x) | ||
x = x.flatten(start_dim=1) | ||
return x | ||
|
||
def init_weights(self, scales: float | tuple[float]): | ||
"""Initialize the weights of the CNN.""" | ||
|
||
# initialize the weights | ||
for idx, module in enumerate(self): | ||
if isinstance(module, nn.Conv2d): | ||
nn.init.xavier_uniform_(module.weight) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to add the state-dependent action noise that is now part of the actor critic module.