diff --git a/.flake8 b/.flake8
deleted file mode 100644
index 26574e0c..00000000
--- a/.flake8
+++ /dev/null
@@ -1,22 +0,0 @@
-[flake8]
-show-source=True
-statistics=True
-per-file-ignores=*/__init__.py:F401
-# E402: Module level import not at top of file
-# E501: Line too long
-# W503: Line break before binary operator
-# E203: Whitespace before ':' -> conflicts with black
-# D401: First line should be in imperative mood
-# R504: Unnecessary variable assignment before return statement.
-# R505: Unnecessary elif after return statement
-# SIM102: Use a single if-statement instead of nested if-statements
-# SIM117: Merge with statements for context managers that have same scope.
-ignore=E402,E501,W503,E203,D401,R504,R505,SIM102,SIM117
-max-line-length = 120
-max-complexity = 18
-exclude=_*,.vscode,.git,docs/**
-# docstrings
-docstring-convention=google
-# annotations
-suppress-none-returning=True
-allow-star-arg-any=True
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 74e58528..b40a9353 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,40 +1,21 @@
repos:
- - repo: https://github.com/python/black
- rev: 23.10.1
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.14.0
hooks:
- - id: black
- args: ["--line-length", "120", "--preview"]
- - repo: https://github.com/pycqa/flake8
- rev: 6.1.0
- hooks:
- - id: flake8
- additional_dependencies: [flake8-simplify, flake8-return]
+ - id: ruff-check
+ - id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- - id: trailing-whitespace
- id: check-symlinks
- id: destroyed-symlinks
- id: check-yaml
+ - id: check-toml
- id: check-merge-conflict
- id: check-case-conflict
- id: check-executables-have-shebangs
- - id: check-toml
- - id: end-of-file-fixer
- id: check-shebang-scripts-are-executable
- id: detect-private-key
- - id: debug-statements
- - repo: https://github.com/pycqa/isort
- rev: 5.12.0
- hooks:
- - id: isort
- name: isort (python)
- args: ["--profile", "black", "--filter-files"]
- - repo: https://github.com/asottile/pyupgrade
- rev: v3.15.0
- hooks:
- - id: pyupgrade
- args: ["--py37-plus"]
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index ca06f311..00692ba7 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -17,12 +17,14 @@ Please keep the lists sorted alphabetically.
---
-* Mayank Mittal
* Clemens Schwarke
+* Mayank Mittal
## Authors
+* Clemens Schwarke
* David Hoeller
+* Mayank Mittal
* Nikita Rudin
## Contributors
diff --git a/README.md b/README.md
index 0ea9bebb..d00385b7 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,14 @@
-# RSL RL
+# RSL-RL
-A fast and simple implementation of RL algorithms, designed to run fully on GPU.
-This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac Gym.
+A fast and simple implementation of learning algorithms for robotics. For an overview of the library please have a look at https://arxiv.org/pdf/2509.10771.
Environment repositories using the framework:
* **`Isaac Lab`** (built on top of NVIDIA Isaac Sim): https://github.com/isaac-sim/IsaacLab
-* **`Legged-Gym`** (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/
+* **`Legged Gym`** (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/
* **`MuJoCo Playground`** (built on top of MuJoCo MJX and Warp): https://github.com/google-deepmind/mujoco_playground/
-The main branch supports **PPO** and **Student-Teacher Distillation** with additional features from our research. These include:
+The library currently supports **PPO** and **Student-Teacher Distillation** with additional features from our research. These include:
* [Random Network Distillation (RND)](https://proceedings.mlr.press/v229/schwarke23a.html) - Encourages exploration by adding
a curiosity driven intrinsic reward.
@@ -22,8 +21,6 @@ information.
**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
**Contact**: cschwarke@ethz.ch
-> **Note:** The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more). However, it isn't currently actively maintained.
-
## Setup
@@ -57,8 +54,7 @@ For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napol
We use the following tools for maintaining code quality:
- [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase.
-- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter.
-- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker.
+- [ruff](https://github.com/astral-sh/ruff): An extremely fast Python linter and code formatter, written in Rust.
Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal:
diff --git a/config/example_config.yaml b/config/example_config.yaml
index 329cb1e5..00a9483f 100644
--- a/config/example_config.yaml
+++ b/config/example_config.yaml
@@ -1,21 +1,21 @@
runner:
class_name: OnPolicyRunner
- # -- general
- num_steps_per_env: 24 # number of steps per environment per iteration
- max_iterations: 1500 # number of policy updates
+ # General
+ num_steps_per_env: 24 # Number of steps per environment per iteration
+ max_iterations: 1500 # Number of policy updates
seed: 1
- # -- observations
- obs_groups: {"policy": ["policy"], "critic": ["policy", "privileged"]} # maps observation groups to types. See `vec_env.py` for more information
- # -- logging parameters
- save_interval: 50 # check for potential saves every `save_interval` iterations
+ # Observations
+ obs_groups: {"policy": ["policy"], "critic": ["policy", "privileged"]} # Maps observation groups to sets. See `vec_env.py` for more information
+ # Logging parameters
+ save_interval: 50 # Check for potential saves every `save_interval` iterations
experiment_name: walking_experiment
run_name: ""
- # -- logging writer
+ # Logging writer
logger: tensorboard # tensorboard, neptune, wandb
neptune_project: legged_gym
wandb_project: legged_gym
- # -- policy
+ # Policy
policy:
class_name: ActorCritic
activation: elu
@@ -25,45 +25,46 @@ runner:
critic_hidden_dims: [256, 256, 256]
init_noise_std: 1.0
noise_std_type: "scalar" # 'scalar' or 'log'
+ state_dependent_std: false
- # -- algorithm
+ # Algorithm
algorithm:
class_name: PPO
- # -- training
+ # Training
learning_rate: 0.001
num_learning_epochs: 5
num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches
schedule: adaptive # adaptive, fixed
- # -- value function
+ # Value function
value_loss_coef: 1.0
clip_param: 0.2
use_clipped_value_loss: true
- # -- surrogate loss
+ # Surrogate loss
desired_kl: 0.01
entropy_coef: 0.01
gamma: 0.99
lam: 0.95
max_grad_norm: 1.0
- # -- miscellaneous
+ # Miscellaneous
normalize_advantage_per_mini_batch: false
- # -- random network distillation
+ # Random network distillation
rnd_cfg:
- weight: 0.0 # initial weight of the RND reward
- weight_schedule: null # note: this is a dictionary with a required key called "mode". Please check the RND module for more information
- reward_normalization: false # whether to normalize RND reward
- # -- learning parameters
- learning_rate: 0.001 # learning rate for RND
- # -- network parameters
- num_outputs: 1 # number of outputs of RND network. Note: if -1, then the network will use dimensions of the observation
- predictor_hidden_dims: [-1] # hidden dimensions of predictor network
- target_hidden_dims: [-1] # hidden dimensions of target network
+ weight: 0.0 # Initial weight of the RND reward
+ weight_schedule: null # This is a dictionary with a required key called "mode". Please check the RND module for more information
+ reward_normalization: false # Whether to normalize RND reward
+ # Learning parameters
+ learning_rate: 0.001 # Learning rate for RND
+ # Network parameters
+ num_outputs: 1 # Number of outputs of RND network. Note: if -1, then the network will use dimensions of the observation
+ predictor_hidden_dims: [-1] # Hidden dimensions of predictor network
+ target_hidden_dims: [-1] # Hidden dimensions of target network
- # -- symmetry augmentation
+ # Symmetry augmentation
symmetry_cfg:
- use_data_augmentation: true # this adds symmetric trajectories to the batch
- use_mirror_loss: false # this adds symmetry loss term to the loss function
- data_augmentation_func: null # string containing the module and function name to import
+ use_data_augmentation: true # This adds symmetric trajectories to the batch
+ use_mirror_loss: false # This adds symmetry loss term to the loss function
+ data_augmentation_func: null # String containing the module and function name to import
# Example: "legged_gym.envs.locomotion.anymal_c.symmetry:get_symmetric_states"
#
# .. code-block:: python
@@ -73,4 +74,4 @@ runner:
# obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, obs_type: str = "policy"
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#
- mirror_loss_coeff: 0.0 #coefficient for symmetry loss term. If 0, no symmetry loss is used
+ mirror_loss_coeff: 0.0 # Coefficient for symmetry loss term. If 0, no symmetry loss is used
diff --git a/licenses/dependencies/black-license.txt b/licenses/dependencies/black-license.txt
deleted file mode 100644
index 7a9b891f..00000000
--- a/licenses/dependencies/black-license.txt
+++ /dev/null
@@ -1,21 +0,0 @@
-The MIT License (MIT)
-
-Copyright (c) 2018 Łukasz Langa
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/licenses/dependencies/flake8-license.txt b/licenses/dependencies/flake8-license.txt
deleted file mode 100644
index e5e3d6f9..00000000
--- a/licenses/dependencies/flake8-license.txt
+++ /dev/null
@@ -1,22 +0,0 @@
-== Flake8 License (MIT) ==
-
-Copyright (C) 2011-2013 Tarek Ziade
-Copyright (C) 2012-2016 Ian Cordasco
-
-Permission is hereby granted, free of charge, to any person obtaining a copy of
-this software and associated documentation files (the "Software"), to deal in
-the Software without restriction, including without limitation the rights to
-use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
-of the Software, and to permit persons to whom the Software is furnished to do
-so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/licenses/dependencies/isort-license.txt b/licenses/dependencies/isort-license.txt
deleted file mode 100644
index b5083a50..00000000
--- a/licenses/dependencies/isort-license.txt
+++ /dev/null
@@ -1,21 +0,0 @@
-The MIT License (MIT)
-
-Copyright (c) 2013 Timothy Edmund Crosley
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
diff --git a/licenses/dependencies/numpy_license.txt b/licenses/dependencies/numpy-license.txt
similarity index 100%
rename from licenses/dependencies/numpy_license.txt
rename to licenses/dependencies/numpy-license.txt
diff --git a/licenses/dependencies/pyupgrade-license.txt b/licenses/dependencies/pyupgrade-license.txt
deleted file mode 100644
index 522fbe20..00000000
--- a/licenses/dependencies/pyupgrade-license.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Copyright (c) 2017 Anthony Sottile
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
diff --git a/licenses/dependencies/ruff-license.txt b/licenses/dependencies/ruff-license.txt
new file mode 100644
index 00000000..d779ee9e
--- /dev/null
+++ b/licenses/dependencies/ruff-license.txt
@@ -0,0 +1,430 @@
+MIT License
+
+Copyright (c) 2022 Charles Marsh
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+end of terms and conditions
+
+The externally maintained libraries from which parts of the Software is derived
+are:
+
+- autoflake, licensed as follows:
+ """
+ Copyright (C) 2012-2018 Steven Myint
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
+ this software and associated documentation files (the "Software"), to deal in
+ the Software without restriction, including without limitation the rights to
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+ of the Software, and to permit persons to whom the Software is furnished to do
+ so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- autotyping, licensed as follows:
+ """
+ MIT License
+
+ Copyright (c) 2023 Jelle Zijlstra
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- Flake8, licensed as follows:
+ """
+ == Flake8 License (MIT) ==
+
+ Copyright (C) 2011-2013 Tarek Ziade
+ Copyright (C) 2012-2016 Ian Cordasco
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
+ this software and associated documentation files (the "Software"), to deal in
+ the Software without restriction, including without limitation the rights to
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+ of the Software, and to permit persons to whom the Software is furnished to do
+ so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- flake8-eradicate, licensed as follows:
+ """
+ MIT License
+
+ Copyright (c) 2018 Nikita Sobolev
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- flake8-pyi, licensed as follows:
+ """
+ The MIT License (MIT)
+
+ Copyright (c) 2016 Łukasz Langa
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- flake8-simplify, licensed as follows:
+ """
+ MIT License
+
+ Copyright (c) 2020 Martin Thoma
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- isort, licensed as follows:
+ """
+ The MIT License (MIT)
+
+ Copyright (c) 2013 Timothy Edmund Crosley
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ THE SOFTWARE.
+ """
+
+- pygrep-hooks, licensed as follows:
+ """
+ Copyright (c) 2018 Anthony Sottile
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ THE SOFTWARE.
+ """
+
+- pycodestyle, licensed as follows:
+ """
+ Copyright © 2006-2009 Johann C. Rocholl
+ Copyright © 2009-2014 Florent Xicluna
+ Copyright © 2014-2020 Ian Lee
+
+ Licensed under the terms of the Expat License
+
+ Permission is hereby granted, free of charge, to any person
+ obtaining a copy of this software and associated documentation files
+ (the "Software"), to deal in the Software without restriction,
+ including without limitation the rights to use, copy, modify, merge,
+ publish, distribute, sublicense, and/or sell copies of the Software,
+ and to permit persons to whom the Software is furnished to do so,
+ subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- pydocstyle, licensed as follows:
+ """
+ Copyright (c) 2012 GreenSteam,
+
+ Copyright (c) 2014-2020 Amir Rachum,
+
+ Copyright (c) 2020 Sambhav Kothari,
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
+ this software and associated documentation files (the "Software"), to deal in
+ the Software without restriction, including without limitation the rights to
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+ of the Software, and to permit persons to whom the Software is furnished to do
+ so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- Pyflakes, licensed as follows:
+ """
+ Copyright 2005-2011 Divmod, Inc.
+ Copyright 2013-2014 Florent Xicluna
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files (the
+ "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to
+ the following conditions:
+
+ The above copyright notice and this permission notice shall be
+ included in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ """
+
+- Pyright, licensed as follows:
+ """
+ MIT License
+
+ Pyright - A static type checker for the Python language
+ Copyright (c) Microsoft Corporation. All rights reserved.
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE
+ """
+
+- pyupgrade, licensed as follows:
+ """
+ Copyright (c) 2017 Anthony Sottile
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ THE SOFTWARE.
+ """
+
+- rome/tools, licensed under the MIT license:
+ """
+ MIT License
+
+ Copyright (c) Rome Tools, Inc. and its affiliates.
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- RustPython, licensed as follows:
+ """
+ MIT License
+
+ Copyright (c) 2020 RustPython Team
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+- rust-analyzer/text-size, licensed under the MIT license:
+ """
+ Permission is hereby granted, free of charge, to any
+ person obtaining a copy of this software and associated
+ documentation files (the "Software"), to deal in the
+ Software without restriction, including without
+ limitation the rights to use, copy, modify, merge,
+ publish, distribute, sublicense, and/or sell copies of
+ the Software, and to permit persons to whom the Software
+ is furnished to do so, subject to the following
+ conditions:
+
+ The above copyright notice and this permission notice
+ shall be included in all copies or substantial portions
+ of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
+ ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
+ TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
+ SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
+ IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ DEALINGS IN THE SOFTWARE.
+ """
\ No newline at end of file
diff --git a/licenses/dependencies/tensordict.txt b/licenses/dependencies/tensordict-license.txt
similarity index 100%
rename from licenses/dependencies/tensordict.txt
rename to licenses/dependencies/tensordict-license.txt
diff --git a/licenses/dependencies/torch_license.txt b/licenses/dependencies/torch-license.txt
similarity index 100%
rename from licenses/dependencies/torch_license.txt
rename to licenses/dependencies/torch-license.txt
diff --git a/pyproject.toml b/pyproject.toml
index 1f1817d4..f9b10ce4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "rsl-rl-lib"
version = "3.1.1"
-keywords = ["reinforcement-learning", "isaac", "leggedrobotics", "rl-pytorch"]
+keywords = ["reinforcement-learning", "robotics"]
maintainers = [
{ name="Clemens Schwarke", email="cschwarke@ethz.ch" },
{ name="Mayank Mittal", email="mittalma@ethz.ch" },
@@ -17,10 +17,9 @@ authors = [
{ name="David Hoeller", email="holler.david78@gmail.com" },
]
description = "Fast and simple RL algorithms implemented in PyTorch"
-readme = { file = "README.md", content-type = "text/markdown"}
-license = { text = "BSD-3-Clause" }
-
-requires-python = ">=3.8"
+readme = { file = "README.md", content-type = "text/markdown" }
+license = "BSD-3-Clause"
+requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
@@ -45,51 +44,16 @@ include = ["rsl_rl*"]
[tool.setuptools.package-data]
"rsl_rl" = ["config/*", "licenses/*"]
-[tool.isort]
-
-py_version = 38
-line_length = 120
-group_by_package = true
-
-# Files to skip
-skip_glob = [".vscode/*"]
-
-# Order of imports
-sections = [
- "FUTURE",
- "STDLIB",
- "THIRDPARTY",
- "FIRSTPARTY",
- "LOCALFOLDER",
-]
-
-# Extra standard libraries considered as part of python (permissive licenses)
-extra_standard_library = [
- "numpy",
- "torch",
- "tensordict",
- "warp",
- "typing_extensions",
- "git",
-]
-# Imports from this repository
-known_first_party = "rsl_rl"
-
[tool.pyright]
-
include = ["rsl_rl"]
-
typeCheckingMode = "basic"
-pythonVersion = "3.8"
+pythonVersion = "3.9"
pythonPlatform = "Linux"
enableTypeIgnoreComments = true
-
# This is required as the CI pre-commit does not download the module (i.e. numpy, torch, prettytable)
-# Therefore, we have to ignore missing imports
reportMissingImports = "none"
-# This is required to ignore for type checks of modules with stubs missing.
-reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers
-
-reportGeneralTypeIssues = "none" # -> raises 218 errors (usage of literal MISSING in dataclasses)
-reportOptionalMemberAccess = "warning" # -> raises 8 errors
+# This is required to ignore type checks of modules with stubs missing.
+reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers
+reportGeneralTypeIssues = "none" # -> usage of literal MISSING in dataclasses
+reportOptionalMemberAccess = "warning"
reportPrivateUsage = "warning"
diff --git a/rsl_rl/algorithms/__init__.py b/rsl_rl/algorithms/__init__.py
index effcaa28..b1686144 100644
--- a/rsl_rl/algorithms/__init__.py
+++ b/rsl_rl/algorithms/__init__.py
@@ -3,7 +3,7 @@
#
# SPDX-License-Identifier: BSD-3-Clause
-"""Implementation of different RL agents."""
+"""Implementation of different learning algorithms."""
from .distillation import Distillation
from .ppo import PPO
diff --git a/rsl_rl/algorithms/distillation.py b/rsl_rl/algorithms/distillation.py
index 3a86e002..5da039a9 100644
--- a/rsl_rl/algorithms/distillation.py
+++ b/rsl_rl/algorithms/distillation.py
@@ -5,6 +5,7 @@
import torch
import torch.nn as nn
+from tensordict import TensorDict
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
from rsl_rl.storage import RolloutStorage
@@ -19,20 +20,21 @@ class Distillation:
def __init__(
self,
- policy,
- num_learning_epochs=1,
- gradient_length=15,
- learning_rate=1e-3,
- max_grad_norm=None,
- loss_type="mse",
- optimizer="adam",
- device="cpu",
+ policy: StudentTeacher | StudentTeacherRecurrent,
+ num_learning_epochs: int = 1,
+ gradient_length: int = 15,
+ learning_rate: float = 1e-3,
+ max_grad_norm: float | None = None,
+ loss_type: str = "mse",
+ optimizer: str = "adam",
+ device: str = "cpu",
# Distributed training parameters
multi_gpu_cfg: dict | None = None,
- ):
- # device-related parameters
+ ) -> None:
+ # Device-related parameters
self.device = device
self.is_multi_gpu = multi_gpu_cfg is not None
+
# Multi-GPU parameters
if multi_gpu_cfg is not None:
self.gpu_global_rank = multi_gpu_cfg["global_rank"]
@@ -41,25 +43,25 @@ def __init__(
self.gpu_global_rank = 0
self.gpu_world_size = 1
- # distillation components
+ # Distillation components
self.policy = policy
self.policy.to(self.device)
- self.storage = None # initialized later
+ self.storage = None # Initialized later
- # initialize the optimizer
+ # Initialize the optimizer
self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate)
- # initialize the transition
+ # Initialize the transition
self.transition = RolloutStorage.Transition()
- self.last_hidden_states = None
+ self.last_hidden_states = (None, None)
- # distillation parameters
+ # Distillation parameters
self.num_learning_epochs = num_learning_epochs
self.gradient_length = gradient_length
self.learning_rate = learning_rate
self.max_grad_norm = max_grad_norm
- # initialize the loss function
+ # Initialize the loss function
loss_fn_dict = {
"mse": nn.functional.mse_loss,
"huber": nn.functional.huber_loss,
@@ -71,8 +73,15 @@ def __init__(
self.num_updates = 0
- def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
- # create rollout storage
+ def init_storage(
+ self,
+ training_type: str,
+ num_envs: int,
+ num_transitions_per_env: int,
+ obs: TensorDict,
+ actions_shape: tuple[int],
+ ) -> None:
+ # Create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
@@ -82,27 +91,29 @@ def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, ac
self.device,
)
- def act(self, obs):
- # compute the actions
+ def act(self, obs: TensorDict) -> torch.Tensor:
+ # Compute the actions
self.transition.actions = self.policy.act(obs).detach()
self.transition.privileged_actions = self.policy.evaluate(obs).detach()
- # record the observations
+ # Record the observations
self.transition.observations = obs
return self.transition.actions
- def process_env_step(self, obs, rewards, dones, extras):
- # update the normalizers
+ def process_env_step(
+ self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor]
+ ) -> None:
+ # Update the normalizers
self.policy.update_normalization(obs)
- # record the rewards and dones
+ # Record the rewards and dones
self.transition.rewards = rewards
self.transition.dones = dones
- # record the transition
+ # Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.policy.reset(dones)
- def update(self):
+ def update(self) -> dict[str, float]:
self.num_updates += 1
mean_behavior_loss = 0
loss = 0
@@ -112,19 +123,18 @@ def update(self):
self.policy.reset(hidden_states=self.last_hidden_states)
self.policy.detach_hidden_states()
for obs, _, privileged_actions, dones in self.storage.generator():
-
- # inference the student for gradient computation
+ # Inference of the student for gradient computation
actions = self.policy.act_inference(obs)
- # behavior cloning loss
+ # Behavior cloning loss
behavior_loss = self.loss_fn(actions, privileged_actions)
- # total loss
+ # Total loss
loss = loss + behavior_loss
mean_behavior_loss += behavior_loss.item()
cnt += 1
- # gradient step
+ # Gradient step
if cnt % self.gradient_length == 0:
self.optimizer.zero_grad()
loss.backward()
@@ -136,7 +146,7 @@ def update(self):
self.policy.detach_hidden_states()
loss = 0
- # reset dones
+ # Reset dones
self.policy.reset(dones.view(-1))
self.policy.detach_hidden_states(dones.view(-1))
@@ -145,25 +155,21 @@ def update(self):
self.last_hidden_states = self.policy.get_hidden_states()
self.policy.detach_hidden_states()
- # construct the loss dictionary
+ # Construct the loss dictionary
loss_dict = {"behavior": mean_behavior_loss}
return loss_dict
- """
- Helper functions
- """
-
- def broadcast_parameters(self):
+ def broadcast_parameters(self) -> None:
"""Broadcast model parameters to all GPUs."""
- # obtain the model parameters on current GPU
+ # Obtain the model parameters on current GPU
model_params = [self.policy.state_dict()]
- # broadcast the model parameters
+ # Broadcast the model parameters
torch.distributed.broadcast_object_list(model_params, src=0)
- # load the model parameters on all GPUs from source GPU
+ # Load the model parameters on all GPUs from source GPU
self.policy.load_state_dict(model_params[0])
- def reduce_parameters(self):
+ def reduce_parameters(self) -> None:
"""Collect gradients from all GPUs and average them.
This function is called after the backward pass to synchronize the gradients across all GPUs.
@@ -179,7 +185,7 @@ def reduce_parameters(self):
for param in self.policy.parameters():
if param.grad is not None:
numel = param.numel()
- # copy data back from shared buffer
+ # Copy data back from shared buffer
param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
- # update the offset for the next parameter
+ # Update the offset for the next parameter
offset += numel
diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py
index 6c21fc54..44efe782 100644
--- a/rsl_rl/algorithms/ppo.py
+++ b/rsl_rl/algorithms/ppo.py
@@ -9,8 +9,9 @@
import torch.nn as nn
import torch.optim as optim
from itertools import chain
+from tensordict import TensorDict
-from rsl_rl.modules import ActorCritic
+from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
from rsl_rl.modules.rnd import RandomNetworkDistillation
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import string_to_callable
@@ -19,36 +20,37 @@
class PPO:
"""Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
- policy: ActorCritic
+ policy: ActorCritic | ActorCriticRecurrent
"""The actor critic module."""
def __init__(
self,
- policy,
- num_learning_epochs=5,
- num_mini_batches=4,
- clip_param=0.2,
- gamma=0.99,
- lam=0.95,
- value_loss_coef=1.0,
- entropy_coef=0.01,
- learning_rate=0.001,
- max_grad_norm=1.0,
- use_clipped_value_loss=True,
- schedule="adaptive",
- desired_kl=0.01,
- device="cpu",
- normalize_advantage_per_mini_batch=False,
+ policy: ActorCritic | ActorCriticRecurrent,
+ num_learning_epochs: int = 5,
+ num_mini_batches: int = 4,
+ clip_param: float = 0.2,
+ gamma: float = 0.99,
+ lam: float = 0.95,
+ value_loss_coef: float = 1.0,
+ entropy_coef: float = 0.01,
+ learning_rate: float = 0.001,
+ max_grad_norm: float = 1.0,
+ use_clipped_value_loss: bool = True,
+ schedule: str = "adaptive",
+ desired_kl: float = 0.01,
+ device: str = "cpu",
+ normalize_advantage_per_mini_batch: bool = False,
# RND parameters
rnd_cfg: dict | None = None,
# Symmetry parameters
symmetry_cfg: dict | None = None,
# Distributed training parameters
multi_gpu_cfg: dict | None = None,
- ):
- # device-related parameters
+ ) -> None:
+ # Device-related parameters
self.device = device
self.is_multi_gpu = multi_gpu_cfg is not None
+
# Multi-GPU parameters
if multi_gpu_cfg is not None:
self.gpu_global_rank = multi_gpu_cfg["global_rank"]
@@ -94,10 +96,12 @@ def __init__(
# PPO components
self.policy = policy
self.policy.to(self.device)
+
# Create optimizer
self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate)
+
# Create rollout storage
- self.storage: RolloutStorage = None # type: ignore
+ self.storage: RolloutStorage | None = None
self.transition = RolloutStorage.Transition()
# PPO parameters
@@ -115,8 +119,15 @@ def __init__(
self.learning_rate = learning_rate
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
- def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape):
- # create rollout storage
+ def init_storage(
+ self,
+ training_type: str,
+ num_envs: int,
+ num_transitions_per_env: int,
+ obs: TensorDict,
+ actions_shape: tuple[int] | list[int],
+ ) -> None:
+ # Create rollout storage
self.storage = RolloutStorage(
training_type,
num_envs,
@@ -126,27 +137,29 @@ def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, ac
self.device,
)
- def act(self, obs):
+ def act(self, obs: TensorDict) -> torch.Tensor:
if self.policy.is_recurrent:
self.transition.hidden_states = self.policy.get_hidden_states()
- # compute the actions and values
+ # Compute the actions and values
self.transition.actions = self.policy.act(obs).detach()
self.transition.values = self.policy.evaluate(obs).detach()
self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.policy.action_mean.detach()
self.transition.action_sigma = self.policy.action_std.detach()
- # need to record obs before env.step()
+ # Record observations before env.step()
self.transition.observations = obs
return self.transition.actions
- def process_env_step(self, obs, rewards, dones, extras):
- # update the normalizers
+ def process_env_step(
+ self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor]
+ ) -> None:
+ # Update the normalizers
self.policy.update_normalization(obs)
if self.rnd:
self.rnd.update_normalization(obs)
# Record the rewards and dones
- # Note: we clone here because later on we bootstrap the rewards based on timeouts
+ # Note: We clone here because later on we bootstrap the rewards based on timeouts
self.transition.rewards = rewards.clone()
self.transition.dones = dones
@@ -163,40 +176,34 @@ def process_env_step(self, obs, rewards, dones, extras):
self.transition.values * extras["time_outs"].unsqueeze(1).to(self.device), 1
)
- # record the transition
+ # Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.policy.reset(dones)
- def compute_returns(self, obs):
- # compute value for the last step
+ def compute_returns(self, obs: TensorDict) -> None:
+ # Compute value for the last step
last_values = self.policy.evaluate(obs).detach()
self.storage.compute_returns(
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
)
- def update(self): # noqa: C901
+ def update(self) -> dict[str, float]:
mean_value_loss = 0
mean_surrogate_loss = 0
mean_entropy = 0
- # -- RND loss
- if self.rnd:
- mean_rnd_loss = 0
- else:
- mean_rnd_loss = None
- # -- Symmetry loss
- if self.symmetry:
- mean_symmetry_loss = 0
- else:
- mean_symmetry_loss = None
+ # RND loss
+ mean_rnd_loss = 0 if self.rnd else None
+ # Symmetry loss
+ mean_symmetry_loss = 0 if self.symmetry else None
- # generator for mini batches
+ # Get mini batch generator
if self.policy.is_recurrent:
generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
- # iterate over batches
+ # Iterate over batches
for (
obs_batch,
actions_batch,
@@ -206,57 +213,46 @@ def update(self): # noqa: C901
old_actions_log_prob_batch,
old_mu_batch,
old_sigma_batch,
- hid_states_batch,
+ hidden_states_batch,
masks_batch,
) in generator:
-
- # number of augmentations per sample
- # we start with 1 and increase it if we use symmetry augmentation
- num_aug = 1
- # original batch size
- # we assume policy group is always there and needs augmentation
+ num_aug = 1 # Number of augmentations per sample. Starts at 1 for no augmentation.
original_batch_size = obs_batch.batch_size[0]
- # check if we should normalize advantages per mini batch
+ # Check if we should normalize advantages per mini batch
if self.normalize_advantage_per_mini_batch:
with torch.no_grad():
advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)
# Perform symmetric augmentation
if self.symmetry and self.symmetry["use_data_augmentation"]:
- # augmentation using symmetry
+ # Augmentation using symmetry
data_augmentation_func = self.symmetry["data_augmentation_func"]
- # returned shape: [batch_size * num_aug, ...]
+ # Returned shape: [batch_size * num_aug, ...]
obs_batch, actions_batch = data_augmentation_func(
obs=obs_batch,
actions=actions_batch,
env=self.symmetry["_env"],
)
- # compute number of augmentations per sample
- # we assume policy group is always there and needs augmentation
+ # Compute number of augmentations per sample
num_aug = int(obs_batch.batch_size[0] / original_batch_size)
- # repeat the rest of the batch
- # -- actor
+ # Repeat the rest of the batch
old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)
- # -- critic
target_values_batch = target_values_batch.repeat(num_aug, 1)
advantages_batch = advantages_batch.repeat(num_aug, 1)
returns_batch = returns_batch.repeat(num_aug, 1)
# Recompute actions log prob and entropy for current batch of transitions
- # Note: we need to do this because we updated the policy with the new parameters
- # -- actor
- self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
+ # Note: We need to do this because we updated the policy with the new parameters
+ self.policy.act(obs_batch, masks=masks_batch, hidden_state=hidden_states_batch[0])
actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch)
- # -- critic
- value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
- # -- entropy
- # we only keep the entropy of the first augmentation (the original one)
+ value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_state=hidden_states_batch[1])
+ # Note: We only keep the entropy of the first augmentation (the original one)
mu_batch = self.policy.action_mean[:original_batch_size]
sigma_batch = self.policy.action_std[:original_batch_size]
entropy_batch = self.policy.entropy[:original_batch_size]
- # KL
+ # Compute KL divergence and adapt the learning rate
if self.desired_kl is not None and self.schedule == "adaptive":
with torch.inference_mode():
kl = torch.sum(
@@ -273,8 +269,7 @@ def update(self): # noqa: C901
torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
kl_mean /= self.gpu_world_size
- # Update the learning rate
- # Perform this adaptation only on the main process
+ # Update the learning rate only on the main process
# TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
# then the learning rate should be the same across all GPUs.
if self.gpu_global_rank == 0:
@@ -316,70 +311,68 @@ def update(self): # noqa: C901
# Symmetry loss
if self.symmetry:
- # obtain the symmetric actions
- # if we did augmentation before then we don't need to augment again
+ # Obtain the symmetric actions
+ # Note: If we did augmentation before then we don't need to augment again
if not self.symmetry["use_data_augmentation"]:
data_augmentation_func = self.symmetry["data_augmentation_func"]
obs_batch, _ = data_augmentation_func(obs=obs_batch, actions=None, env=self.symmetry["_env"])
- # compute number of augmentations per sample
+ # Compute number of augmentations per sample
num_aug = int(obs_batch.shape[0] / original_batch_size)
- # actions predicted by the actor for symmetrically-augmented observations
+ # Actions predicted by the actor for symmetrically-augmented observations
mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone())
- # compute the symmetrically augmented actions
- # note: we are assuming the first augmentation is the original one.
- # We do not use the action_batch from earlier since that action was sampled from the distribution.
- # However, the symmetry loss is computed using the mean of the distribution.
+ # Compute the symmetrically augmented actions
+ # Note: We are assuming the first augmentation is the original one. We do not use the action_batch from
+ # earlier since that action was sampled from the distribution. However, the symmetry loss is computed
+ # using the mean of the distribution.
action_mean_orig = mean_actions_batch[:original_batch_size]
_, actions_mean_symm_batch = data_augmentation_func(
obs=None, actions=action_mean_orig, env=self.symmetry["_env"]
)
- # compute the loss (we skip the first augmentation as it is the original one)
+ # Compute the loss
mse_loss = torch.nn.MSELoss()
symmetry_loss = mse_loss(
mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:]
)
- # add the loss to the total loss
+ # Add the loss to the total loss
if self.symmetry["use_mirror_loss"]:
loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss
else:
symmetry_loss = symmetry_loss.detach()
- # Random Network Distillation loss
+ # RND loss
# TODO: Move this processing to inside RND module.
if self.rnd:
- # extract the rnd_state
+ # Extract the rnd_state
# TODO: Check if we still need torch no grad. It is just an affine transformation.
with torch.no_grad():
rnd_state_batch = self.rnd.get_rnd_state(obs_batch[:original_batch_size])
rnd_state_batch = self.rnd.state_normalizer(rnd_state_batch)
- # predict the embedding and the target
+ # Predict the embedding and the target
predicted_embedding = self.rnd.predictor(rnd_state_batch)
target_embedding = self.rnd.target(rnd_state_batch).detach()
- # compute the loss as the mean squared error
+ # Compute the loss as the mean squared error
mseloss = torch.nn.MSELoss()
rnd_loss = mseloss(predicted_embedding, target_embedding)
- # Compute the gradients
- # -- For PPO
+ # Compute the gradients for PPO
self.optimizer.zero_grad()
loss.backward()
- # -- For RND
+ # Compute the gradients for RND
if self.rnd:
- self.rnd_optimizer.zero_grad() # type: ignore
+ self.rnd_optimizer.zero_grad()
rnd_loss.backward()
# Collect gradients from all GPUs
if self.is_multi_gpu:
self.reduce_parameters()
- # Apply the gradients
- # -- For PPO
+ # Apply the gradients for PPO
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
- # -- For RND
+ # Apply the gradients for RND
if self.rnd_optimizer:
self.rnd_optimizer.step()
@@ -387,28 +380,27 @@ def update(self): # noqa: C901
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
mean_entropy += entropy_batch.mean().item()
- # -- RND loss
+ # RND loss
if mean_rnd_loss is not None:
mean_rnd_loss += rnd_loss.item()
- # -- Symmetry loss
+ # Symmetry loss
if mean_symmetry_loss is not None:
mean_symmetry_loss += symmetry_loss.item()
- # -- For PPO
+ # Divide the losses by the number of updates
num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
mean_entropy /= num_updates
- # -- For RND
if mean_rnd_loss is not None:
mean_rnd_loss /= num_updates
- # -- For Symmetry
if mean_symmetry_loss is not None:
mean_symmetry_loss /= num_updates
- # -- Clear the storage
+
+ # Clear the storage
self.storage.clear()
- # construct the loss dictionary
+ # Construct the loss dictionary
loss_dict = {
"value_function": mean_value_loss,
"surrogate": mean_surrogate_loss,
@@ -421,24 +413,20 @@ def update(self): # noqa: C901
return loss_dict
- """
- Helper functions
- """
-
- def broadcast_parameters(self):
+ def broadcast_parameters(self) -> None:
"""Broadcast model parameters to all GPUs."""
- # obtain the model parameters on current GPU
+ # Obtain the model parameters on current GPU
model_params = [self.policy.state_dict()]
if self.rnd:
model_params.append(self.rnd.predictor.state_dict())
- # broadcast the model parameters
+ # Broadcast the model parameters
torch.distributed.broadcast_object_list(model_params, src=0)
- # load the model parameters on all GPUs from source GPU
+ # Load the model parameters on all GPUs from source GPU
self.policy.load_state_dict(model_params[0])
if self.rnd:
self.rnd.predictor.load_state_dict(model_params[1])
- def reduce_parameters(self):
+ def reduce_parameters(self) -> None:
"""Collect gradients from all GPUs and average them.
This function is called after the backward pass to synchronize the gradients across all GPUs.
@@ -463,7 +451,7 @@ def reduce_parameters(self):
for param in all_params:
if param.grad is not None:
numel = param.numel()
- # copy data back from shared buffer
+ # Copy data back from shared buffer
param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
- # update the offset for the next parameter
+ # Update the offset for the next parameter
offset += numel
diff --git a/rsl_rl/env/vec_env.py b/rsl_rl/env/vec_env.py
index 3e73e336..86285923 100644
--- a/rsl_rl/env/vec_env.py
+++ b/rsl_rl/env/vec_env.py
@@ -13,9 +13,8 @@
class VecEnv(ABC):
"""Abstract class for a vectorized environment.
- The vectorized environment is a collection of environments that are synchronized. This means that
- the same type of action is applied to all environments and the same type of observation is returned from all
- environments.
+ The vectorized environment is a collection of environments that are synchronized. This means that the same type of
+ action is applied to all environments and the same type of observation is returned from all environments.
"""
num_envs: int
@@ -41,16 +40,12 @@ class VecEnv(ABC):
cfg: dict | object
"""Configuration object."""
- """
- Operations.
- """
-
@abstractmethod
def get_observations(self) -> TensorDict:
"""Return the current observations.
Returns:
- observations (TensorDict): Observations from the environment.
+ The observations from the environment.
"""
raise NotImplementedError
@@ -59,16 +54,15 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T
"""Apply input action to the environment.
Args:
- actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions)
+ actions: Input actions to apply. Shape: (num_envs, num_actions)
Returns:
- observations (TensorDict): Observations from the environment.
- rewards (torch.Tensor): Rewards from the environment. Shape: (num_envs,)
- dones (torch.Tensor): Done flags from the environment. Shape: (num_envs,)
- extras (dict): Extra information from the environment.
+ observations: Observations from the environment.
+ rewards: Rewards from the environment. Shape: (num_envs,)
+ dones: Done flags from the environment. Shape: (num_envs,)
+ extras: Extra information from the environment.
Observations:
-
The observations TensorDict usually contains multiple observation groups. The `obs_groups`
dictionary of the runner configuration specifies which observation groups are used for which
purpose, i.e., it maps the available observation groups to observation sets. The observation sets
@@ -83,7 +77,6 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T
`rsl_rl/utils/utils.py`.
Extras:
-
The extras dictionary includes metrics such as the episode reward, episode length, etc. The following
dictionary keys are used by rsl_rl:
diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py
index 0d23fff6..efb8613a 100644
--- a/rsl_rl/modules/__init__.py
+++ b/rsl_rl/modules/__init__.py
@@ -7,14 +7,17 @@
from .actor_critic import ActorCritic
from .actor_critic_recurrent import ActorCriticRecurrent
-from .rnd import *
+from .rnd import RandomNetworkDistillation, resolve_rnd_config
from .student_teacher import StudentTeacher
from .student_teacher_recurrent import StudentTeacherRecurrent
-from .symmetry import *
+from .symmetry import resolve_symmetry_config
__all__ = [
"ActorCritic",
"ActorCriticRecurrent",
+ "RandomNetworkDistillation",
"StudentTeacher",
"StudentTeacherRecurrent",
+ "resolve_rnd_config",
+ "resolve_symmetry_config",
]
diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py
index 89434c3e..9f01b2f4 100644
--- a/rsl_rl/modules/actor_critic.py
+++ b/rsl_rl/modules/actor_critic.py
@@ -7,37 +7,38 @@
import torch
import torch.nn as nn
+from tensordict import TensorDict
from torch.distributions import Normal
+from typing import Any, NoReturn
from rsl_rl.networks import MLP, EmpiricalNormalization
class ActorCritic(nn.Module):
- is_recurrent = False
+ is_recurrent: bool = False
def __init__(
self,
- obs,
- obs_groups,
- num_actions,
- actor_obs_normalization=False,
- critic_obs_normalization=False,
- actor_hidden_dims=[256, 256, 256],
- critic_hidden_dims=[256, 256, 256],
- activation="elu",
- init_noise_std=1.0,
+ obs: TensorDict,
+ obs_groups: dict[str, list[str]],
+ num_actions: int,
+ actor_obs_normalization: bool = False,
+ critic_obs_normalization: bool = False,
+ actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ activation: str = "elu",
+ init_noise_std: float = 1.0,
noise_std_type: str = "scalar",
- state_dependent_std=False,
- **kwargs,
- ):
+ state_dependent_std: bool = False,
+ **kwargs: dict[str, Any],
+ ) -> None:
if kwargs:
print(
- "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
- + str([key for key in kwargs.keys()])
+ "ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs])
)
super().__init__()
- # get the observation dimensions
+ # Get the observation dimensions
self.obs_groups = obs_groups
num_actor_obs = 0
for obs_group in obs_groups["policy"]:
@@ -49,28 +50,31 @@ def __init__(
num_critic_obs += obs[obs_group].shape[-1]
self.state_dependent_std = state_dependent_std
- # actor
+
+ # Actor
if self.state_dependent_std:
self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation)
else:
self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation)
- # actor observation normalization
+ print(f"Actor MLP: {self.actor}")
+
+ # Actor observation normalization
self.actor_obs_normalization = actor_obs_normalization
if actor_obs_normalization:
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs)
else:
self.actor_obs_normalizer = torch.nn.Identity()
- print(f"Actor MLP: {self.actor}")
- # critic
+ # Critic
self.critic = MLP(num_critic_obs, 1, critic_hidden_dims, activation)
- # critic observation normalization
+ print(f"Critic MLP: {self.critic}")
+
+ # Critic observation normalization
self.critic_obs_normalization = critic_obs_normalization
if critic_obs_normalization:
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
else:
self.critic_obs_normalizer = torch.nn.Identity()
- print(f"Critic MLP: {self.critic}")
# Action noise
self.noise_std_type = noise_std_type
@@ -92,32 +96,34 @@ def __init__(
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # Action distribution (populated in update_distribution)
+ # Action distribution
+ # Note: Populated in update_distribution
self.distribution = None
- # disable args validation for speedup
+
+ # Disable args validation for speedup
Normal.set_default_validate_args(False)
- def reset(self, dones=None):
+ def reset(self, dones: torch.Tensor | None = None) -> None:
pass
- def forward(self):
+ def forward(self) -> NoReturn:
raise NotImplementedError
@property
- def action_mean(self):
+ def action_mean(self) -> torch.Tensor:
return self.distribution.mean
@property
- def action_std(self):
+ def action_std(self) -> torch.Tensor:
return self.distribution.stddev
@property
- def entropy(self):
+ def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)
- def update_distribution(self, obs):
+ def _update_distribution(self, obs: TensorDict) -> None:
if self.state_dependent_std:
- # compute mean and standard deviation
+ # Compute mean and standard deviation
mean_and_std = self.actor(obs)
if self.noise_std_type == "scalar":
mean, std = torch.unbind(mean_and_std, dim=-2)
@@ -127,25 +133,25 @@ def update_distribution(self, obs):
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
- # compute mean
+ # Compute mean
mean = self.actor(obs)
- # compute standard deviation
+ # Compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # create distribution
+ # Create distribution
self.distribution = Normal(mean, std)
- def act(self, obs, **kwargs):
+ def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
obs = self.get_actor_obs(obs)
obs = self.actor_obs_normalizer(obs)
- self.update_distribution(obs)
+ self._update_distribution(obs)
return self.distribution.sample()
- def act_inference(self, obs):
+ def act_inference(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_actor_obs(obs)
obs = self.actor_obs_normalizer(obs)
if self.state_dependent_std:
@@ -153,27 +159,23 @@ def act_inference(self, obs):
else:
return self.actor(obs)
- def evaluate(self, obs, **kwargs):
+ def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor:
obs = self.get_critic_obs(obs)
obs = self.critic_obs_normalizer(obs)
return self.critic(obs)
- def get_actor_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["policy"]:
- obs_list.append(obs[obs_group])
+ def get_actor_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]]
return torch.cat(obs_list, dim=-1)
- def get_critic_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["critic"]:
- obs_list.append(obs[obs_group])
+ def get_critic_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["critic"]]
return torch.cat(obs_list, dim=-1)
- def get_actions_log_prob(self, actions):
+ def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
return self.distribution.log_prob(actions).sum(dim=-1)
- def update_normalization(self, obs):
+ def update_normalization(self, obs: TensorDict) -> None:
if self.actor_obs_normalization:
actor_obs = self.get_actor_obs(obs)
self.actor_obs_normalizer.update(actor_obs)
@@ -181,18 +183,17 @@ def update_normalization(self, obs):
critic_obs = self.get_critic_obs(obs)
self.critic_obs_normalizer.update(critic_obs)
- def load_state_dict(self, state_dict, strict=True):
+ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
"""Load the parameters of the actor-critic model.
Args:
- state_dict (dict): State dictionary of the model.
- strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
- module's state_dict() function.
+ state_dict: State dictionary of the model.
+ strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
+ :meth:`state_dict` function.
Returns:
- bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
- `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
+ Whether this training resumes a previous training. This flag is used by the :func:`load` function of
+ :class:`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
"""
-
super().load_state_dict(state_dict, strict=strict)
- return True # training resumes
+ return True
diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py
index 03dfc162..509b6821 100644
--- a/rsl_rl/modules/actor_critic_recurrent.py
+++ b/rsl_rl/modules/actor_critic_recurrent.py
@@ -8,32 +8,34 @@
import torch
import torch.nn as nn
import warnings
+from tensordict import TensorDict
from torch.distributions import Normal
+from typing import Any, NoReturn
-from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
+from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState, Memory
class ActorCriticRecurrent(nn.Module):
- is_recurrent = True
+ is_recurrent: bool = True
def __init__(
self,
- obs,
- obs_groups,
- num_actions,
- actor_obs_normalization=False,
- critic_obs_normalization=False,
- actor_hidden_dims=[256, 256, 256],
- critic_hidden_dims=[256, 256, 256],
- activation="elu",
- init_noise_std=1.0,
+ obs: TensorDict,
+ obs_groups: dict[str, list[str]],
+ num_actions: int,
+ actor_obs_normalization: bool = False,
+ critic_obs_normalization: bool = False,
+ actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ activation: str = "elu",
+ init_noise_std: float = 1.0,
noise_std_type: str = "scalar",
- state_dependent_std=False,
- rnn_type="lstm",
- rnn_hidden_dim=256,
- rnn_num_layers=1,
- **kwargs,
- ):
+ state_dependent_std: bool = False,
+ rnn_type: str = "lstm",
+ rnn_hidden_dim: int = 256,
+ rnn_num_layers: int = 1,
+ **kwargs: dict[str, Any],
+ ) -> None:
if "rnn_hidden_size" in kwargs:
warnings.warn(
"The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
@@ -48,7 +50,7 @@ def __init__(
)
super().__init__()
- # get the observation dimensions
+ # Get the observation dimensions
self.obs_groups = obs_groups
num_actor_obs = 0
for obs_group in obs_groups["policy"]:
@@ -60,33 +62,35 @@ def __init__(
num_critic_obs += obs[obs_group].shape[-1]
self.state_dependent_std = state_dependent_std
- # actor
- self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
+
+ # Actor
+ self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
if self.state_dependent_std:
self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation)
else:
self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation)
+ print(f"Actor RNN: {self.memory_a}")
+ print(f"Actor MLP: {self.actor}")
- # actor observation normalization
+ # Actor observation normalization
self.actor_obs_normalization = actor_obs_normalization
if actor_obs_normalization:
self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs)
else:
self.actor_obs_normalizer = torch.nn.Identity()
- print(f"Actor RNN: {self.memory_a}")
- print(f"Actor MLP: {self.actor}")
- # critic
- self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
+ # Critic
+ self.memory_c = Memory(num_critic_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
self.critic = MLP(rnn_hidden_dim, 1, critic_hidden_dims, activation)
- # critic observation normalization
+ print(f"Critic RNN: {self.memory_c}")
+ print(f"Critic MLP: {self.critic}")
+
+ # Critic observation normalization
self.critic_obs_normalization = critic_obs_normalization
if critic_obs_normalization:
self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs)
else:
self.critic_obs_normalizer = torch.nn.Identity()
- print(f"Critic RNN: {self.memory_c}")
- print(f"Critic MLP: {self.critic}")
# Action noise
self.noise_std_type = noise_std_type
@@ -108,33 +112,35 @@ def __init__(
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # Action distribution (populated in update_distribution)
+ # Action distribution
+ # Note: Populated in update_distribution
self.distribution = None
- # disable args validation for speedup
+
+ # Disable args validation for speedup
Normal.set_default_validate_args(False)
@property
- def action_mean(self):
+ def action_mean(self) -> torch.Tensor:
return self.distribution.mean
@property
- def action_std(self):
+ def action_std(self) -> torch.Tensor:
return self.distribution.stddev
@property
- def entropy(self):
+ def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)
- def reset(self, dones=None):
+ def reset(self, dones: torch.Tensor | None = None) -> None:
self.memory_a.reset(dones)
self.memory_c.reset(dones)
- def forward(self):
+ def forward(self) -> NoReturn:
raise NotImplementedError
- def update_distribution(self, obs):
+ def _update_distribution(self, obs: TensorDict) -> None:
if self.state_dependent_std:
- # compute mean and standard deviation
+ # Compute mean and standard deviation
mean_and_std = self.actor(obs)
if self.noise_std_type == "scalar":
mean, std = torch.unbind(mean_and_std, dim=-2)
@@ -144,26 +150,26 @@ def update_distribution(self, obs):
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
else:
- # compute mean
+ # Compute mean
mean = self.actor(obs)
- # compute standard deviation
+ # Compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # create distribution
+ # Create distribution
self.distribution = Normal(mean, std)
- def act(self, obs, masks=None, hidden_states=None):
+ def act(self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None) -> torch.Tensor:
obs = self.get_actor_obs(obs)
obs = self.actor_obs_normalizer(obs)
- out_mem = self.memory_a(obs, masks, hidden_states).squeeze(0)
- self.update_distribution(out_mem)
+ out_mem = self.memory_a(obs, masks, hidden_state).squeeze(0)
+ self._update_distribution(out_mem)
return self.distribution.sample()
- def act_inference(self, obs):
+ def act_inference(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_actor_obs(obs)
obs = self.actor_obs_normalizer(obs)
out_mem = self.memory_a(obs).squeeze(0)
@@ -172,31 +178,29 @@ def act_inference(self, obs):
else:
return self.actor(out_mem)
- def evaluate(self, obs, masks=None, hidden_states=None):
+ def evaluate(
+ self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None
+ ) -> torch.Tensor:
obs = self.get_critic_obs(obs)
obs = self.critic_obs_normalizer(obs)
- out_mem = self.memory_c(obs, masks, hidden_states).squeeze(0)
+ out_mem = self.memory_c(obs, masks, hidden_state).squeeze(0)
return self.critic(out_mem)
- def get_actor_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["policy"]:
- obs_list.append(obs[obs_group])
+ def get_actor_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]]
return torch.cat(obs_list, dim=-1)
- def get_critic_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["critic"]:
- obs_list.append(obs[obs_group])
+ def get_critic_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["critic"]]
return torch.cat(obs_list, dim=-1)
- def get_actions_log_prob(self, actions):
+ def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
return self.distribution.log_prob(actions).sum(dim=-1)
- def get_hidden_states(self):
- return self.memory_a.hidden_states, self.memory_c.hidden_states
+ def get_hidden_states(self) -> tuple[HiddenState, HiddenState]:
+ return self.memory_a.hidden_state, self.memory_c.hidden_state
- def update_normalization(self, obs):
+ def update_normalization(self, obs: TensorDict) -> None:
if self.actor_obs_normalization:
actor_obs = self.get_actor_obs(obs)
self.actor_obs_normalizer.update(actor_obs)
@@ -204,18 +208,17 @@ def update_normalization(self, obs):
critic_obs = self.get_critic_obs(obs)
self.critic_obs_normalizer.update(critic_obs)
- def load_state_dict(self, state_dict, strict=True):
+ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
"""Load the parameters of the actor-critic model.
Args:
- state_dict (dict): State dictionary of the model.
- strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
- module's state_dict() function.
+ state_dict: State dictionary of the model.
+ strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
+ :meth:`state_dict` function.
Returns:
- bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
- `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
+ Whether this training resumes a previous training. This flag is used by the :func:`load` function of
+ :class:`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
"""
-
super().load_state_dict(state_dict, strict=strict)
return True
diff --git a/rsl_rl/modules/rnd.py b/rsl_rl/modules/rnd.py
index 8e65c43a..3b536dcd 100644
--- a/rsl_rl/modules/rnd.py
+++ b/rsl_rl/modules/rnd.py
@@ -7,15 +7,18 @@
import torch
import torch.nn as nn
+from tensordict import TensorDict
+from typing import Any, NoReturn
+from rsl_rl.env import VecEnv
from rsl_rl.networks import MLP, EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
class RandomNetworkDistillation(nn.Module):
- """Implementation of Random Network Distillation (RND) [1]
+ """Implementation of Random Network Distillation (RND) [1].
References:
- .. [1] Burda, Yuri, et al. "Exploration by random network distillation." arXiv preprint arXiv:1810.12894 (2018).
+ .. [1] Burda, Yuri, et al. "Exploration by Random Network Distillation." arXiv preprint arXiv:1810.12894 (2018).
"""
def __init__(
@@ -23,55 +26,54 @@ def __init__(
num_states: int,
obs_groups: dict,
num_outputs: int,
- predictor_hidden_dims: list[int],
- target_hidden_dims: list[int],
+ predictor_hidden_dims: tuple[int] | list[int],
+ target_hidden_dims: tuple[int] | list[int],
activation: str = "elu",
weight: float = 0.0,
state_normalization: bool = False,
reward_normalization: bool = False,
device: str = "cpu",
weight_schedule: dict | None = None,
- ):
+ ) -> None:
"""Initialize the RND module.
- - If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization layer.
+ - If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization
+ layer.
- If :attr:`reward_normalization` is True, then the intrinsic reward is normalized using an Empirical Discounted
Variation Normalization layer.
-
- .. note::
- If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states
- is used as the hidden dimension.
+ - If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states
+ is used as the hidden dimension.
Args:
num_states: Number of states/inputs to the predictor and target networks.
+ obs_groups: Dictionary of observation groups.
num_outputs: Number of outputs (embedding size) of the predictor and target networks.
predictor_hidden_dims: List of hidden dimensions of the predictor network.
target_hidden_dims: List of hidden dimensions of the target network.
- activation: Activation function. Defaults to "elu".
- weight: Scaling factor of the intrinsic reward. Defaults to 0.0.
- state_normalization: Whether to normalize the input state. Defaults to False.
- reward_normalization: Whether to normalize the intrinsic reward. Defaults to False.
- device: Device to use. Defaults to "cpu".
- weight_schedule: The type of schedule to use for the RND weight parameter.
- Defaults to None, in which case the weight parameter is constant.
+ activation: Activation function.
+ weight: Scaling factor of the intrinsic reward.
+ state_normalization: Whether to normalize the input state.
+ reward_normalization: Whether to normalize the intrinsic reward.
+ device: Device to use.
+ weight_schedule: Type of schedule to use for the RND weight parameter.
It is a dictionary with the following keys:
- - "mode": The type of schedule to use for the RND weight parameter.
+ - "mode": Type of schedule to use for the RND weight parameter.
- "constant": Constant weight schedule.
- "step": Step weight schedule.
- "linear": Linear weight schedule.
For the "step" weight schedule, the following parameters are required:
- - "final_step": The step at which the weight parameter is set to the final value.
- - "final_value": The final value of the weight parameter.
+ - "final_step": Step at which the weight parameter is set to the final value.
+ - "final_value": Final value of the weight parameter.
For the "linear" weight schedule, the following parameters are required:
- - "initial_step": The step at which the weight parameter is set to the initial value.
- - "final_step": The step at which the weight parameter is set to the final value.
- - "final_value": The final value of the weight parameter.
+ - "initial_step": Step at which the weight parameter is set to the initial value.
+ - "final_step": Step at which the weight parameter is set to the final value.
+ - "final_value": Final value of the weight parameter.
"""
- # initialize parent class
+ # Initialize parent class
super().__init__()
# Store parameters
@@ -88,30 +90,32 @@ def __init__(
self.state_normalizer = EmpiricalNormalization(shape=[self.num_states], until=1.0e8).to(self.device)
else:
self.state_normalizer = torch.nn.Identity()
+
# Normalization of intrinsic reward
if reward_normalization:
self.reward_normalizer = EmpiricalDiscountedVariationNormalization(shape=[], until=1.0e8).to(self.device)
else:
self.reward_normalizer = torch.nn.Identity()
- # counter for the number of updates
+ # Counter for the number of updates
self.update_counter = 0
- # resolve weight schedule
+ # Resolve weight schedule
if weight_schedule is not None:
self.weight_scheduler_params = weight_schedule
self.weight_scheduler = getattr(self, f"_{weight_schedule['mode']}_weight_schedule")
else:
self.weight_scheduler = None
+
# Create network architecture
self.predictor = MLP(num_states, num_outputs, predictor_hidden_dims, activation).to(self.device)
self.target = MLP(num_states, num_outputs, target_hidden_dims, activation).to(self.device)
- # make target network not trainable
+ # Make target network not trainable
self.target.eval()
- def get_intrinsic_reward(self, obs) -> torch.Tensor:
- # Note: the counter is updated number of env steps per learning iteration
+ def get_intrinsic_reward(self, obs: TensorDict) -> torch.Tensor:
+ # Note: The counter is updated number of env steps per learning iteration
self.update_counter += 1
# Extract the rnd state from the observation
rnd_state = self.get_rnd_state(obs)
@@ -123,7 +127,6 @@ def get_intrinsic_reward(self, obs) -> torch.Tensor:
intrinsic_reward = torch.linalg.norm(target_embedding - predictor_embedding, dim=1)
# Normalize intrinsic reward
intrinsic_reward = self.reward_normalizer(intrinsic_reward)
-
# Check the weight schedule
if self.weight_scheduler is not None:
self.weight = self.weight_scheduler(step=self.update_counter, **self.weight_scheduler_params)
@@ -134,11 +137,11 @@ def get_intrinsic_reward(self, obs) -> torch.Tensor:
return intrinsic_reward
- def forward(self, *args, **kwargs):
+ def forward(self, *args: Any, **kwargs: dict[str, Any]) -> NoReturn:
raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.")
- def train(self, mode: bool = True):
- # sets module into training mode
+ def train(self, mode: bool = True) -> RandomNetworkDistillation:
+ # Set module into training mode
self.predictor.train(mode)
if self.state_normalization:
self.state_normalizer.train(mode)
@@ -146,32 +149,28 @@ def train(self, mode: bool = True):
self.reward_normalizer.train(mode)
return self
- def eval(self):
+ def eval(self) -> RandomNetworkDistillation:
return self.train(False)
- def get_rnd_state(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["rnd_state"]:
- obs_list.append(obs[obs_group])
+ def get_rnd_state(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["rnd_state"]]
return torch.cat(obs_list, dim=-1)
- def update_normalization(self, obs):
+ def update_normalization(self, obs: TensorDict) -> None:
# Normalize the state
if self.state_normalization:
rnd_state = self.get_rnd_state(obs)
self.state_normalizer.update(rnd_state)
- """
- Different weight schedules.
- """
-
- def _constant_weight_schedule(self, step: int, **kwargs):
+ def _constant_weight_schedule(self, step: int, **kwargs: dict[str, Any]) -> float:
return self.initial_weight
- def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs):
+ def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs: dict[str, Any]) -> float:
return self.initial_weight if step < final_step else final_value
- def _linear_weight_schedule(self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs):
+ def _linear_weight_schedule(
+ self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs: dict[str, Any]
+ ) -> float:
if step < initial_step:
return self.initial_weight
elif step > final_step:
@@ -182,28 +181,28 @@ def _linear_weight_schedule(self, step: int, initial_step: int, final_step: int,
)
-def resolve_rnd_config(alg_cfg, obs, obs_groups, env):
+def resolve_rnd_config(alg_cfg: dict, obs: TensorDict, obs_groups: dict[str, list[str]], env: VecEnv) -> dict:
"""Resolve the RND configuration.
Args:
- alg_cfg: The algorithm configuration dictionary.
- obs: The observation dictionary.
- obs_groups: The observation groups dictionary.
- env: The environment.
+ alg_cfg: Algorithm configuration dictionary.
+ obs: Observation dictionary.
+ obs_groups: Observation groups dictionary.
+ env: Environment object.
Returns:
The resolved algorithm configuration dictionary.
"""
- # resolve dimension of rnd gated state
+ # Resolve dimension of rnd gated state
if "rnd_cfg" in alg_cfg and alg_cfg["rnd_cfg"] is not None:
- # get dimension of rnd gated state
+ # Get dimension of rnd gated state
num_rnd_state = 0
for obs_group in obs_groups["rnd_state"]:
assert len(obs[obs_group].shape) == 2, "The RND module only supports 1D observations."
num_rnd_state += obs[obs_group].shape[-1]
- # add rnd gated state to config
+ # Add rnd gated state to config
alg_cfg["rnd_cfg"]["num_states"] = num_rnd_state
alg_cfg["rnd_cfg"]["obs_groups"] = obs_groups
- # scale down the rnd weight with timestep
+ # Scale down the rnd weight with timestep
alg_cfg["rnd_cfg"]["weight"] *= env.unwrapped.step_dt
return alg_cfg
diff --git a/rsl_rl/modules/student_teacher.py b/rsl_rl/modules/student_teacher.py
index 6bf1380c..82b3ef62 100644
--- a/rsl_rl/modules/student_teacher.py
+++ b/rsl_rl/modules/student_teacher.py
@@ -7,38 +7,40 @@
import torch
import torch.nn as nn
+from tensordict import TensorDict
from torch.distributions import Normal
+from typing import Any, NoReturn
-from rsl_rl.networks import MLP, EmpiricalNormalization
+from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState
class StudentTeacher(nn.Module):
- is_recurrent = False
+ is_recurrent: bool = False
def __init__(
self,
- obs,
- obs_groups,
- num_actions,
- student_obs_normalization=False,
- teacher_obs_normalization=False,
- student_hidden_dims=[256, 256, 256],
- teacher_hidden_dims=[256, 256, 256],
- activation="elu",
- init_noise_std=0.1,
+ obs: TensorDict,
+ obs_groups: dict[str, list[str]],
+ num_actions: int,
+ student_obs_normalization: bool = False,
+ teacher_obs_normalization: bool = False,
+ student_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ teacher_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ activation: str = "elu",
+ init_noise_std: float = 0.1,
noise_std_type: str = "scalar",
- **kwargs,
- ):
+ **kwargs: dict[str, Any],
+ ) -> None:
if kwargs:
print(
"StudentTeacher.__init__ got unexpected arguments, which will be ignored: "
- + str([key for key in kwargs.keys()])
+ + str([key for key in kwargs])
)
super().__init__()
- self.loaded_teacher = False # indicates if teacher has been loaded
+ self.loaded_teacher = False # Indicates if teacher has been loaded
- # get the observation dimensions
+ # Get the observation dimensions
self.obs_groups = obs_groups
num_student_obs = 0
for obs_group in obs_groups["policy"]:
@@ -49,32 +51,30 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The StudentTeacher module only supports 1D observations."
num_teacher_obs += obs[obs_group].shape[-1]
- # student
+ # Student
self.student = MLP(num_student_obs, num_actions, student_hidden_dims, activation)
+ print(f"Student MLP: {self.student}")
- # student observation normalization
+ # Student observation normalization
self.student_obs_normalization = student_obs_normalization
if student_obs_normalization:
self.student_obs_normalizer = EmpiricalNormalization(num_student_obs)
else:
self.student_obs_normalizer = torch.nn.Identity()
- print(f"Student MLP: {self.student}")
-
- # teacher
+ # Teacher
self.teacher = MLP(num_teacher_obs, num_actions, teacher_hidden_dims, activation)
self.teacher.eval()
+ print(f"Teacher MLP: {self.teacher}")
- # teacher observation normalization
+ # Teacher observation normalization
self.teacher_obs_normalization = teacher_obs_normalization
if teacher_obs_normalization:
self.teacher_obs_normalizer = EmpiricalNormalization(num_teacher_obs)
else:
self.teacher_obs_normalizer = torch.nn.Identity()
- print(f"Teacher MLP: {self.teacher}")
-
- # action noise
+ # Action noise
self.noise_std_type = noise_std_type
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
@@ -83,104 +83,103 @@ def __init__(
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # action distribution (populated in update_distribution)
+ # Action distribution
+ # Note: Populated in update_distribution
self.distribution = None
- # disable args validation for speedup
+
+ # Disable args validation for speedup
Normal.set_default_validate_args(False)
- def reset(self, dones=None, hidden_states=None):
+ def reset(
+ self, dones: torch.Tensor | None = None, hidden_states: tuple[HiddenState, HiddenState] = (None, None)
+ ) -> None:
pass
- def forward(self):
+ def forward(self) -> NoReturn:
raise NotImplementedError
@property
- def action_mean(self):
+ def action_mean(self) -> torch.Tensor:
return self.distribution.mean
@property
- def action_std(self):
+ def action_std(self) -> torch.Tensor:
return self.distribution.stddev
@property
- def entropy(self):
+ def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)
- def update_distribution(self, obs):
- # compute mean
+ def _update_distribution(self, obs: TensorDict) -> None:
+ # Compute mean
mean = self.student(obs)
- # compute standard deviation
+ # Compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # create distribution
+ # Create distribution
self.distribution = Normal(mean, std)
- def act(self, obs):
+ def act(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_student_obs(obs)
obs = self.student_obs_normalizer(obs)
- self.update_distribution(obs)
+ self._update_distribution(obs)
return self.distribution.sample()
- def act_inference(self, obs):
+ def act_inference(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_student_obs(obs)
obs = self.student_obs_normalizer(obs)
return self.student(obs)
- def evaluate(self, obs):
+ def evaluate(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_teacher_obs(obs)
obs = self.teacher_obs_normalizer(obs)
with torch.no_grad():
return self.teacher(obs)
- def get_student_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["policy"]:
- obs_list.append(obs[obs_group])
+ def get_student_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]]
return torch.cat(obs_list, dim=-1)
- def get_teacher_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["teacher"]:
- obs_list.append(obs[obs_group])
+ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
return torch.cat(obs_list, dim=-1)
- def get_hidden_states(self):
- return None
+ def get_hidden_states(self) -> tuple[HiddenState, HiddenState]:
+ return None, None
- def detach_hidden_states(self, dones=None):
+ def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
pass
- def train(self, mode=True):
+ def train(self, mode: bool = True) -> None:
super().train(mode)
- # make sure teacher is in eval mode
+ # Make sure teacher is in eval mode
self.teacher.eval()
self.teacher_obs_normalizer.eval()
- def update_normalization(self, obs):
+ def update_normalization(self, obs: TensorDict) -> None:
if self.student_obs_normalization:
student_obs = self.get_student_obs(obs)
self.student_obs_normalizer.update(student_obs)
- def load_state_dict(self, state_dict, strict=True):
+ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
"""Load the parameters of the student and teacher networks.
Args:
- state_dict (dict): State dictionary of the model.
- strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
- module's state_dict() function.
+ state_dict: State dictionary of the model.
+ strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
+ :meth:`state_dict` function.
Returns:
- bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
- `OnPolicyRunner` to determine how to load further parameters.
+ Whether this training resumes a previous training. This flag is used by the :func:`load` function of
+ :class:`OnPolicyRunner` to determine how to load further parameters.
"""
-
- # check if state_dict contains teacher and student or just teacher parameters
- if any("actor" in key for key in state_dict.keys()): # loading parameters from rl training
- # rename keys to match teacher and remove critic parameters
+ # Check if state_dict contains teacher and student or just teacher parameters
+ if any("actor" in key for key in state_dict): # Load parameters from rl training
+ # Rename keys to match teacher and remove critic parameters
teacher_state_dict = {}
teacher_obs_normalizer_state_dict = {}
for key, value in state_dict.items():
@@ -190,17 +189,17 @@ def load_state_dict(self, state_dict, strict=True):
teacher_obs_normalizer_state_dict[key.replace("actor_obs_normalizer.", "")] = value
self.teacher.load_state_dict(teacher_state_dict, strict=strict)
self.teacher_obs_normalizer.load_state_dict(teacher_obs_normalizer_state_dict, strict=strict)
- # set flag for successfully loading the parameters
+ # Set flag for successfully loading the parameters
self.loaded_teacher = True
self.teacher.eval()
self.teacher_obs_normalizer.eval()
- return False # training does not resume
- elif any("student" in key for key in state_dict.keys()): # loading parameters from distillation training
+ return False # Training does not resume
+ elif any("student" in key for key in state_dict): # Load parameters from distillation training
super().load_state_dict(state_dict, strict=strict)
- # set flag for successfully loading the parameters
+ # Set flag for successfully loading the parameters
self.loaded_teacher = True
self.teacher.eval()
self.teacher_obs_normalizer.eval()
- return True # training resumes
+ return True # Training resumes
else:
raise ValueError("state_dict does not contain student or teacher parameters")
diff --git a/rsl_rl/modules/student_teacher_recurrent.py b/rsl_rl/modules/student_teacher_recurrent.py
index bba4bd34..2819c560 100644
--- a/rsl_rl/modules/student_teacher_recurrent.py
+++ b/rsl_rl/modules/student_teacher_recurrent.py
@@ -8,32 +8,34 @@
import torch
import torch.nn as nn
import warnings
+from tensordict import TensorDict
from torch.distributions import Normal
+from typing import Any, NoReturn
-from rsl_rl.networks import MLP, EmpiricalNormalization, Memory
+from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState, Memory
class StudentTeacherRecurrent(nn.Module):
- is_recurrent = True
+ is_recurrent: bool = True
def __init__(
self,
- obs,
- obs_groups,
- num_actions,
- student_obs_normalization=False,
- teacher_obs_normalization=False,
- student_hidden_dims=[256, 256, 256],
- teacher_hidden_dims=[256, 256, 256],
- activation="elu",
- init_noise_std=0.1,
+ obs: TensorDict,
+ obs_groups: dict[str, list[str]],
+ num_actions: int,
+ student_obs_normalization: bool = False,
+ teacher_obs_normalization: bool = False,
+ student_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ teacher_hidden_dims: tuple[int] | list[int] = [256, 256, 256],
+ activation: str = "elu",
+ init_noise_std: float = 0.1,
noise_std_type: str = "scalar",
- rnn_type="lstm",
- rnn_hidden_dim=256,
- rnn_num_layers=1,
- teacher_recurrent=False,
- **kwargs,
- ):
+ rnn_type: str = "lstm",
+ rnn_hidden_dim: int = 256,
+ rnn_num_layers: int = 1,
+ teacher_recurrent: bool = False,
+ **kwargs: dict[str, Any],
+ ) -> None:
if "rnn_hidden_size" in kwargs:
warnings.warn(
"The argument `rnn_hidden_size` is deprecated and will be removed in a future version. "
@@ -49,10 +51,10 @@ def __init__(
)
super().__init__()
- self.loaded_teacher = False # indicates if teacher has been loaded
- self.teacher_recurrent = teacher_recurrent # indicates if teacher is recurrent too
+ self.loaded_teacher = False # Indicates if teacher has been loaded
+ self.teacher_recurrent = teacher_recurrent # Indicates if teacher is recurrent too
- # get the observation dimensions
+ # Get the observation dimensions
self.obs_groups = obs_groups
num_student_obs = 0
for obs_group in obs_groups["policy"]:
@@ -63,39 +65,35 @@ def __init__(
assert len(obs[obs_group].shape) == 2, "The StudentTeacher module only supports 1D observations."
num_teacher_obs += obs[obs_group].shape[-1]
- # student
- self.memory_s = Memory(num_student_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
+ # Student
+ self.memory_s = Memory(num_student_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
self.student = MLP(rnn_hidden_dim, num_actions, student_hidden_dims, activation)
+ print(f"Student RNN: {self.memory_s}")
+ print(f"Student MLP: {self.student}")
- # student observation normalization
+ # Student observation normalization
self.student_obs_normalization = student_obs_normalization
if student_obs_normalization:
self.student_obs_normalizer = EmpiricalNormalization(num_student_obs)
else:
self.student_obs_normalizer = torch.nn.Identity()
- print(f"Student RNN: {self.memory_s}")
- print(f"Student MLP: {self.student}")
-
- # teacher
+ # Teacher
if self.teacher_recurrent:
- self.memory_t = Memory(
- num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim
- )
+ self.memory_t = Memory(num_teacher_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
self.teacher = MLP(rnn_hidden_dim, num_actions, teacher_hidden_dims, activation)
+ if self.teacher_recurrent:
+ print(f"Teacher RNN: {self.memory_t}")
+ print(f"Teacher MLP: {self.teacher}")
- # teacher observation normalization
+ # Teacher observation normalization
self.teacher_obs_normalization = teacher_obs_normalization
if teacher_obs_normalization:
self.teacher_obs_normalizer = EmpiricalNormalization(num_teacher_obs)
else:
self.teacher_obs_normalizer = torch.nn.Identity()
- if self.teacher_recurrent:
- print(f"Teacher RNN: {self.memory_t}")
- print(f"Teacher MLP: {self.teacher}")
-
- # action noise
+ # Action noise
self.noise_std_type = noise_std_type
if self.noise_std_type == "scalar":
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
@@ -104,60 +102,62 @@ def __init__(
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # action distribution (populated in update_distribution)
+ # Action distribution
+ # Note: Populated in update_distribution
self.distribution = None
- # disable args validation for speedup
+
+ # Disable args validation for speedup
Normal.set_default_validate_args(False)
- def reset(self, dones=None, hidden_states=None):
- if hidden_states is None:
- hidden_states = (None, None)
+ def reset(
+ self, dones: torch.Tensor | None = None, hidden_states: tuple[HiddenState, HiddenState] = (None, None)
+ ) -> None:
self.memory_s.reset(dones, hidden_states[0])
if self.teacher_recurrent:
self.memory_t.reset(dones, hidden_states[1])
- def forward(self):
+ def forward(self) -> NoReturn:
raise NotImplementedError
@property
- def action_mean(self):
+ def action_mean(self) -> torch.Tensor:
return self.distribution.mean
@property
- def action_std(self):
+ def action_std(self) -> torch.Tensor:
return self.distribution.stddev
@property
- def entropy(self):
+ def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)
- def update_distribution(self, obs):
- # compute mean
+ def _update_distribution(self, obs: TensorDict) -> None:
+ # Compute mean
mean = self.student(obs)
- # compute standard deviation
+ # Compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
elif self.noise_std_type == "log":
std = torch.exp(self.log_std).expand_as(mean)
else:
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
- # create distribution
+ # Create distribution
self.distribution = Normal(mean, std)
- def act(self, obs):
+ def act(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_student_obs(obs)
obs = self.student_obs_normalizer(obs)
out_mem = self.memory_s(obs).squeeze(0)
- self.update_distribution(out_mem)
+ self._update_distribution(out_mem)
return self.distribution.sample()
- def act_inference(self, obs):
+ def act_inference(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_student_obs(obs)
obs = self.student_obs_normalizer(obs)
out_mem = self.memory_s(obs).squeeze(0)
return self.student(out_mem)
- def evaluate(self, obs):
+ def evaluate(self, obs: TensorDict) -> torch.Tensor:
obs = self.get_teacher_obs(obs)
obs = self.teacher_obs_normalizer(obs)
with torch.no_grad():
@@ -166,56 +166,51 @@ def evaluate(self, obs):
obs = self.memory_t(obs).squeeze(0)
return self.teacher(obs)
- def get_student_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["policy"]:
- obs_list.append(obs[obs_group])
+ def get_student_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]]
return torch.cat(obs_list, dim=-1)
- def get_teacher_obs(self, obs):
- obs_list = []
- for obs_group in self.obs_groups["teacher"]:
- obs_list.append(obs[obs_group])
+ def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor:
+ obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]]
return torch.cat(obs_list, dim=-1)
- def get_hidden_states(self):
+ def get_hidden_states(self) -> tuple[HiddenState, HiddenState]:
if self.teacher_recurrent:
- return self.memory_s.hidden_states, self.memory_t.hidden_states
+ return self.memory_s.hidden_state, self.memory_t.hidden_state
else:
- return self.memory_s.hidden_states, None
+ return self.memory_s.hidden_state, None
- def detach_hidden_states(self, dones=None):
- self.memory_s.detach_hidden_states(dones)
+ def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None:
+ self.memory_s.detach_hidden_state(dones)
if self.teacher_recurrent:
- self.memory_t.detach_hidden_states(dones)
+ self.memory_t.detach_hidden_state(dones)
- def train(self, mode=True):
+ def train(self, mode: bool = True) -> None:
super().train(mode)
- # make sure teacher is in eval mode
+ # Make sure teacher is in eval mode
self.teacher.eval()
self.teacher_obs_normalizer.eval()
- def update_normalization(self, obs):
+ def update_normalization(self, obs: TensorDict) -> None:
if self.student_obs_normalization:
student_obs = self.get_student_obs(obs)
self.student_obs_normalizer.update(student_obs)
- def load_state_dict(self, state_dict, strict=True):
+ def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool:
"""Load the parameters of the student and teacher networks.
Args:
- state_dict (dict): State dictionary of the model.
- strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
- module's state_dict() function.
+ state_dict: State dictionary of the model.
+ strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's
+ :meth:`state_dict` function.
Returns:
- bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
- `OnPolicyRunner` to determine how to load further parameters.
+ Whether this training resumes a previous training. This flag is used by the :func:`load` function of
+ :class:`OnPolicyRunner` to determine how to load further parameters.
"""
-
- # check if state_dict contains teacher and student or just teacher parameters
- if any("actor" in key for key in state_dict.keys()): # loading parameters from rl training
- # rename keys to match teacher and remove critic parameters
+ # Check if state_dict contains teacher and student or just teacher parameters
+ if any("actor" in key for key in state_dict): # Load parameters from rl training
+ # Rename keys to match teacher and remove critic parameters
teacher_state_dict = {}
teacher_obs_normalizer_state_dict = {}
for key, value in state_dict.items():
@@ -225,24 +220,24 @@ def load_state_dict(self, state_dict, strict=True):
teacher_obs_normalizer_state_dict[key.replace("actor_obs_normalizer.", "")] = value
self.teacher.load_state_dict(teacher_state_dict, strict=strict)
self.teacher_obs_normalizer.load_state_dict(teacher_obs_normalizer_state_dict, strict=strict)
- # also load recurrent memory if teacher is recurrent
+ # Also load recurrent memory if teacher is recurrent
if self.teacher_recurrent:
memory_t_state_dict = {}
for key, value in state_dict.items():
if "memory_a." in key:
memory_t_state_dict[key.replace("memory_a.", "")] = value
self.memory_t.load_state_dict(memory_t_state_dict, strict=strict)
- # set flag for successfully loading the parameters
+ # Set flag for successfully loading the parameters
self.loaded_teacher = True
self.teacher.eval()
self.teacher_obs_normalizer.eval()
- return False # training does not resume
- elif any("student" in key for key in state_dict.keys()): # loading parameters from distillation training
+ return False # Training does not resume
+ elif any("student" in key for key in state_dict): # Load parameters from distillation training
super().load_state_dict(state_dict, strict=strict)
- # set flag for successfully loading the parameters
+ # Set flag for successfully loading the parameters
self.loaded_teacher = True
self.teacher.eval()
self.teacher_obs_normalizer.eval()
- return True # training resumes
+ return True # Training resumes
else:
raise ValueError("state_dict does not contain student or teacher parameters")
diff --git a/rsl_rl/modules/symmetry.py b/rsl_rl/modules/symmetry.py
index b0175151..8e21d1da 100644
--- a/rsl_rl/modules/symmetry.py
+++ b/rsl_rl/modules/symmetry.py
@@ -5,20 +5,21 @@
from __future__ import annotations
+from rsl_rl.env import VecEnv
-def resolve_symmetry_config(alg_cfg, env):
+
+def resolve_symmetry_config(alg_cfg: dict, env: VecEnv) -> dict:
"""Resolve the symmetry configuration.
Args:
- alg_cfg: The algorithm configuration dictionary.
- env: The environment.
+ alg_cfg: Algorithm configuration dictionary.
+ env: Environment object.
Returns:
The resolved algorithm configuration dictionary.
"""
-
- # if using symmetry then pass the environment config object
+ # If using symmetry then pass the environment config object
+ # Note: This is used by the symmetry function for handling different observation terms
if "symmetry_cfg" in alg_cfg and alg_cfg["symmetry_cfg"] is not None:
- # this is used by the symmetry function for handling different observation terms
alg_cfg["symmetry_cfg"]["_env"] = env
return alg_cfg
diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py
index c18f487a..7ede0665 100644
--- a/rsl_rl/networks/__init__.py
+++ b/rsl_rl/networks/__init__.py
@@ -5,6 +5,14 @@
"""Definitions for components of modules."""
-from .memory import Memory
+from .memory import HiddenState, Memory
from .mlp import MLP
from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization
+
+__all__ = [
+ "MLP",
+ "EmpiricalDiscountedVariationNormalization",
+ "EmpiricalNormalization",
+ "HiddenState",
+ "Memory",
+]
diff --git a/rsl_rl/networks/memory.py b/rsl_rl/networks/memory.py
index 75773577..dd40afc2 100644
--- a/rsl_rl/networks/memory.py
+++ b/rsl_rl/networks/memory.py
@@ -5,66 +5,76 @@
from __future__ import annotations
+import torch
import torch.nn as nn
from rsl_rl.utils import unpad_trajectories
+HiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None
+"""Type alias for the hidden state of RNNs (GRU/LSTM).
+
+For GRUs, this is a single tensor while for LSTMs, this is a tuple of two tensors (hidden state and cell state).
+"""
+
class Memory(nn.Module):
"""Memory module for recurrent networks.
- This module is used to store the hidden states of the policy.
- Currently only supports GRU and LSTM.
+ This module is used to store the hidden state of the policy. It currently supports GRU and LSTM.
"""
- def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
+ def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None:
super().__init__()
- # RNN
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
- self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
- self.hidden_states = None
+ self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers)
+ self.hidden_state = None
- def forward(self, input, masks=None, hidden_states=None):
+ def forward(
+ self,
+ input: torch.Tensor,
+ masks: torch.Tensor | None = None,
+ hidden_state: HiddenState = None,
+ ) -> torch.Tensor:
batch_mode = masks is not None
if batch_mode:
- # batch mode: needs saved hidden states
- if hidden_states is None:
+ # Batch mode needs saved hidden states
+ if hidden_state is None:
raise ValueError("Hidden states not passed to memory module during policy update")
- out, _ = self.rnn(input, hidden_states)
+ out, _ = self.rnn(input, hidden_state)
out = unpad_trajectories(out, masks)
else:
- # inference/distillation mode: uses hidden states of last step
- out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
+ # Inference/distillation mode uses hidden state of last step
+ out, self.hidden_state = self.rnn(input.unsqueeze(0), self.hidden_state)
return out
- def reset(self, dones=None, hidden_states=None):
- if dones is None: # reset all hidden states
- if hidden_states is None:
- self.hidden_states = None
+ def reset(self, dones: torch.Tensor | None = None, hidden_state: HiddenState = None) -> None:
+ if dones is None: # Reset hidden state
+ if hidden_state is None:
+ self.hidden_state = None
else:
- self.hidden_states = hidden_states
- elif self.hidden_states is not None: # reset hidden states of done environments
- if hidden_states is None:
- if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
- for hidden_state in self.hidden_states:
+ self.hidden_state = hidden_state
+ elif self.hidden_state is not None: # Reset hidden state of done environments
+ if hidden_state is None:
+ if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM
+ for hidden_state in self.hidden_state:
hidden_state[..., dones == 1, :] = 0.0
else:
- self.hidden_states[..., dones == 1, :] = 0.0
+ self.hidden_state[..., dones == 1, :] = 0.0
else:
NotImplementedError(
- "Resetting hidden states of done environments with custom hidden states is not implemented"
+ "Resetting the hidden state of done environments with a custom hidden state is not implemented"
)
- def detach_hidden_states(self, dones=None):
- if self.hidden_states is not None:
- if dones is None: # detach all hidden states
- if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
- self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states)
+ def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None:
+ if self.hidden_state is not None:
+ if dones is None: # Detach hidden state
+ if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM
+ self.hidden_state = tuple(hidden_state.detach() for hidden_state in self.hidden_state)
else:
- self.hidden_states = self.hidden_states.detach()
- else: # detach hidden states of done environments
- if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
- for hidden_state in self.hidden_states:
+ self.hidden_state = self.hidden_state.detach()
+ else: # Detach hidden state of done environments
+ if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM
+ for hidden_state in self.hidden_state:
hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach()
else:
- self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach()
+ self.hidden_state[..., dones == 1, :] = self.hidden_state[..., dones == 1, :].detach()
diff --git a/rsl_rl/networks/mlp.py b/rsl_rl/networks/mlp.py
index e91574ea..f01a7577 100644
--- a/rsl_rl/networks/mlp.py
+++ b/rsl_rl/networks/mlp.py
@@ -15,17 +15,12 @@
class MLP(nn.Sequential):
"""Multi-layer perceptron.
- The MLP network is a sequence of linear layers and activation functions. The
- last layer is a linear layer that outputs the desired dimension unless the
- last activation function is specified.
+ The MLP network is a sequence of linear layers and activation functions. The last layer is a linear layer that
+ outputs the desired dimension unless the last activation function is specified.
It provides additional conveniences:
-
- - If the hidden dimensions have a value of ``-1``, the dimension is inferred
- from the input dimension.
- - If the output dimension is a tuple, the output is reshaped to the desired
- shape.
-
+ - If the hidden dimensions have a value of ``-1``, the dimension is inferred from the input dimension.
+ - If the output dimension is a tuple, the output is reshaped to the desired shape.
"""
def __init__(
@@ -35,27 +30,26 @@ def __init__(
hidden_dims: tuple[int] | list[int],
activation: str = "elu",
last_activation: str | None = None,
- ):
+ ) -> None:
"""Initialize the MLP.
Args:
input_dim: Dimension of the input.
output_dim: Dimension of the output.
- hidden_dims: Dimensions of the hidden layers. A value of ``-1`` indicates
- that the dimension should be inferred from the input dimension.
- activation: Activation function. Defaults to "elu".
- last_activation: Activation function of the last layer. Defaults to None,
- in which case the last layer is linear.
+ hidden_dims: Dimensions of the hidden layers. A value of ``-1`` indicates that the dimension should be
+ inferred from the input dimension.
+ activation: Activation function.
+ last_activation: Activation function of the last layer. None results in a linear last layer.
"""
super().__init__()
- # resolve activation functions
+ # Resolve activation functions
activation_mod = resolve_nn_activation(activation)
last_activation_mod = resolve_nn_activation(last_activation) if last_activation is not None else None
- # resolve number of hidden dims if they are -1
+ # Resolve number of hidden dims if they are -1
hidden_dims_processed = [input_dim if dim == -1 else dim for dim in hidden_dims]
- # create layers sequentially
+ # Create layers sequentially
layers = []
layers.append(nn.Linear(input_dim, hidden_dims_processed[0]))
layers.append(activation_mod)
@@ -64,32 +58,32 @@ def __init__(
layers.append(nn.Linear(hidden_dims_processed[layer_index], hidden_dims_processed[layer_index + 1]))
layers.append(activation_mod)
- # add last layer
+ # Add last layer
if isinstance(output_dim, int):
layers.append(nn.Linear(hidden_dims_processed[-1], output_dim))
else:
- # compute the total output dimension
+ # Compute the total output dimension
total_out_dim = reduce(lambda x, y: x * y, output_dim)
- # add a layer to reshape the output to the desired shape
+ # Add a layer to reshape the output to the desired shape
layers.append(nn.Linear(hidden_dims_processed[-1], total_out_dim))
layers.append(nn.Unflatten(dim=-1, unflattened_size=output_dim))
- # add last activation function if specified
+ # Add last activation function if specified
if last_activation_mod is not None:
layers.append(last_activation_mod)
- # register the layers
+ # Register the layers
for idx, layer in enumerate(layers):
self.add_module(f"{idx}", layer)
- def init_weights(self, scales: float | tuple[float]):
+ def init_weights(self, scales: float | tuple[float]) -> None:
"""Initialize the weights of the MLP.
Args:
scales: Scale factor for the weights.
"""
- def get_scale(idx) -> float:
+ def get_scale(idx: int) -> float:
"""Get the scale factor for the weights of the MLP.
Args:
@@ -97,7 +91,7 @@ def get_scale(idx) -> float:
"""
return scales[idx] if isinstance(scales, (list, tuple)) else scales
- # initialize the weights
+ # Initialize the weights
for idx, module in enumerate(self):
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=get_scale(idx))
@@ -112,9 +106,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self:
x = layer(x)
return x
-
- def reset(self, dones=None, hidden_states=None):
- pass
-
- def detach_hidden_states(self, dones=None):
- pass
diff --git a/rsl_rl/networks/normalization.py b/rsl_rl/networks/normalization.py
index 5fd96921..5077479f 100644
--- a/rsl_rl/networks/normalization.py
+++ b/rsl_rl/networks/normalization.py
@@ -14,16 +14,15 @@
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
- def __init__(self, shape, eps=1e-2, until=None):
+ def __init__(self, shape: int | tuple[int] | list[int], eps: float = 1e-2, until: int | None = None) -> None:
"""Initialize EmpiricalNormalization module.
- Args:
- shape (int or tuple of int): Shape of input values except batch axis.
- eps (float): Small value for stability.
- until (int or None): If this arg is specified, the module learns input values until the sum of batch sizes
- exceeds it.
+ .. note:: The normalization parameters are computed over the whole batch, not for each environment separately.
- Note: The normalization parameters are computed over the whole batch, not for each environment separately.
+ Args:
+ shape: Shape of input values except batch axis.
+ eps: Small value for stability.
+ until: If this arg is specified, the module learns input values until the sum of batch sizes exceeds it.
"""
super().__init__()
self.eps = eps
@@ -34,22 +33,20 @@ def __init__(self, shape, eps=1e-2, until=None):
self.register_buffer("count", torch.tensor(0, dtype=torch.long))
@property
- def mean(self):
+ def mean(self) -> torch.Tensor:
return self._mean.squeeze(0).clone()
@property
- def std(self):
+ def std(self) -> torch.Tensor:
return self._std.squeeze(0).clone()
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize mean and variance of values based on empirical values."""
-
return (x - self._mean) / (self._std + self.eps)
@torch.jit.unused
- def update(self, x):
- """Learn input values without computing the output values of them"""
-
+ def update(self, x: torch.Tensor) -> None:
+ """Learn input values without computing the output values of them."""
if not self.training:
return
if self.until is not None and self.count >= self.until:
@@ -66,45 +63,41 @@ def update(self, x):
self._std = torch.sqrt(self._var)
@torch.jit.unused
- def inverse(self, y):
+ def inverse(self, y: torch.Tensor) -> torch.Tensor:
"""De-normalize values based on empirical values."""
-
return y * (self._std + self.eps) + self._mean
class EmpiricalDiscountedVariationNormalization(nn.Module):
"""Reward normalization from Pathak's large scale study on PPO.
- Reward normalization. Since the reward function is non-stationary, it is useful to normalize
- the scale of the rewards so that the value function can learn quickly. We did this by dividing
- the rewards by a running estimate of the standard deviation of the sum of discounted rewards.
+ Reward normalization. Since the reward function is non-stationary, it is useful to normalize the scale of the
+ rewards so that the value function can learn quickly. We did this by dividing the rewards by a running estimate of
+ the standard deviation of the sum of discounted rewards.
"""
- def __init__(self, shape, eps=1e-2, gamma=0.99, until=None):
+ def __init__(
+ self, shape: int | tuple[int] | list[int], eps: float = 1e-2, gamma: float = 0.99, until: int | None = None
+ ) -> None:
super().__init__()
self.emp_norm = EmpiricalNormalization(shape, eps, until)
self.disc_avg = _DiscountedAverage(gamma)
- def forward(self, rew):
+ def forward(self, rew: torch.Tensor) -> torch.Tensor:
if self.training:
- # update discounted rewards
+ # Update discounted rewards
avg = self.disc_avg.update(rew)
- # update moments from discounted rewards
+ # Update moments from discounted rewards
self.emp_norm.update(avg)
- # normalize rewards with the empirical std
+ # Normalize rewards with the empirical std
if self.emp_norm._std > 0:
return rew / self.emp_norm._std
else:
return rew
-"""
-Helper class.
-"""
-
-
class _DiscountedAverage:
r"""Discounted average of rewards.
@@ -113,12 +106,9 @@ class _DiscountedAverage:
.. math::
\bar{R}_t = \gamma \bar{R}_{t-1} + r_t
-
- Args:
- gamma (float): Discount factor.
"""
- def __init__(self, gamma):
+ def __init__(self, gamma: float) -> None:
self.avg = None
self.gamma = gamma
diff --git a/rsl_rl/runners/__init__.py b/rsl_rl/runners/__init__.py
index d6f630b6..568ebc3b 100644
--- a/rsl_rl/runners/__init__.py
+++ b/rsl_rl/runners/__init__.py
@@ -5,7 +5,7 @@
"""Implementation of runners for environment-agent interaction."""
-from .on_policy_runner import OnPolicyRunner # isort:skip
+from .on_policy_runner import OnPolicyRunner # noqa: I001
from .distillation_runner import DistillationRunner
-__all__ = ["OnPolicyRunner", "DistillationRunner"]
+__all__ = ["DistillationRunner", "OnPolicyRunner"]
diff --git a/rsl_rl/runners/distillation_runner.py b/rsl_rl/runners/distillation_runner.py
index 9cc6a8b1..6f9c502b 100644
--- a/rsl_rl/runners/distillation_runner.py
+++ b/rsl_rl/runners/distillation_runner.py
@@ -9,6 +9,7 @@
import time
import torch
from collections import deque
+from tensordict import TensorDict
import rsl_rl
from rsl_rl.algorithms import Distillation
@@ -21,29 +22,29 @@
class DistillationRunner(OnPolicyRunner):
"""On-policy runner for training and evaluation of teacher-student training."""
- def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device="cpu"):
+ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device: str = "cpu") -> None:
self.cfg = train_cfg
self.alg_cfg = train_cfg["algorithm"]
self.policy_cfg = train_cfg["policy"]
self.device = device
self.env = env
- # check if multi-gpu is enabled
+ # Check if multi-GPU is enabled
self._configure_multi_gpu()
- # store training configuration
+ # Store training configuration
self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"]
- # query observations from environment for algorithm construction
+ # Query observations from environment for algorithm construction
obs = self.env.get_observations()
self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets=["teacher"])
- # create the algorithm
+ # Create the algorithm
self.alg = self._construct_algorithm(obs)
# Decide whether to disable logging
- # We only log from the process with rank 0 (main process)
+ # Note: We only log from the process with rank 0 (main process)
self.disable_logs = self.is_distributed and self.gpu_global_rank != 0
# Logging
@@ -54,20 +55,20 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev
self.current_learning_iteration = 0
self.git_status_repos = [rsl_rl.__file__]
- def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901
- # initialize writer
+ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None:
+ # Initialize writer
self._prepare_logging_writer()
- # check if teacher is loaded
+ # Check if teacher is loaded
if not self.alg.policy.loaded_teacher:
raise ValueError("Teacher model parameters not loaded. Please load a teacher model to distill.")
- # randomize initial episode lengths (for exploration)
+ # Randomize initial episode lengths (for exploration)
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=int(self.env.max_episode_length)
)
- # start learning
+ # Start learning
obs = self.env.get_observations().to(self.device)
self.train_mode() # switch to train mode (for dropout for example)
@@ -97,9 +98,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
obs, rewards, dones, extras = self.env.step(actions.to(self.env.device))
# Move to device
obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device))
- # process the step
+ # Process the step
self.alg.process_env_step(obs, rewards, dones, extras)
- # book keeping
+ # Book keeping
if self.log_dir is not None:
if "episode" in extras:
ep_infos.append(extras["episode"])
@@ -120,13 +121,13 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
collection_time = stop - start
start = stop
- # update policy
+ # Update policy
loss_dict = self.alg.update()
stop = time.time()
learn_time = stop - start
self.current_learning_iteration = it
- # log info
+
if self.log_dir is not None and not self.disable_logs:
# Log information
self.log(locals())
@@ -138,9 +139,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
ep_infos.clear()
# Save code state
if it == start_iter and not self.disable_logs:
- # obtain all the diff files
+ # Obtain all the diff files
git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
- # if possible store them to wandb
+ # If possible store them to wandb or neptune
if self.logger_type in ["wandb", "neptune"] and git_file_paths:
for path in git_file_paths:
self.writer.save_file(path)
@@ -149,25 +150,21 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
if self.log_dir is not None and not self.disable_logs:
self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
- """
- Helper methods.
- """
-
- def _construct_algorithm(self, obs) -> Distillation:
+ def _construct_algorithm(self, obs: TensorDict) -> Distillation:
"""Construct the distillation algorithm."""
- # initialize the actor-critic
+ # Initialize the policy
student_teacher_class = eval(self.policy_cfg.pop("class_name"))
student_teacher: StudentTeacher | StudentTeacherRecurrent = student_teacher_class(
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)
- # initialize the algorithm
+ # Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg: Distillation = alg_class(
student_teacher, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg
)
- # initialize the storage
+ # Initialize the storage
alg.init_storage(
"distillation",
self.env.num_envs,
diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py
index 36f11f37..46a9b524 100644
--- a/rsl_rl/runners/on_policy_runner.py
+++ b/rsl_rl/runners/on_policy_runner.py
@@ -11,6 +11,7 @@
import torch
import warnings
from collections import deque
+from tensordict import TensorDict
import rsl_rl
from rsl_rl.algorithms import PPO
@@ -22,32 +23,32 @@
class OnPolicyRunner:
"""On-policy runner for training and evaluation of actor-critic methods."""
- def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device="cpu"):
+ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device: str = "cpu") -> None:
self.cfg = train_cfg
self.alg_cfg = train_cfg["algorithm"]
self.policy_cfg = train_cfg["policy"]
self.device = device
self.env = env
- # check if multi-gpu is enabled
+ # Check if multi-GPU is enabled
self._configure_multi_gpu()
- # store training configuration
+ # Store training configuration
self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"]
- # query observations from environment for algorithm construction
+ # Query observations from environment for algorithm construction
obs = self.env.get_observations()
default_sets = ["critic"]
if "rnd_cfg" in self.alg_cfg and self.alg_cfg["rnd_cfg"] is not None:
default_sets.append("rnd_state")
self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets)
- # create the algorithm
+ # Create the algorithm
self.alg = self._construct_algorithm(obs)
# Decide whether to disable logging
- # We only log from the process with rank 0 (main process)
+ # Note: We only log from the process with rank 0 (main process)
self.disable_logs = self.is_distributed and self.gpu_global_rank != 0
# Logging
@@ -58,17 +59,17 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev
self.current_learning_iteration = 0
self.git_status_repos = [rsl_rl.__file__]
- def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901
- # initialize writer
+ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None:
+ # Initialize writer
self._prepare_logging_writer()
- # randomize initial episode lengths (for exploration)
+ # Randomize initial episode lengths (for exploration)
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=int(self.env.max_episode_length)
)
- # start learning
+ # Start learning
obs = self.env.get_observations().to(self.device)
self.train_mode() # switch to train mode (for dropout for example)
@@ -79,7 +80,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
- # create buffers for logging extrinsic and intrinsic rewards
+ # Create buffers for logging extrinsic and intrinsic rewards
if self.alg.rnd:
erewbuffer = deque(maxlen=100)
irewbuffer = deque(maxlen=100)
@@ -105,11 +106,11 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
obs, rewards, dones, extras = self.env.step(actions.to(self.env.device))
# Move to device
obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device))
- # process the step
+ # Process the step
self.alg.process_env_step(obs, rewards, dones, extras)
# Extract intrinsic rewards (only for logging)
intrinsic_rewards = self.alg.intrinsic_rewards if self.alg.rnd else None
- # book keeping
+ # Book keeping
if self.log_dir is not None:
if "episode" in extras:
ep_infos.append(extras["episode"])
@@ -118,20 +119,18 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
# Update rewards
if self.alg.rnd:
cur_ereward_sum += rewards
- cur_ireward_sum += intrinsic_rewards # type: ignore
+ cur_ireward_sum += intrinsic_rewards
cur_reward_sum += rewards + intrinsic_rewards
else:
cur_reward_sum += rewards
# Update episode length
cur_episode_length += 1
# Clear data for completed episodes
- # -- common
new_ids = (dones > 0).nonzero(as_tuple=False)
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
cur_reward_sum[new_ids] = 0
cur_episode_length[new_ids] = 0
- # -- intrinsic and extrinsic rewards
if self.alg.rnd:
erewbuffer.extend(cur_ereward_sum[new_ids][:, 0].cpu().numpy().tolist())
irewbuffer.extend(cur_ireward_sum[new_ids][:, 0].cpu().numpy().tolist())
@@ -142,16 +141,16 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
collection_time = stop - start
start = stop
- # compute returns
+ # Compute returns
self.alg.compute_returns(obs)
- # update policy
+ # Update policy
loss_dict = self.alg.update()
stop = time.time()
learn_time = stop - start
self.current_learning_iteration = it
- # log info
+
if self.log_dir is not None and not self.disable_logs:
# Log information
self.log(locals())
@@ -163,9 +162,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
ep_infos.clear()
# Save code state
if it == start_iter and not self.disable_logs:
- # obtain all the diff files
+ # Obtain all the diff files
git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
- # if possible store them to wandb
+ # If possible store them to wandb or neptune
if self.logger_type in ["wandb", "neptune"] and git_file_paths:
for path in git_file_paths:
self.writer.save_file(path)
@@ -174,7 +173,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
if self.log_dir is not None and not self.disable_logs:
self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
- def log(self, locs: dict, width: int = 80, pad: int = 35):
+ def log(self, locs: dict, width: int = 80, pad: int = 35) -> None:
# Compute the collection size
collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size
# Update total time-steps and time
@@ -182,13 +181,13 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
self.tot_time += locs["collection_time"] + locs["learn_time"]
iteration_time = locs["collection_time"] + locs["learn_time"]
- # -- Episode info
+ # Log episode information
ep_string = ""
if locs["ep_infos"]:
for key in locs["ep_infos"][0]:
infotensor = torch.tensor([], device=self.device)
for ep_info in locs["ep_infos"]:
- # handle scalar and zero dimensional tensor infos
+ # Handle scalar and zero dimensional tensor infos
if key not in ep_info:
continue
if not isinstance(ep_info[key], torch.Tensor):
@@ -197,38 +196,38 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
ep_info[key] = ep_info[key].unsqueeze(0)
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
value = torch.mean(infotensor)
- # log to logger and terminal
+ # Log to logger and terminal
if "/" in key:
self.writer.add_scalar(key, value, locs["it"])
- ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
+ ep_string += f"""{f"{key}:":>{pad}} {value:.4f}\n"""
else:
self.writer.add_scalar("Episode/" + key, value, locs["it"])
- ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
+ ep_string += f"""{f"Mean episode {key}:":>{pad}} {value:.4f}\n"""
mean_std = self.alg.policy.action_std.mean()
fps = int(collection_size / (locs["collection_time"] + locs["learn_time"]))
- # -- Losses
+ # Log losses
for key, value in locs["loss_dict"].items():
self.writer.add_scalar(f"Loss/{key}", value, locs["it"])
self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
- # -- Policy
+ # Log noise std
self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
- # -- Performance
+ # Log performance
self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])
- # -- Training
+ # Log training
if len(locs["rewbuffer"]) > 0:
- # separate logging for intrinsic and extrinsic rewards
+ # Separate logging for intrinsic and extrinsic rewards
if hasattr(self.alg, "rnd") and self.alg.rnd:
self.writer.add_scalar("Rnd/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), locs["it"])
self.writer.add_scalar("Rnd/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), locs["it"])
self.writer.add_scalar("Rnd/weight", self.alg.rnd.weight, locs["it"])
- # everything else
+ # Everything else
self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"])
self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"])
if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging
@@ -241,145 +240,144 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
if len(locs["rewbuffer"]) > 0:
log_string = (
- f"""{'#' * width}\n"""
- f"""{str.center(width, ' ')}\n\n"""
- f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
- 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
- f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
+ f"""{"#" * width}\n"""
+ f"""{str.center(width, " ")}\n\n"""
+ f"""{"Computation:":>{pad}} {fps:.0f} steps/s (collection: {locs["collection_time"]:.3f}s, learning {
+ locs["learn_time"]:.3f}s)\n"""
+ f"""{"Mean action noise std:":>{pad}} {mean_std.item():.2f}\n"""
)
- # -- Losses
+ # Print losses
for key, value in locs["loss_dict"].items():
- log_string += f"""{f'Mean {key} loss:':>{pad}} {value:.4f}\n"""
- # -- Rewards
+ log_string += f"""{f"Mean {key} loss:":>{pad}} {value:.4f}\n"""
+ # Print rewards
if hasattr(self.alg, "rnd") and self.alg.rnd:
log_string += (
- f"""{'Mean extrinsic reward:':>{pad}} {statistics.mean(locs['erewbuffer']):.2f}\n"""
- f"""{'Mean intrinsic reward:':>{pad}} {statistics.mean(locs['irewbuffer']):.2f}\n"""
+ f"""{"Mean extrinsic reward:":>{pad}} {statistics.mean(locs["erewbuffer"]):.2f}\n"""
+ f"""{"Mean intrinsic reward:":>{pad}} {statistics.mean(locs["irewbuffer"]):.2f}\n"""
)
- log_string += f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
- # -- episode info
- log_string += f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
+ log_string += f"""{"Mean reward:":>{pad}} {statistics.mean(locs["rewbuffer"]):.2f}\n"""
+ # Print episode information
+ log_string += f"""{"Mean episode length:":>{pad}} {statistics.mean(locs["lenbuffer"]):.2f}\n"""
else:
log_string = (
- f"""{'#' * width}\n"""
- f"""{str.center(width, ' ')}\n\n"""
- f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
- 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
- f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
+ f"""{"#" * width}\n"""
+ f"""{str.center(width, " ")}\n\n"""
+ f"""{"Computation:":>{pad}} {fps:.0f} steps/s (collection: {locs["collection_time"]:.3f}s, learning {
+ locs["learn_time"]:.3f}s)\n"""
+ f"""{"Mean action noise std:":>{pad}} {mean_std.item():.2f}\n"""
)
for key, value in locs["loss_dict"].items():
- log_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
+ log_string += f"""{f"{key}:":>{pad}} {value:.4f}\n"""
log_string += ep_string
log_string += (
- f"""{'-' * width}\n"""
- f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
- f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
- f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n"""
- f"""{'ETA:':>{pad}} {time.strftime(
- "%H:%M:%S",
- time.gmtime(
- self.tot_time / (locs['it'] - locs['start_iter'] + 1)
- * (locs['start_iter'] + locs['num_learning_iterations'] - locs['it'])
+ f"""{"-" * width}\n"""
+ f"""{"Total timesteps:":>{pad}} {self.tot_timesteps}\n"""
+ f"""{"Iteration time:":>{pad}} {iteration_time:.2f}s\n"""
+ f"""{"Time elapsed:":>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n"""
+ f"""{"ETA:":>{pad}} {
+ time.strftime(
+ "%H:%M:%S",
+ time.gmtime(
+ self.tot_time
+ / (locs["it"] - locs["start_iter"] + 1)
+ * (locs["start_iter"] + locs["num_learning_iterations"] - locs["it"])
+ ),
)
- )}\n"""
+ }\n"""
)
print(log_string)
- def save(self, path: str, infos=None):
- # -- Save model
+ def save(self, path: str, infos: dict | None = None) -> None:
+ # Save model
saved_dict = {
"model_state_dict": self.alg.policy.state_dict(),
"optimizer_state_dict": self.alg.optimizer.state_dict(),
"iter": self.current_learning_iteration,
"infos": infos,
}
- # -- Save RND model if used
+ # Save RND model if used
if hasattr(self.alg, "rnd") and self.alg.rnd:
saved_dict["rnd_state_dict"] = self.alg.rnd.state_dict()
saved_dict["rnd_optimizer_state_dict"] = self.alg.rnd_optimizer.state_dict()
torch.save(saved_dict, path)
- # upload model to external logging service
+ # Upload model to external logging service
if self.logger_type in ["neptune", "wandb"] and not self.disable_logs:
self.writer.save_model(path, self.current_learning_iteration)
- def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None):
+ def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None) -> dict:
loaded_dict = torch.load(path, weights_only=False, map_location=map_location)
- # -- Load model
+ # Load model
resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"])
- # -- Load RND model if used
+ # Load RND model if used
if hasattr(self.alg, "rnd") and self.alg.rnd:
self.alg.rnd.load_state_dict(loaded_dict["rnd_state_dict"])
- # -- load optimizer if used
+ # Load optimizer if used
if load_optimizer and resumed_training:
- # -- algorithm optimizer
+ # Algorithm optimizer
self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
- # -- RND optimizer if used
+ # RND optimizer if used
if hasattr(self.alg, "rnd") and self.alg.rnd:
self.alg.rnd_optimizer.load_state_dict(loaded_dict["rnd_optimizer_state_dict"])
- # -- load current learning iteration
+ # Load current learning iteration
if resumed_training:
self.current_learning_iteration = loaded_dict["iter"]
return loaded_dict["infos"]
- def get_inference_policy(self, device=None):
- self.eval_mode() # switch to evaluation mode (dropout for example)
+ def get_inference_policy(self, device: str | None = None) -> callable:
+ self.eval_mode() # Switch to evaluation mode (e.g. for dropout)
if device is not None:
self.alg.policy.to(device)
return self.alg.policy.act_inference
- def train_mode(self):
- # -- PPO
+ def train_mode(self) -> None:
+ # PPO
self.alg.policy.train()
- # -- RND
+ # RND
if hasattr(self.alg, "rnd") and self.alg.rnd:
self.alg.rnd.train()
- def eval_mode(self):
- # -- PPO
+ def eval_mode(self) -> None:
+ # PPO
self.alg.policy.eval()
- # -- RND
+ # RND
if hasattr(self.alg, "rnd") and self.alg.rnd:
self.alg.rnd.eval()
- def add_git_repo_to_log(self, repo_file_path):
+ def add_git_repo_to_log(self, repo_file_path: str) -> None:
self.git_status_repos.append(repo_file_path)
- """
- Helper functions.
- """
-
- def _configure_multi_gpu(self):
+ def _configure_multi_gpu(self) -> None:
"""Configure multi-gpu training."""
- # check if distributed training is enabled
+ # Check if distributed training is enabled
self.gpu_world_size = int(os.getenv("WORLD_SIZE", "1"))
self.is_distributed = self.gpu_world_size > 1
- # if not distributed training, set local and global rank to 0 and return
+ # If not distributed training, set local and global rank to 0 and return
if not self.is_distributed:
self.gpu_local_rank = 0
self.gpu_global_rank = 0
self.multi_gpu_cfg = None
return
- # get rank and world size
+ # Get rank and world size
self.gpu_local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.gpu_global_rank = int(os.getenv("RANK", "0"))
- # make a configuration dictionary
+ # Make a configuration dictionary
self.multi_gpu_cfg = {
- "global_rank": self.gpu_global_rank, # rank of the main process
- "local_rank": self.gpu_local_rank, # rank of the current process
- "world_size": self.gpu_world_size, # total number of processes
+ "global_rank": self.gpu_global_rank, # Rank of the main process
+ "local_rank": self.gpu_local_rank, # Rank of the current process
+ "world_size": self.gpu_world_size, # Total number of processes
}
- # check if user has device specified for local rank
+ # Check if user has device specified for local rank
if self.device != f"cuda:{self.gpu_local_rank}":
raise ValueError(
f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'."
)
- # validate multi-gpu configuration
+ # Validate multi-gpu configuration
if self.gpu_local_rank >= self.gpu_world_size:
raise ValueError(
f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
@@ -389,20 +387,20 @@ def _configure_multi_gpu(self):
f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
)
- # initialize torch distributed
+ # Initialize torch distributed
torch.distributed.init_process_group(backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size)
- # set device to the local rank
+ # Set device to the local rank
torch.cuda.set_device(self.gpu_local_rank)
- def _construct_algorithm(self, obs) -> PPO:
+ def _construct_algorithm(self, obs: TensorDict) -> PPO:
"""Construct the actor-critic algorithm."""
- # resolve RND config
+ # Resolve RND config
self.alg_cfg = resolve_rnd_config(self.alg_cfg, obs, self.cfg["obs_groups"], self.env)
- # resolve symmetry config
+ # Resolve symmetry config
self.alg_cfg = resolve_symmetry_config(self.alg_cfg, self.env)
- # resolve deprecated normalization config
+ # Resolve deprecated normalization config
if self.cfg.get("empirical_normalization") is not None:
warnings.warn(
"The `empirical_normalization` parameter is deprecated. Please set `actor_obs_normalization` and "
@@ -414,17 +412,17 @@ def _construct_algorithm(self, obs) -> PPO:
if self.policy_cfg.get("critic_obs_normalization") is None:
self.policy_cfg["critic_obs_normalization"] = self.cfg["empirical_normalization"]
- # initialize the actor-critic
+ # Initialize the policy
actor_critic_class = eval(self.policy_cfg.pop("class_name"))
actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class(
obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
).to(self.device)
- # initialize the algorithm
+ # Initialize the algorithm
alg_class = eval(self.alg_cfg.pop("class_name"))
alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg)
- # initialize the storage
+ # Initialize the storage
alg.init_storage(
"rl",
self.env.num_envs,
@@ -435,10 +433,10 @@ def _construct_algorithm(self, obs) -> PPO:
return alg
- def _prepare_logging_writer(self):
- """Prepares the logging writers."""
+ def _prepare_logging_writer(self) -> None:
+ """Prepare the logging writers."""
if self.log_dir is not None and self.writer is None and not self.disable_logs:
- # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard.
+ # Launch either Tensorboard or Neptune or Tensorboard summary writer, default: Tensorboard.
self.logger_type = self.cfg.get("logger", "tensorboard")
self.logger_type = self.logger_type.lower()
diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py
index e9309b3b..539a57bb 100644
--- a/rsl_rl/storage/rollout_storage.py
+++ b/rsl_rl/storage/rollout_storage.py
@@ -6,38 +6,39 @@
from __future__ import annotations
import torch
+from collections.abc import Generator
from tensordict import TensorDict
+from rsl_rl.networks import HiddenState
from rsl_rl.utils import split_and_pad_trajectories
class RolloutStorage:
class Transition:
- def __init__(self):
- self.observations = None
- self.actions = None
- self.privileged_actions = None
- self.rewards = None
- self.dones = None
- self.values = None
- self.actions_log_prob = None
- self.action_mean = None
- self.action_sigma = None
- self.hidden_states = None
-
- def clear(self):
+ def __init__(self) -> None:
+ self.observations: TensorDict | None = None
+ self.actions: torch.Tensor | None = None
+ self.privileged_actions: torch.Tensor | None = None
+ self.rewards: torch.Tensor | None = None
+ self.dones: torch.Tensor | None = None
+ self.values: torch.Tensor | None = None
+ self.actions_log_prob: torch.Tensor
+ self.action_mean: torch.Tensor | None = None
+ self.action_sigma: torch.Tensor | None = None
+ self.hidden_states: tuple[HiddenState, HiddenState] = (None, None)
+
+ def clear(self) -> None:
self.__init__()
def __init__(
self,
- training_type,
- num_envs,
- num_transitions_per_env,
- obs,
- actions_shape,
- device="cpu",
- ):
- # store inputs
+ training_type: str,
+ num_envs: int,
+ num_transitions_per_env: int,
+ obs: TensorDict,
+ actions_shape: tuple[int] | list[int],
+ device: str = "cpu",
+ ) -> None:
self.training_type = training_type
self.device = device
self.num_transitions_per_env = num_transitions_per_env
@@ -54,11 +55,11 @@ def __init__(
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
- # for distillation
+ # For distillation
if training_type == "distillation":
self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
- # for reinforcement learning
+ # For reinforcement learning
if training_type == "rl":
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
@@ -68,14 +69,14 @@ def __init__(
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
# For RNN networks
- self.saved_hidden_states_a = None
- self.saved_hidden_states_c = None
+ self.saved_hidden_state_a = None
+ self.saved_hidden_state_c = None
- # counter for the number of transitions stored
+ # Counter for the number of transitions stored
self.step = 0
- def add_transitions(self, transition: Transition):
- # check if the transition is valid
+ def add_transitions(self, transition: Transition) -> None:
+ # Check if the transition is valid
if self.step >= self.num_transitions_per_env:
raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.")
@@ -85,11 +86,11 @@ def add_transitions(self, transition: Transition):
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
- # for distillation
+ # For distillation
if self.training_type == "distillation":
self.privileged_actions[self.step].copy_(transition.privileged_actions)
- # for reinforcement learning
+ # For reinforcement learning
if self.training_type == "rl":
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
@@ -99,39 +100,40 @@ def add_transitions(self, transition: Transition):
# For RNN networks
self._save_hidden_states(transition.hidden_states)
- # increment the counter
+ # Increment the counter
self.step += 1
- def _save_hidden_states(self, hidden_states):
- if hidden_states is None or hidden_states == (None, None):
+ def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None:
+ if hidden_states == (None, None):
return
- # make a tuple out of GRU hidden state sto match the LSTM format
- hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
- hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
- # initialize if needed
- if self.saved_hidden_states_a is None:
- self.saved_hidden_states_a = [
- torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))
+ # Make a tuple out of GRU hidden states to match the LSTM format
+ hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
+ hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
+ # Initialize hidden states if needed
+ if self.saved_hidden_state_a is None:
+ self.saved_hidden_state_a = [
+ torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device)
+ for i in range(len(hidden_state_a))
]
- self.saved_hidden_states_c = [
- torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))
+ self.saved_hidden_state_c = [
+ torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device)
+ for i in range(len(hidden_state_c))
]
- # copy the states
- for i in range(len(hid_a)):
- self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
- self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
+ # Copy the states
+ for i in range(len(hidden_state_a)):
+ self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i])
+ self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i])
- def clear(self):
+ def clear(self) -> None:
self.step = 0
- def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True):
+ def compute_returns(
+ self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True
+ ) -> None:
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
- # if we are at the last step, bootstrap the return value
- if step == self.num_transitions_per_env - 1:
- next_values = last_values
- else:
- next_values = self.values[step + 1]
+ # If we are at the last step, bootstrap the return value
+ next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1]
# 1 if we are not in a terminal state, 0 otherwise
next_is_not_terminal = 1.0 - self.dones[step].float()
# TD error: r_t + gamma * V(s_{t+1}) - V(s_t)
@@ -144,20 +146,20 @@ def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = T
# Compute the advantages
self.advantages = self.returns - self.values
# Normalize the advantages if flag is set
- # This is to prevent double normalization (i.e. if per minibatch normalization is used)
+ # Note: This is to prevent double normalization (i.e. if per minibatch normalization is used)
if normalize_advantage:
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
- # for distillation
- def generator(self):
+ # For distillation
+ def generator(self) -> Generator:
if self.training_type != "distillation":
raise ValueError("This function is only available for distillation training.")
for i in range(self.num_transitions_per_env):
yield self.observations[i], self.actions[i], self.privileged_actions[i], self.dones[i]
- # for reinforcement learning with feedforward networks
- def mini_batch_generator(self, num_mini_batches, num_epochs=8):
+ # For reinforcement learning with feedforward networks
+ def mini_batch_generator(self, num_mini_batches: int, num_epochs: int = 8) -> Generator:
if self.training_type != "rl":
raise ValueError("This function is only available for reinforcement learning training.")
batch_size = self.num_envs * self.num_transitions_per_env
@@ -180,15 +182,12 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8):
for i in range(num_mini_batches):
# Select the indices for the mini-batch
start = i * mini_batch_size
- end = (i + 1) * mini_batch_size
- batch_idx = indices[start:end]
+ stop = (i + 1) * mini_batch_size
+ batch_idx = indices[start:stop]
# Create the mini-batch
- # -- Core
obs_batch = observations[batch_idx]
actions_batch = actions[batch_idx]
-
- # -- For PPO
target_values_batch = values[batch_idx]
returns_batch = returns[batch_idx]
old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
@@ -196,14 +195,29 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8):
old_mu_batch = old_mu[batch_idx]
old_sigma_batch = old_sigma[batch_idx]
- # yield the mini-batch
- yield obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
- None,
- None,
- ), None
-
- # for reinfrocement learning with recurrent networks
- def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
+ hidden_state_a_batch = None
+ hidden_state_c_batch = None
+ masks_batch = None
+
+ # Yield the mini-batch
+ yield (
+ obs_batch,
+ actions_batch,
+ target_values_batch,
+ advantages_batch,
+ returns_batch,
+ old_actions_log_prob_batch,
+ old_mu_batch,
+ old_sigma_batch,
+ (
+ hidden_state_a_batch,
+ hidden_state_c_batch,
+ ),
+ masks_batch,
+ )
+
+ # For reinforcement learning with recurrent networks
+ def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int = 8) -> Generator:
if self.training_type != "rl":
raise ValueError("This function is only available for reinforcement learning training.")
padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
@@ -232,29 +246,46 @@ def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
values_batch = self.values[:, start:stop]
old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
- # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
- # then take only time steps after dones (flattens num envs and time dimensions),
- # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
+ # Reshape to [num_envs, time, num layers, hidden dim]
+ # Original shape: [time, num_layers, num_envs, hidden_dim])
last_was_done = last_was_done.permute(1, 0)
- hid_a_batch = [
- saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
+ # Take only time steps after dones (flattens num envs and time dimensions),
+ # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
+ hidden_state_a_batch = [
+ saved_hidden_state.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
.transpose(1, 0)
.contiguous()
- for saved_hidden_states in self.saved_hidden_states_a
+ for saved_hidden_state in self.saved_hidden_state_a
]
- hid_c_batch = [
- saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
+ hidden_state_c_batch = [
+ saved_hidden_state.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
.transpose(1, 0)
.contiguous()
- for saved_hidden_states in self.saved_hidden_states_c
+ for saved_hidden_state in self.saved_hidden_state_c
]
- # remove the tuple for GRU
- hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch
- hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch
-
- yield obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
- hid_a_batch,
- hid_c_batch,
- ), masks_batch
+ # Remove the tuple for GRU
+ hidden_state_a_batch = (
+ hidden_state_a_batch[0] if len(hidden_state_a_batch) == 1 else hidden_state_a_batch
+ )
+ hidden_state_c_batch = (
+ hidden_state_c_batch[0] if len(hidden_state_c_batch) == 1 else hidden_state_c_batch
+ )
+
+ # Yield the mini-batch
+ yield (
+ obs_batch,
+ actions_batch,
+ values_batch,
+ advantages_batch,
+ returns_batch,
+ old_actions_log_prob_batch,
+ old_mu_batch,
+ old_sigma_batch,
+ (
+ hidden_state_a_batch,
+ hidden_state_c_batch,
+ ),
+ masks_batch,
+ )
first_traj = last_traj
diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py
index 0a7deab6..a11074e0 100644
--- a/rsl_rl/utils/__init__.py
+++ b/rsl_rl/utils/__init__.py
@@ -5,4 +5,22 @@
"""Helper functions."""
-from .utils import *
+from .utils import (
+ resolve_nn_activation,
+ resolve_obs_groups,
+ resolve_optimizer,
+ split_and_pad_trajectories,
+ store_code_state,
+ string_to_callable,
+ unpad_trajectories,
+)
+
+__all__ = [
+ "resolve_nn_activation",
+ "resolve_obs_groups",
+ "resolve_optimizer",
+ "split_and_pad_trajectories",
+ "store_code_state",
+ "string_to_callable",
+ "unpad_trajectories",
+]
diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py
index 3796ec8b..90bcd7ef 100644
--- a/rsl_rl/utils/neptune_utils.py
+++ b/rsl_rl/utils/neptune_utils.py
@@ -12,14 +12,14 @@
try:
import neptune
except ModuleNotFoundError:
- raise ModuleNotFoundError("neptune-client is required to log to Neptune.")
+ raise ModuleNotFoundError("neptune-client is required to log to Neptune.") from None
class NeptuneLogger:
- def __init__(self, project, token):
+ def __init__(self, project: str, token: str) -> None:
self.run = neptune.init_run(project=project, api_token=token)
- def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
+ def store_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
self.run["runner_cfg"] = runner_cfg
self.run["policy_cfg"] = policy_cfg
self.run["alg_cfg"] = alg_cfg
@@ -29,48 +29,45 @@ def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
class NeptuneSummaryWriter(SummaryWriter):
"""Summary writer for Neptune."""
- def __init__(self, log_dir: str, flush_secs: int, cfg):
+ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None:
super().__init__(log_dir, flush_secs)
try:
project = cfg["neptune_project"]
except KeyError:
- raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.")
+ raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.") from None
try:
token = os.environ["NEPTUNE_API_TOKEN"]
except KeyError:
raise KeyError(
"Neptune api token not found. Please run or add to ~/.bashrc: export NEPTUNE_API_TOKEN=YOUR_API_TOKEN"
- )
+ ) from None
try:
entity = os.environ["NEPTUNE_USERNAME"]
except KeyError:
raise KeyError(
"Neptune username not found. Please run or add to ~/.bashrc: export NEPTUNE_USERNAME=YOUR_USERNAME"
- )
+ ) from None
neptune_project = entity + "/" + project
-
self.neptune_logger = NeptuneLogger(neptune_project, token)
-
self.name_map = {
"Train/mean_reward/time": "Train/mean_reward_time",
"Train/mean_episode_length/time": "Train/mean_episode_length_time",
}
-
run_name = os.path.split(log_dir)[-1]
-
self.neptune_logger.run["log_dir"].log(run_name)
- def _map_path(self, path):
- if path in self.name_map:
- return self.name_map[path]
- else:
- return path
-
- def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False):
+ def add_scalar(
+ self,
+ tag: str,
+ scalar_value: float,
+ global_step: int | None = None,
+ walltime: float | None = None,
+ new_style: bool = False,
+ ) -> None:
super().add_scalar(
tag,
scalar_value,
@@ -80,15 +77,21 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_sty
)
self.neptune_logger.run[self._map_path(tag)].log(scalar_value, step=global_step)
- def stop(self):
+ def stop(self) -> None:
self.neptune_logger.run.stop()
- def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
+ def log_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
self.neptune_logger.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
- def save_model(self, model_path, iter):
+ def save_model(self, model_path: str, iter: int) -> None:
self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path)
- def save_file(self, path, iter=None):
+ def save_file(self, path: str) -> None:
name = path.rsplit("/", 1)[-1].split(".")[0]
self.neptune_logger.run["git_diff/" + name].upload(path)
+
+ def _map_path(self, path: str) -> str:
+ if path in self.name_map:
+ return self.name_map[path]
+ else:
+ return path
diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py
index da19a0b7..7a044e83 100644
--- a/rsl_rl/utils/utils.py
+++ b/rsl_rl/utils/utils.py
@@ -16,10 +16,10 @@
def resolve_nn_activation(act_name: str) -> torch.nn.Module:
- """Resolves the activation function from the name.
+ """Resolve the activation function from the name.
Args:
- act_name: The name of the activation function.
+ act_name: Name of the activation function.
Returns:
The activation function.
@@ -50,10 +50,10 @@ def resolve_nn_activation(act_name: str) -> torch.nn.Module:
def resolve_optimizer(optimizer_name: str) -> torch.optim.Optimizer:
- """Resolves the optimizer from the name.
+ """Resolve the optimizer from the name.
Args:
- optimizer_name: The name of the optimizer.
+ optimizer_name: Name of the optimizer.
Returns:
The optimizer.
@@ -78,10 +78,12 @@ def resolve_optimizer(optimizer_name: str) -> torch.optim.Optimizer:
def split_and_pad_trajectories(
tensor: torch.Tensor | TensorDict, dones: torch.Tensor
) -> tuple[torch.Tensor | TensorDict, torch.Tensor]:
- """Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length of the longest
- trajectory. Returns masks corresponding to valid parts of the trajectories.
+ """Split trajectories at done indices.
- Example:
+ Split trajectories, concatenate them and pad with zeros up to the length of the longest trajectory. Return masks
+ corresponding to valid parts of the trajectories.
+
+ Example (transposed for readability):
Input: [[a1, a2, a3, a4 | a5, a6],
[b1, b2 | b3, b4, b5 | b6]]
@@ -93,10 +95,9 @@ def split_and_pad_trajectories(
Assumes that the input has the following order of dimensions: [time, number of envs, additional dimensions]
"""
-
dones = dones.clone()
dones[-1] = 1
- # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
+ # Permute the buffers to have the order (num_envs, num_transitions_per_env, ...) for correct reshaping
flat_dones = dones.transpose(1, 0).reshape(-1, 1)
# Get length of trajectory by counting the number of successive not done elements
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
@@ -106,33 +107,33 @@ def split_and_pad_trajectories(
if isinstance(tensor, TensorDict):
padded_trajectories = {}
for k, v in tensor.items():
- # split the tensor into trajectories
+ # Split the tensor into trajectories
trajectories = torch.split(v.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
- # add at least one full length trajectory
- trajectories = trajectories + (torch.zeros(v.shape[0], *v.shape[2:], device=v.device),)
- # pad the trajectories to the length of the longest trajectory
+ # Add at least one full length trajectory
+ trajectories = (*trajectories, torch.zeros(v.shape[0], *v.shape[2:], device=v.device))
+ # Pad the trajectories to the length of the longest trajectory
padded_trajectories[k] = torch.nn.utils.rnn.pad_sequence(trajectories)
- # remove the added tensor
+ # Remove the added trajectory
padded_trajectories[k] = padded_trajectories[k][:, :-1]
padded_trajectories = TensorDict(
padded_trajectories, batch_size=[tensor.batch_size[0], len(trajectory_lengths_list)]
)
else:
- # split the tensor into trajectories
+ # Split the tensor into trajectories
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
- # add at least one full length trajectory
- trajectories = trajectories + (torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device),)
- # pad the trajectories to the length of the longest trajectory
+ # Add at least one full length trajectory
+ trajectories = (*trajectories, torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device))
+ # Pad the trajectories to the length of the longest trajectory
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
- # remove the added tensor
+ # Remove the added trajectory
padded_trajectories = padded_trajectories[:, :-1]
- # create masks for the valid parts of the trajectories
+ # Create masks for the valid parts of the trajectories
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
return padded_trajectories, trajectory_masks
-def unpad_trajectories(trajectories, masks):
- """Does the inverse operation of split_and_pad_trajectories()"""
+def unpad_trajectories(trajectories: torch.Tensor | TensorDict, masks: torch.Tensor) -> torch.Tensor | TensorDict:
+ """Do the inverse operation of `split_and_pad_trajectories()`."""
# Need to transpose before and after the masking to have proper reshaping
return (
trajectories.transpose(1, 0)[masks.transpose(1, 0)]
@@ -141,7 +142,7 @@ def unpad_trajectories(trajectories, masks):
)
-def store_code_state(logdir, repositories) -> list:
+def store_code_state(logdir: str, repositories: list[str]) -> list[str]:
git_log_dir = os.path.join(logdir, "git")
os.makedirs(git_log_dir, exist_ok=True)
file_paths = []
@@ -151,58 +152,58 @@ def store_code_state(logdir, repositories) -> list:
t = repo.head.commit.tree
except Exception:
print(f"Could not find git repository in {repository_file_path}. Skipping.")
- # skip if not a git repository
+ # Skip if not a git repository
continue
- # get the name of the repository
+ # Get the name of the repository
repo_name = pathlib.Path(repo.working_dir).name
diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff")
- # check if the diff file already exists
+ # Check if the diff file already exists
if os.path.isfile(diff_file_name):
continue
- # write the diff file
+ # Write the diff file
print(f"Storing git diff for '{repo_name}' in: {diff_file_name}")
with open(diff_file_name, "x", encoding="utf-8") as f:
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
f.write(content)
- # add the file path to the list of files to be uploaded
+ # Add the file path to the list of files to be uploaded
file_paths.append(diff_file_name)
return file_paths
def string_to_callable(name: str) -> Callable:
- """Resolves the module and function names to return the function.
+ """Resolve the module and function names to return the function.
Args:
- name: The function name. The format should be 'module:attribute_name'.
+ name: Function name. The format should be 'module:attribute_name'.
+
+ Returns:
+ The function loaded from the module.
Raises:
ValueError: When the resolved attribute is not a function.
ValueError: When unable to resolve the attribute.
-
- Returns:
- The function loaded from the module.
"""
try:
mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name)
callable_object = getattr(mod, attr_name)
- # check if attribute is callable
+ # Check if attribute is callable
if callable(callable_object):
return callable_object
else:
raise ValueError(f"The imported object is not callable: '{name}'")
- except AttributeError as e:
+ except AttributeError as err:
msg = (
"We could not interpret the entry as a callable object. The format of input should be"
- f" 'module:attribute_name'\nWhile processing input '{name}', received the error:\n {e}."
+ f" 'module:attribute_name'\nWhile processing input '{name}'."
)
- raise ValueError(msg)
+ raise ValueError(msg) from err
def resolve_obs_groups(
obs: TensorDict, obs_groups: dict[str, list[str]], default_sets: list[str]
) -> dict[str, list[str]]:
- """Validates the observation configuration and defaults missing observation sets.
+ """Validate the observation configuration and defaults missing observation sets.
The input is an observation dictionary `obs` containing observation groups and a configuration dictionary
`obs_groups` where the keys are the observation sets and the values are lists of observation groups.
@@ -213,13 +214,13 @@ def resolve_obs_groups(
"critic": ["group_1", "group_3"]
}
- This means that the 'policy' observation set will contain the observations "group_1" and "group_2" and the
- 'critic' observation set will contain the observations "group_1" and "group_3". This function will check that all
- the observations in the 'policy' and 'critic' observation sets are present in the observation dictionary from the
+ This means that the 'policy' observation set will contain the observations "group_1" and "group_2" and the 'critic'
+ observation set will contain the observations "group_1" and "group_3". This function will check that all the
+ observations in the 'policy' and 'critic' observation sets are present in the observation dictionary from the
environment.
- Additionally, if one of the `default_sets`, e.g. "critic", is not present in the configuration dictionary,
- this function will:
+ Additionally, if one of the `default_sets`, e.g. "critic", is not present in the configuration dictionary, this
+ function will:
1. Check if a group with the same name exists in the observations and assign this group to the observation set.
2. If 1. fails, it will assign the observations from the 'policy' observation set to the default observation set.
@@ -227,8 +228,8 @@ def resolve_obs_groups(
Args:
obs: Observations from the environment in the form of a dictionary.
obs_groups: Observation sets configuration.
- default_sets: Reserved observation set names used by the algorithm (besides 'policy').
- If not provided in 'obs_groups', a default behavior gets triggered.
+ default_sets: Reserved observation set names used by the algorithm (besides 'policy'). If not provided in
+ 'obs_groups', a default behavior gets triggered.
Returns:
The resolved observation groups.
@@ -237,8 +238,8 @@ def resolve_obs_groups(
ValueError: If any observation set is an empty list.
ValueError: If any observation set contains an observation term that is not present in the observations.
"""
- # check if policy observation set exists
- if "policy" not in obs_groups.keys():
+ # Check if policy observation set exists
+ if "policy" not in obs_groups:
if "policy" in obs:
obs_groups["policy"] = ["policy"]
warnings.warn(
@@ -253,9 +254,9 @@ def resolve_obs_groups(
f" Found keys: {list(obs_groups.keys())}"
)
- # check all observation sets for valid observation groups
+ # Check all observation sets for valid observation groups
for set_name, groups in obs_groups.items():
- # check if the list is empty
+ # Check if the list is empty
if len(groups) == 0:
msg = f"The '{set_name}' key in the 'obs_groups' dictionary can not be an empty list."
if set_name in default_sets:
@@ -266,7 +267,7 @@ def resolve_obs_groups(
f" Consider removing the key to default to the observation '{set_name}' from the environment."
)
raise ValueError(msg)
- # check groups exist inside the observations from the environment
+ # Check groups exist inside the observations from the environment
for group in groups:
if group not in obs:
raise ValueError(
@@ -274,9 +275,9 @@ def resolve_obs_groups(
f" environment. Available observations from the environment: {list(obs.keys())}"
)
- # fill missing observation sets
+ # Fill missing observation sets
for default_set_name in default_sets:
- if default_set_name not in obs_groups.keys():
+ if default_set_name not in obs_groups:
if default_set_name in obs:
obs_groups[default_set_name] = [default_set_name]
warnings.warn(
@@ -294,7 +295,7 @@ def resolve_obs_groups(
" clarity. This behavior will be removed in a future version."
)
- # print the final parsed observation sets
+ # Print the final parsed observation sets
print("-" * 80)
print("Resolved observation sets: ")
for set_name, groups in obs_groups.items():
diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py
index 243e82d4..b1bf7559 100644
--- a/rsl_rl/utils/wandb_utils.py
+++ b/rsl_rl/utils/wandb_utils.py
@@ -12,13 +12,13 @@
try:
import wandb
except ModuleNotFoundError:
- raise ModuleNotFoundError("Wandb is required to log to Weights and Biases.")
+ raise ModuleNotFoundError("Wandb is required to log to Weights and Biases.") from None
class WandbSummaryWriter(SummaryWriter):
"""Summary writer for Weights and Biases."""
- def __init__(self, log_dir: str, flush_secs: int, cfg):
+ def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None:
super().__init__(log_dir, flush_secs)
# Get the run name
@@ -27,7 +27,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg):
try:
project = cfg["wandb_project"]
except KeyError:
- raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.")
+ raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") from None
try:
entity = os.environ["WANDB_USERNAME"]
@@ -40,12 +40,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg):
# Add log directory to wandb
wandb.config.update({"log_dir": log_dir})
- self.name_map = {
- "Train/mean_reward/time": "Train/mean_reward_time",
- "Train/mean_episode_length/time": "Train/mean_episode_length_time",
- }
-
- def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
+ def store_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
wandb.config.update({"runner_cfg": runner_cfg})
wandb.config.update({"policy_cfg": policy_cfg})
wandb.config.update({"alg_cfg": alg_cfg})
@@ -54,7 +49,14 @@ def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
except Exception:
wandb.config.update({"env_cfg": asdict(env_cfg)})
- def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False):
+ def add_scalar(
+ self,
+ tag: str,
+ scalar_value: float,
+ global_step: int | None = None,
+ walltime: float | None = None,
+ new_style: bool = False,
+ ) -> None:
super().add_scalar(
tag,
scalar_value,
@@ -62,26 +64,16 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_sty
walltime=walltime,
new_style=new_style,
)
- wandb.log({self._map_path(tag): scalar_value}, step=global_step)
+ wandb.log({tag: scalar_value}, step=global_step)
- def stop(self):
+ def stop(self) -> None:
wandb.finish()
- def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
+ def log_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None:
self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
- def save_model(self, model_path, iter):
+ def save_model(self, model_path: str, iter: int) -> None:
wandb.save(model_path, base_path=os.path.dirname(model_path))
- def save_file(self, path, iter=None):
+ def save_file(self, path: str) -> None:
wandb.save(path, base_path=os.path.dirname(path))
-
- """
- Private methods.
- """
-
- def _map_path(self, path):
- if path in self.name_map:
- return self.name_map[path]
- else:
- return path
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 00000000..95ae3f32
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,71 @@
+line-length = 120
+target-version = "py39"
+preview = true
+
+[lint]
+select = [
+ # pycodestyle
+ "E", "W",
+ # pydocstyle
+ "D",
+ # pylint for later
+ # "PL",
+ # pyflakes
+ "F",
+ # pyupgrade
+ "UP",
+ # pep8-naming
+ "N",
+ # flake8-bugbear
+ "B",
+ # flake8-simplify
+ "SIM",
+ # flake8-tidy-imports
+ "TID",
+ # flake8-annotations
+ "ANN",
+ # isort
+ "I",
+ # perflint
+ "PERF",
+ # ruff
+ "RUF",
+]
+ignore = ["B006",
+ "B007",
+ "B028",
+ "ANN401",
+ "D100",
+ "D101",
+ "D102",
+ "D103",
+ "D104",
+ "D105",
+ "D106",
+ "D107",
+ "D203",
+ "D213",
+ "D413",
+]
+per-file-ignores = {"*/__init__.py" = ["F401"]}
+
+[lint.isort]
+# Order of imports
+section-order = [
+ "future",
+ "standard-library",
+ "third-party",
+ "first-party",
+ "local-folder",
+]
+# Extra standard libraries considered as part of python (permissive licenses)
+extra-standard-library = [
+ "numpy",
+ "torch",
+ "tensordict",
+ "warp",
+ "typing_extensions",
+ "git",
+]
+# Imports from this repository
+known-first-party = ["rsl_rl"]
\ No newline at end of file