diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 26574e0c..00000000 --- a/.flake8 +++ /dev/null @@ -1,22 +0,0 @@ -[flake8] -show-source=True -statistics=True -per-file-ignores=*/__init__.py:F401 -# E402: Module level import not at top of file -# E501: Line too long -# W503: Line break before binary operator -# E203: Whitespace before ':' -> conflicts with black -# D401: First line should be in imperative mood -# R504: Unnecessary variable assignment before return statement. -# R505: Unnecessary elif after return statement -# SIM102: Use a single if-statement instead of nested if-statements -# SIM117: Merge with statements for context managers that have same scope. -ignore=E402,E501,W503,E203,D401,R504,R505,SIM102,SIM117 -max-line-length = 120 -max-complexity = 18 -exclude=_*,.vscode,.git,docs/** -# docstrings -docstring-convention=google -# annotations -suppress-none-returning=True -allow-star-arg-any=True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 74e58528..b40a9353 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,40 +1,21 @@ repos: - - repo: https://github.com/python/black - rev: 23.10.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 hooks: - - id: black - args: ["--line-length", "120", "--preview"] - - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-simplify, flake8-return] + - id: ruff-check + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: - - id: trailing-whitespace - id: check-symlinks - id: destroyed-symlinks - id: check-yaml + - id: check-toml - id: check-merge-conflict - id: check-case-conflict - id: check-executables-have-shebangs - - id: check-toml - - id: end-of-file-fixer - id: check-shebang-scripts-are-executable - id: detect-private-key - - id: debug-statements - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - name: isort (python) - args: ["--profile", "black", "--filter-files"] - - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 - hooks: - - id: pyupgrade - args: ["--py37-plus"] - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ca06f311..00692ba7 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -17,12 +17,14 @@ Please keep the lists sorted alphabetically. --- -* Mayank Mittal * Clemens Schwarke +* Mayank Mittal ## Authors +* Clemens Schwarke * David Hoeller +* Mayank Mittal * Nikita Rudin ## Contributors diff --git a/README.md b/README.md index 0ea9bebb..d00385b7 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,14 @@ -# RSL RL +# RSL-RL -A fast and simple implementation of RL algorithms, designed to run fully on GPU. -This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac Gym. +A fast and simple implementation of learning algorithms for robotics. For an overview of the library please have a look at https://arxiv.org/pdf/2509.10771. Environment repositories using the framework: * **`Isaac Lab`** (built on top of NVIDIA Isaac Sim): https://github.com/isaac-sim/IsaacLab -* **`Legged-Gym`** (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/ +* **`Legged Gym`** (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/ * **`MuJoCo Playground`** (built on top of MuJoCo MJX and Warp): https://github.com/google-deepmind/mujoco_playground/ -The main branch supports **PPO** and **Student-Teacher Distillation** with additional features from our research. These include: +The library currently supports **PPO** and **Student-Teacher Distillation** with additional features from our research. These include: * [Random Network Distillation (RND)](https://proceedings.mlr.press/v229/schwarke23a.html) - Encourages exploration by adding a curiosity driven intrinsic reward. @@ -22,8 +21,6 @@ information. **Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
**Contact**: cschwarke@ethz.ch -> **Note:** The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more). However, it isn't currently actively maintained. - ## Setup @@ -57,8 +54,7 @@ For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napol We use the following tools for maintaining code quality: - [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase. -- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter. -- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker. +- [ruff](https://github.com/astral-sh/ruff): An extremely fast Python linter and code formatter, written in Rust. Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal: diff --git a/config/example_config.yaml b/config/example_config.yaml index 329cb1e5..00a9483f 100644 --- a/config/example_config.yaml +++ b/config/example_config.yaml @@ -1,21 +1,21 @@ runner: class_name: OnPolicyRunner - # -- general - num_steps_per_env: 24 # number of steps per environment per iteration - max_iterations: 1500 # number of policy updates + # General + num_steps_per_env: 24 # Number of steps per environment per iteration + max_iterations: 1500 # Number of policy updates seed: 1 - # -- observations - obs_groups: {"policy": ["policy"], "critic": ["policy", "privileged"]} # maps observation groups to types. See `vec_env.py` for more information - # -- logging parameters - save_interval: 50 # check for potential saves every `save_interval` iterations + # Observations + obs_groups: {"policy": ["policy"], "critic": ["policy", "privileged"]} # Maps observation groups to sets. See `vec_env.py` for more information + # Logging parameters + save_interval: 50 # Check for potential saves every `save_interval` iterations experiment_name: walking_experiment run_name: "" - # -- logging writer + # Logging writer logger: tensorboard # tensorboard, neptune, wandb neptune_project: legged_gym wandb_project: legged_gym - # -- policy + # Policy policy: class_name: ActorCritic activation: elu @@ -25,45 +25,46 @@ runner: critic_hidden_dims: [256, 256, 256] init_noise_std: 1.0 noise_std_type: "scalar" # 'scalar' or 'log' + state_dependent_std: false - # -- algorithm + # Algorithm algorithm: class_name: PPO - # -- training + # Training learning_rate: 0.001 num_learning_epochs: 5 num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches schedule: adaptive # adaptive, fixed - # -- value function + # Value function value_loss_coef: 1.0 clip_param: 0.2 use_clipped_value_loss: true - # -- surrogate loss + # Surrogate loss desired_kl: 0.01 entropy_coef: 0.01 gamma: 0.99 lam: 0.95 max_grad_norm: 1.0 - # -- miscellaneous + # Miscellaneous normalize_advantage_per_mini_batch: false - # -- random network distillation + # Random network distillation rnd_cfg: - weight: 0.0 # initial weight of the RND reward - weight_schedule: null # note: this is a dictionary with a required key called "mode". Please check the RND module for more information - reward_normalization: false # whether to normalize RND reward - # -- learning parameters - learning_rate: 0.001 # learning rate for RND - # -- network parameters - num_outputs: 1 # number of outputs of RND network. Note: if -1, then the network will use dimensions of the observation - predictor_hidden_dims: [-1] # hidden dimensions of predictor network - target_hidden_dims: [-1] # hidden dimensions of target network + weight: 0.0 # Initial weight of the RND reward + weight_schedule: null # This is a dictionary with a required key called "mode". Please check the RND module for more information + reward_normalization: false # Whether to normalize RND reward + # Learning parameters + learning_rate: 0.001 # Learning rate for RND + # Network parameters + num_outputs: 1 # Number of outputs of RND network. Note: if -1, then the network will use dimensions of the observation + predictor_hidden_dims: [-1] # Hidden dimensions of predictor network + target_hidden_dims: [-1] # Hidden dimensions of target network - # -- symmetry augmentation + # Symmetry augmentation symmetry_cfg: - use_data_augmentation: true # this adds symmetric trajectories to the batch - use_mirror_loss: false # this adds symmetry loss term to the loss function - data_augmentation_func: null # string containing the module and function name to import + use_data_augmentation: true # This adds symmetric trajectories to the batch + use_mirror_loss: false # This adds symmetry loss term to the loss function + data_augmentation_func: null # String containing the module and function name to import # Example: "legged_gym.envs.locomotion.anymal_c.symmetry:get_symmetric_states" # # .. code-block:: python @@ -73,4 +74,4 @@ runner: # obs: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, cfg: "BaseEnvCfg" = None, obs_type: str = "policy" # ) -> Tuple[torch.Tensor, torch.Tensor]: # - mirror_loss_coeff: 0.0 #coefficient for symmetry loss term. If 0, no symmetry loss is used + mirror_loss_coeff: 0.0 # Coefficient for symmetry loss term. If 0, no symmetry loss is used diff --git a/licenses/dependencies/black-license.txt b/licenses/dependencies/black-license.txt deleted file mode 100644 index 7a9b891f..00000000 --- a/licenses/dependencies/black-license.txt +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2018 Łukasz Langa - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/licenses/dependencies/flake8-license.txt b/licenses/dependencies/flake8-license.txt deleted file mode 100644 index e5e3d6f9..00000000 --- a/licenses/dependencies/flake8-license.txt +++ /dev/null @@ -1,22 +0,0 @@ -== Flake8 License (MIT) == - -Copyright (C) 2011-2013 Tarek Ziade -Copyright (C) 2012-2016 Ian Cordasco - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/licenses/dependencies/isort-license.txt b/licenses/dependencies/isort-license.txt deleted file mode 100644 index b5083a50..00000000 --- a/licenses/dependencies/isort-license.txt +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2013 Timothy Edmund Crosley - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/licenses/dependencies/numpy_license.txt b/licenses/dependencies/numpy-license.txt similarity index 100% rename from licenses/dependencies/numpy_license.txt rename to licenses/dependencies/numpy-license.txt diff --git a/licenses/dependencies/pyupgrade-license.txt b/licenses/dependencies/pyupgrade-license.txt deleted file mode 100644 index 522fbe20..00000000 --- a/licenses/dependencies/pyupgrade-license.txt +++ /dev/null @@ -1,19 +0,0 @@ -Copyright (c) 2017 Anthony Sottile - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. diff --git a/licenses/dependencies/ruff-license.txt b/licenses/dependencies/ruff-license.txt new file mode 100644 index 00000000..d779ee9e --- /dev/null +++ b/licenses/dependencies/ruff-license.txt @@ -0,0 +1,430 @@ +MIT License + +Copyright (c) 2022 Charles Marsh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +end of terms and conditions + +The externally maintained libraries from which parts of the Software is derived +are: + +- autoflake, licensed as follows: + """ + Copyright (C) 2012-2018 Steven Myint + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in + the Software without restriction, including without limitation the rights to + use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + of the Software, and to permit persons to whom the Software is furnished to do + so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- autotyping, licensed as follows: + """ + MIT License + + Copyright (c) 2023 Jelle Zijlstra + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- Flake8, licensed as follows: + """ + == Flake8 License (MIT) == + + Copyright (C) 2011-2013 Tarek Ziade + Copyright (C) 2012-2016 Ian Cordasco + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in + the Software without restriction, including without limitation the rights to + use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + of the Software, and to permit persons to whom the Software is furnished to do + so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- flake8-eradicate, licensed as follows: + """ + MIT License + + Copyright (c) 2018 Nikita Sobolev + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- flake8-pyi, licensed as follows: + """ + The MIT License (MIT) + + Copyright (c) 2016 Łukasz Langa + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- flake8-simplify, licensed as follows: + """ + MIT License + + Copyright (c) 2020 Martin Thoma + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- isort, licensed as follows: + """ + The MIT License (MIT) + + Copyright (c) 2013 Timothy Edmund Crosley + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + """ + +- pygrep-hooks, licensed as follows: + """ + Copyright (c) 2018 Anthony Sottile + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + """ + +- pycodestyle, licensed as follows: + """ + Copyright © 2006-2009 Johann C. Rocholl + Copyright © 2009-2014 Florent Xicluna + Copyright © 2014-2020 Ian Lee + + Licensed under the terms of the Expat License + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, + including without limitation the rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of the Software, + and to permit persons to whom the Software is furnished to do so, + subject to the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- pydocstyle, licensed as follows: + """ + Copyright (c) 2012 GreenSteam, + + Copyright (c) 2014-2020 Amir Rachum, + + Copyright (c) 2020 Sambhav Kothari, + + Permission is hereby granted, free of charge, to any person obtaining a copy of + this software and associated documentation files (the "Software"), to deal in + the Software without restriction, including without limitation the rights to + use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + of the Software, and to permit persons to whom the Software is furnished to do + so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- Pyflakes, licensed as follows: + """ + Copyright 2005-2011 Divmod, Inc. + Copyright 2013-2014 Florent Xicluna + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + """ + +- Pyright, licensed as follows: + """ + MIT License + + Pyright - A static type checker for the Python language + Copyright (c) Microsoft Corporation. All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + """ + +- pyupgrade, licensed as follows: + """ + Copyright (c) 2017 Anthony Sottile + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE. + """ + +- rome/tools, licensed under the MIT license: + """ + MIT License + + Copyright (c) Rome Tools, Inc. and its affiliates. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- RustPython, licensed as follows: + """ + MIT License + + Copyright (c) 2020 RustPython Team + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + +- rust-analyzer/text-size, licensed under the MIT license: + """ + Permission is hereby granted, free of charge, to any + person obtaining a copy of this software and associated + documentation files (the "Software"), to deal in the + Software without restriction, including without + limitation the rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of + the Software, and to permit persons to whom the Software + is furnished to do so, subject to the following + conditions: + + The above copyright notice and this permission notice + shall be included in all copies or substantial portions + of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF + ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A + PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT + SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR + IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + """ \ No newline at end of file diff --git a/licenses/dependencies/tensordict.txt b/licenses/dependencies/tensordict-license.txt similarity index 100% rename from licenses/dependencies/tensordict.txt rename to licenses/dependencies/tensordict-license.txt diff --git a/licenses/dependencies/torch_license.txt b/licenses/dependencies/torch-license.txt similarity index 100% rename from licenses/dependencies/torch_license.txt rename to licenses/dependencies/torch-license.txt diff --git a/pyproject.toml b/pyproject.toml index 1f1817d4..f9b10ce4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "rsl-rl-lib" version = "3.1.1" -keywords = ["reinforcement-learning", "isaac", "leggedrobotics", "rl-pytorch"] +keywords = ["reinforcement-learning", "robotics"] maintainers = [ { name="Clemens Schwarke", email="cschwarke@ethz.ch" }, { name="Mayank Mittal", email="mittalma@ethz.ch" }, @@ -17,10 +17,9 @@ authors = [ { name="David Hoeller", email="holler.david78@gmail.com" }, ] description = "Fast and simple RL algorithms implemented in PyTorch" -readme = { file = "README.md", content-type = "text/markdown"} -license = { text = "BSD-3-Clause" } - -requires-python = ">=3.8" +readme = { file = "README.md", content-type = "text/markdown" } +license = "BSD-3-Clause" +requires-python = ">=3.9" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", @@ -45,51 +44,16 @@ include = ["rsl_rl*"] [tool.setuptools.package-data] "rsl_rl" = ["config/*", "licenses/*"] -[tool.isort] - -py_version = 38 -line_length = 120 -group_by_package = true - -# Files to skip -skip_glob = [".vscode/*"] - -# Order of imports -sections = [ - "FUTURE", - "STDLIB", - "THIRDPARTY", - "FIRSTPARTY", - "LOCALFOLDER", -] - -# Extra standard libraries considered as part of python (permissive licenses) -extra_standard_library = [ - "numpy", - "torch", - "tensordict", - "warp", - "typing_extensions", - "git", -] -# Imports from this repository -known_first_party = "rsl_rl" - [tool.pyright] - include = ["rsl_rl"] - typeCheckingMode = "basic" -pythonVersion = "3.8" +pythonVersion = "3.9" pythonPlatform = "Linux" enableTypeIgnoreComments = true - # This is required as the CI pre-commit does not download the module (i.e. numpy, torch, prettytable) -# Therefore, we have to ignore missing imports reportMissingImports = "none" -# This is required to ignore for type checks of modules with stubs missing. -reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers - -reportGeneralTypeIssues = "none" # -> raises 218 errors (usage of literal MISSING in dataclasses) -reportOptionalMemberAccess = "warning" # -> raises 8 errors +# This is required to ignore type checks of modules with stubs missing. +reportMissingModuleSource = "none" # -> most common: prettytable in mdp managers +reportGeneralTypeIssues = "none" # -> usage of literal MISSING in dataclasses +reportOptionalMemberAccess = "warning" reportPrivateUsage = "warning" diff --git a/rsl_rl/algorithms/__init__.py b/rsl_rl/algorithms/__init__.py index effcaa28..b1686144 100644 --- a/rsl_rl/algorithms/__init__.py +++ b/rsl_rl/algorithms/__init__.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implementation of different RL agents.""" +"""Implementation of different learning algorithms.""" from .distillation import Distillation from .ppo import PPO diff --git a/rsl_rl/algorithms/distillation.py b/rsl_rl/algorithms/distillation.py index 3a86e002..5da039a9 100644 --- a/rsl_rl/algorithms/distillation.py +++ b/rsl_rl/algorithms/distillation.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +from tensordict import TensorDict from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent from rsl_rl.storage import RolloutStorage @@ -19,20 +20,21 @@ class Distillation: def __init__( self, - policy, - num_learning_epochs=1, - gradient_length=15, - learning_rate=1e-3, - max_grad_norm=None, - loss_type="mse", - optimizer="adam", - device="cpu", + policy: StudentTeacher | StudentTeacherRecurrent, + num_learning_epochs: int = 1, + gradient_length: int = 15, + learning_rate: float = 1e-3, + max_grad_norm: float | None = None, + loss_type: str = "mse", + optimizer: str = "adam", + device: str = "cpu", # Distributed training parameters multi_gpu_cfg: dict | None = None, - ): - # device-related parameters + ) -> None: + # Device-related parameters self.device = device self.is_multi_gpu = multi_gpu_cfg is not None + # Multi-GPU parameters if multi_gpu_cfg is not None: self.gpu_global_rank = multi_gpu_cfg["global_rank"] @@ -41,25 +43,25 @@ def __init__( self.gpu_global_rank = 0 self.gpu_world_size = 1 - # distillation components + # Distillation components self.policy = policy self.policy.to(self.device) - self.storage = None # initialized later + self.storage = None # Initialized later - # initialize the optimizer + # Initialize the optimizer self.optimizer = resolve_optimizer(optimizer)(self.policy.parameters(), lr=learning_rate) - # initialize the transition + # Initialize the transition self.transition = RolloutStorage.Transition() - self.last_hidden_states = None + self.last_hidden_states = (None, None) - # distillation parameters + # Distillation parameters self.num_learning_epochs = num_learning_epochs self.gradient_length = gradient_length self.learning_rate = learning_rate self.max_grad_norm = max_grad_norm - # initialize the loss function + # Initialize the loss function loss_fn_dict = { "mse": nn.functional.mse_loss, "huber": nn.functional.huber_loss, @@ -71,8 +73,15 @@ def __init__( self.num_updates = 0 - def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape): - # create rollout storage + def init_storage( + self, + training_type: str, + num_envs: int, + num_transitions_per_env: int, + obs: TensorDict, + actions_shape: tuple[int], + ) -> None: + # Create rollout storage self.storage = RolloutStorage( training_type, num_envs, @@ -82,27 +91,29 @@ def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, ac self.device, ) - def act(self, obs): - # compute the actions + def act(self, obs: TensorDict) -> torch.Tensor: + # Compute the actions self.transition.actions = self.policy.act(obs).detach() self.transition.privileged_actions = self.policy.evaluate(obs).detach() - # record the observations + # Record the observations self.transition.observations = obs return self.transition.actions - def process_env_step(self, obs, rewards, dones, extras): - # update the normalizers + def process_env_step( + self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor] + ) -> None: + # Update the normalizers self.policy.update_normalization(obs) - # record the rewards and dones + # Record the rewards and dones self.transition.rewards = rewards self.transition.dones = dones - # record the transition + # Record the transition self.storage.add_transitions(self.transition) self.transition.clear() self.policy.reset(dones) - def update(self): + def update(self) -> dict[str, float]: self.num_updates += 1 mean_behavior_loss = 0 loss = 0 @@ -112,19 +123,18 @@ def update(self): self.policy.reset(hidden_states=self.last_hidden_states) self.policy.detach_hidden_states() for obs, _, privileged_actions, dones in self.storage.generator(): - - # inference the student for gradient computation + # Inference of the student for gradient computation actions = self.policy.act_inference(obs) - # behavior cloning loss + # Behavior cloning loss behavior_loss = self.loss_fn(actions, privileged_actions) - # total loss + # Total loss loss = loss + behavior_loss mean_behavior_loss += behavior_loss.item() cnt += 1 - # gradient step + # Gradient step if cnt % self.gradient_length == 0: self.optimizer.zero_grad() loss.backward() @@ -136,7 +146,7 @@ def update(self): self.policy.detach_hidden_states() loss = 0 - # reset dones + # Reset dones self.policy.reset(dones.view(-1)) self.policy.detach_hidden_states(dones.view(-1)) @@ -145,25 +155,21 @@ def update(self): self.last_hidden_states = self.policy.get_hidden_states() self.policy.detach_hidden_states() - # construct the loss dictionary + # Construct the loss dictionary loss_dict = {"behavior": mean_behavior_loss} return loss_dict - """ - Helper functions - """ - - def broadcast_parameters(self): + def broadcast_parameters(self) -> None: """Broadcast model parameters to all GPUs.""" - # obtain the model parameters on current GPU + # Obtain the model parameters on current GPU model_params = [self.policy.state_dict()] - # broadcast the model parameters + # Broadcast the model parameters torch.distributed.broadcast_object_list(model_params, src=0) - # load the model parameters on all GPUs from source GPU + # Load the model parameters on all GPUs from source GPU self.policy.load_state_dict(model_params[0]) - def reduce_parameters(self): + def reduce_parameters(self) -> None: """Collect gradients from all GPUs and average them. This function is called after the backward pass to synchronize the gradients across all GPUs. @@ -179,7 +185,7 @@ def reduce_parameters(self): for param in self.policy.parameters(): if param.grad is not None: numel = param.numel() - # copy data back from shared buffer + # Copy data back from shared buffer param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data)) - # update the offset for the next parameter + # Update the offset for the next parameter offset += numel diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index 6c21fc54..44efe782 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -9,8 +9,9 @@ import torch.nn as nn import torch.optim as optim from itertools import chain +from tensordict import TensorDict -from rsl_rl.modules import ActorCritic +from rsl_rl.modules import ActorCritic, ActorCriticRecurrent from rsl_rl.modules.rnd import RandomNetworkDistillation from rsl_rl.storage import RolloutStorage from rsl_rl.utils import string_to_callable @@ -19,36 +20,37 @@ class PPO: """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" - policy: ActorCritic + policy: ActorCritic | ActorCriticRecurrent """The actor critic module.""" def __init__( self, - policy, - num_learning_epochs=5, - num_mini_batches=4, - clip_param=0.2, - gamma=0.99, - lam=0.95, - value_loss_coef=1.0, - entropy_coef=0.01, - learning_rate=0.001, - max_grad_norm=1.0, - use_clipped_value_loss=True, - schedule="adaptive", - desired_kl=0.01, - device="cpu", - normalize_advantage_per_mini_batch=False, + policy: ActorCritic | ActorCriticRecurrent, + num_learning_epochs: int = 5, + num_mini_batches: int = 4, + clip_param: float = 0.2, + gamma: float = 0.99, + lam: float = 0.95, + value_loss_coef: float = 1.0, + entropy_coef: float = 0.01, + learning_rate: float = 0.001, + max_grad_norm: float = 1.0, + use_clipped_value_loss: bool = True, + schedule: str = "adaptive", + desired_kl: float = 0.01, + device: str = "cpu", + normalize_advantage_per_mini_batch: bool = False, # RND parameters rnd_cfg: dict | None = None, # Symmetry parameters symmetry_cfg: dict | None = None, # Distributed training parameters multi_gpu_cfg: dict | None = None, - ): - # device-related parameters + ) -> None: + # Device-related parameters self.device = device self.is_multi_gpu = multi_gpu_cfg is not None + # Multi-GPU parameters if multi_gpu_cfg is not None: self.gpu_global_rank = multi_gpu_cfg["global_rank"] @@ -94,10 +96,12 @@ def __init__( # PPO components self.policy = policy self.policy.to(self.device) + # Create optimizer self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate) + # Create rollout storage - self.storage: RolloutStorage = None # type: ignore + self.storage: RolloutStorage | None = None self.transition = RolloutStorage.Transition() # PPO parameters @@ -115,8 +119,15 @@ def __init__( self.learning_rate = learning_rate self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch - def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, actions_shape): - # create rollout storage + def init_storage( + self, + training_type: str, + num_envs: int, + num_transitions_per_env: int, + obs: TensorDict, + actions_shape: tuple[int] | list[int], + ) -> None: + # Create rollout storage self.storage = RolloutStorage( training_type, num_envs, @@ -126,27 +137,29 @@ def init_storage(self, training_type, num_envs, num_transitions_per_env, obs, ac self.device, ) - def act(self, obs): + def act(self, obs: TensorDict) -> torch.Tensor: if self.policy.is_recurrent: self.transition.hidden_states = self.policy.get_hidden_states() - # compute the actions and values + # Compute the actions and values self.transition.actions = self.policy.act(obs).detach() self.transition.values = self.policy.evaluate(obs).detach() self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach() self.transition.action_mean = self.policy.action_mean.detach() self.transition.action_sigma = self.policy.action_std.detach() - # need to record obs before env.step() + # Record observations before env.step() self.transition.observations = obs return self.transition.actions - def process_env_step(self, obs, rewards, dones, extras): - # update the normalizers + def process_env_step( + self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor] + ) -> None: + # Update the normalizers self.policy.update_normalization(obs) if self.rnd: self.rnd.update_normalization(obs) # Record the rewards and dones - # Note: we clone here because later on we bootstrap the rewards based on timeouts + # Note: We clone here because later on we bootstrap the rewards based on timeouts self.transition.rewards = rewards.clone() self.transition.dones = dones @@ -163,40 +176,34 @@ def process_env_step(self, obs, rewards, dones, extras): self.transition.values * extras["time_outs"].unsqueeze(1).to(self.device), 1 ) - # record the transition + # Record the transition self.storage.add_transitions(self.transition) self.transition.clear() self.policy.reset(dones) - def compute_returns(self, obs): - # compute value for the last step + def compute_returns(self, obs: TensorDict) -> None: + # Compute value for the last step last_values = self.policy.evaluate(obs).detach() self.storage.compute_returns( last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch ) - def update(self): # noqa: C901 + def update(self) -> dict[str, float]: mean_value_loss = 0 mean_surrogate_loss = 0 mean_entropy = 0 - # -- RND loss - if self.rnd: - mean_rnd_loss = 0 - else: - mean_rnd_loss = None - # -- Symmetry loss - if self.symmetry: - mean_symmetry_loss = 0 - else: - mean_symmetry_loss = None + # RND loss + mean_rnd_loss = 0 if self.rnd else None + # Symmetry loss + mean_symmetry_loss = 0 if self.symmetry else None - # generator for mini batches + # Get mini batch generator if self.policy.is_recurrent: generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) else: generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) - # iterate over batches + # Iterate over batches for ( obs_batch, actions_batch, @@ -206,57 +213,46 @@ def update(self): # noqa: C901 old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, - hid_states_batch, + hidden_states_batch, masks_batch, ) in generator: - - # number of augmentations per sample - # we start with 1 and increase it if we use symmetry augmentation - num_aug = 1 - # original batch size - # we assume policy group is always there and needs augmentation + num_aug = 1 # Number of augmentations per sample. Starts at 1 for no augmentation. original_batch_size = obs_batch.batch_size[0] - # check if we should normalize advantages per mini batch + # Check if we should normalize advantages per mini batch if self.normalize_advantage_per_mini_batch: with torch.no_grad(): advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8) # Perform symmetric augmentation if self.symmetry and self.symmetry["use_data_augmentation"]: - # augmentation using symmetry + # Augmentation using symmetry data_augmentation_func = self.symmetry["data_augmentation_func"] - # returned shape: [batch_size * num_aug, ...] + # Returned shape: [batch_size * num_aug, ...] obs_batch, actions_batch = data_augmentation_func( obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"], ) - # compute number of augmentations per sample - # we assume policy group is always there and needs augmentation + # Compute number of augmentations per sample num_aug = int(obs_batch.batch_size[0] / original_batch_size) - # repeat the rest of the batch - # -- actor + # Repeat the rest of the batch old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1) - # -- critic target_values_batch = target_values_batch.repeat(num_aug, 1) advantages_batch = advantages_batch.repeat(num_aug, 1) returns_batch = returns_batch.repeat(num_aug, 1) # Recompute actions log prob and entropy for current batch of transitions - # Note: we need to do this because we updated the policy with the new parameters - # -- actor - self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) + # Note: We need to do this because we updated the policy with the new parameters + self.policy.act(obs_batch, masks=masks_batch, hidden_state=hidden_states_batch[0]) actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch) - # -- critic - value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]) - # -- entropy - # we only keep the entropy of the first augmentation (the original one) + value_batch = self.policy.evaluate(obs_batch, masks=masks_batch, hidden_state=hidden_states_batch[1]) + # Note: We only keep the entropy of the first augmentation (the original one) mu_batch = self.policy.action_mean[:original_batch_size] sigma_batch = self.policy.action_std[:original_batch_size] entropy_batch = self.policy.entropy[:original_batch_size] - # KL + # Compute KL divergence and adapt the learning rate if self.desired_kl is not None and self.schedule == "adaptive": with torch.inference_mode(): kl = torch.sum( @@ -273,8 +269,7 @@ def update(self): # noqa: C901 torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM) kl_mean /= self.gpu_world_size - # Update the learning rate - # Perform this adaptation only on the main process + # Update the learning rate only on the main process # TODO: Is this needed? If KL-divergence is the "same" across all GPUs, # then the learning rate should be the same across all GPUs. if self.gpu_global_rank == 0: @@ -316,70 +311,68 @@ def update(self): # noqa: C901 # Symmetry loss if self.symmetry: - # obtain the symmetric actions - # if we did augmentation before then we don't need to augment again + # Obtain the symmetric actions + # Note: If we did augmentation before then we don't need to augment again if not self.symmetry["use_data_augmentation"]: data_augmentation_func = self.symmetry["data_augmentation_func"] obs_batch, _ = data_augmentation_func(obs=obs_batch, actions=None, env=self.symmetry["_env"]) - # compute number of augmentations per sample + # Compute number of augmentations per sample num_aug = int(obs_batch.shape[0] / original_batch_size) - # actions predicted by the actor for symmetrically-augmented observations + # Actions predicted by the actor for symmetrically-augmented observations mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone()) - # compute the symmetrically augmented actions - # note: we are assuming the first augmentation is the original one. - # We do not use the action_batch from earlier since that action was sampled from the distribution. - # However, the symmetry loss is computed using the mean of the distribution. + # Compute the symmetrically augmented actions + # Note: We are assuming the first augmentation is the original one. We do not use the action_batch from + # earlier since that action was sampled from the distribution. However, the symmetry loss is computed + # using the mean of the distribution. action_mean_orig = mean_actions_batch[:original_batch_size] _, actions_mean_symm_batch = data_augmentation_func( obs=None, actions=action_mean_orig, env=self.symmetry["_env"] ) - # compute the loss (we skip the first augmentation as it is the original one) + # Compute the loss mse_loss = torch.nn.MSELoss() symmetry_loss = mse_loss( mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:] ) - # add the loss to the total loss + # Add the loss to the total loss if self.symmetry["use_mirror_loss"]: loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss else: symmetry_loss = symmetry_loss.detach() - # Random Network Distillation loss + # RND loss # TODO: Move this processing to inside RND module. if self.rnd: - # extract the rnd_state + # Extract the rnd_state # TODO: Check if we still need torch no grad. It is just an affine transformation. with torch.no_grad(): rnd_state_batch = self.rnd.get_rnd_state(obs_batch[:original_batch_size]) rnd_state_batch = self.rnd.state_normalizer(rnd_state_batch) - # predict the embedding and the target + # Predict the embedding and the target predicted_embedding = self.rnd.predictor(rnd_state_batch) target_embedding = self.rnd.target(rnd_state_batch).detach() - # compute the loss as the mean squared error + # Compute the loss as the mean squared error mseloss = torch.nn.MSELoss() rnd_loss = mseloss(predicted_embedding, target_embedding) - # Compute the gradients - # -- For PPO + # Compute the gradients for PPO self.optimizer.zero_grad() loss.backward() - # -- For RND + # Compute the gradients for RND if self.rnd: - self.rnd_optimizer.zero_grad() # type: ignore + self.rnd_optimizer.zero_grad() rnd_loss.backward() # Collect gradients from all GPUs if self.is_multi_gpu: self.reduce_parameters() - # Apply the gradients - # -- For PPO + # Apply the gradients for PPO nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.optimizer.step() - # -- For RND + # Apply the gradients for RND if self.rnd_optimizer: self.rnd_optimizer.step() @@ -387,28 +380,27 @@ def update(self): # noqa: C901 mean_value_loss += value_loss.item() mean_surrogate_loss += surrogate_loss.item() mean_entropy += entropy_batch.mean().item() - # -- RND loss + # RND loss if mean_rnd_loss is not None: mean_rnd_loss += rnd_loss.item() - # -- Symmetry loss + # Symmetry loss if mean_symmetry_loss is not None: mean_symmetry_loss += symmetry_loss.item() - # -- For PPO + # Divide the losses by the number of updates num_updates = self.num_learning_epochs * self.num_mini_batches mean_value_loss /= num_updates mean_surrogate_loss /= num_updates mean_entropy /= num_updates - # -- For RND if mean_rnd_loss is not None: mean_rnd_loss /= num_updates - # -- For Symmetry if mean_symmetry_loss is not None: mean_symmetry_loss /= num_updates - # -- Clear the storage + + # Clear the storage self.storage.clear() - # construct the loss dictionary + # Construct the loss dictionary loss_dict = { "value_function": mean_value_loss, "surrogate": mean_surrogate_loss, @@ -421,24 +413,20 @@ def update(self): # noqa: C901 return loss_dict - """ - Helper functions - """ - - def broadcast_parameters(self): + def broadcast_parameters(self) -> None: """Broadcast model parameters to all GPUs.""" - # obtain the model parameters on current GPU + # Obtain the model parameters on current GPU model_params = [self.policy.state_dict()] if self.rnd: model_params.append(self.rnd.predictor.state_dict()) - # broadcast the model parameters + # Broadcast the model parameters torch.distributed.broadcast_object_list(model_params, src=0) - # load the model parameters on all GPUs from source GPU + # Load the model parameters on all GPUs from source GPU self.policy.load_state_dict(model_params[0]) if self.rnd: self.rnd.predictor.load_state_dict(model_params[1]) - def reduce_parameters(self): + def reduce_parameters(self) -> None: """Collect gradients from all GPUs and average them. This function is called after the backward pass to synchronize the gradients across all GPUs. @@ -463,7 +451,7 @@ def reduce_parameters(self): for param in all_params: if param.grad is not None: numel = param.numel() - # copy data back from shared buffer + # Copy data back from shared buffer param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data)) - # update the offset for the next parameter + # Update the offset for the next parameter offset += numel diff --git a/rsl_rl/env/vec_env.py b/rsl_rl/env/vec_env.py index 3e73e336..86285923 100644 --- a/rsl_rl/env/vec_env.py +++ b/rsl_rl/env/vec_env.py @@ -13,9 +13,8 @@ class VecEnv(ABC): """Abstract class for a vectorized environment. - The vectorized environment is a collection of environments that are synchronized. This means that - the same type of action is applied to all environments and the same type of observation is returned from all - environments. + The vectorized environment is a collection of environments that are synchronized. This means that the same type of + action is applied to all environments and the same type of observation is returned from all environments. """ num_envs: int @@ -41,16 +40,12 @@ class VecEnv(ABC): cfg: dict | object """Configuration object.""" - """ - Operations. - """ - @abstractmethod def get_observations(self) -> TensorDict: """Return the current observations. Returns: - observations (TensorDict): Observations from the environment. + The observations from the environment. """ raise NotImplementedError @@ -59,16 +54,15 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T """Apply input action to the environment. Args: - actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions) + actions: Input actions to apply. Shape: (num_envs, num_actions) Returns: - observations (TensorDict): Observations from the environment. - rewards (torch.Tensor): Rewards from the environment. Shape: (num_envs,) - dones (torch.Tensor): Done flags from the environment. Shape: (num_envs,) - extras (dict): Extra information from the environment. + observations: Observations from the environment. + rewards: Rewards from the environment. Shape: (num_envs,) + dones: Done flags from the environment. Shape: (num_envs,) + extras: Extra information from the environment. Observations: - The observations TensorDict usually contains multiple observation groups. The `obs_groups` dictionary of the runner configuration specifies which observation groups are used for which purpose, i.e., it maps the available observation groups to observation sets. The observation sets @@ -83,7 +77,6 @@ def step(self, actions: torch.Tensor) -> tuple[TensorDict, torch.Tensor, torch.T `rsl_rl/utils/utils.py`. Extras: - The extras dictionary includes metrics such as the episode reward, episode length, etc. The following dictionary keys are used by rsl_rl: diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index 0d23fff6..efb8613a 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -7,14 +7,17 @@ from .actor_critic import ActorCritic from .actor_critic_recurrent import ActorCriticRecurrent -from .rnd import * +from .rnd import RandomNetworkDistillation, resolve_rnd_config from .student_teacher import StudentTeacher from .student_teacher_recurrent import StudentTeacherRecurrent -from .symmetry import * +from .symmetry import resolve_symmetry_config __all__ = [ "ActorCritic", "ActorCriticRecurrent", + "RandomNetworkDistillation", "StudentTeacher", "StudentTeacherRecurrent", + "resolve_rnd_config", + "resolve_symmetry_config", ] diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py index 89434c3e..9f01b2f4 100644 --- a/rsl_rl/modules/actor_critic.py +++ b/rsl_rl/modules/actor_critic.py @@ -7,37 +7,38 @@ import torch import torch.nn as nn +from tensordict import TensorDict from torch.distributions import Normal +from typing import Any, NoReturn from rsl_rl.networks import MLP, EmpiricalNormalization class ActorCritic(nn.Module): - is_recurrent = False + is_recurrent: bool = False def __init__( self, - obs, - obs_groups, - num_actions, - actor_obs_normalization=False, - critic_obs_normalization=False, - actor_hidden_dims=[256, 256, 256], - critic_hidden_dims=[256, 256, 256], - activation="elu", - init_noise_std=1.0, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + activation: str = "elu", + init_noise_std: float = 1.0, noise_std_type: str = "scalar", - state_dependent_std=False, - **kwargs, - ): + state_dependent_std: bool = False, + **kwargs: dict[str, Any], + ) -> None: if kwargs: print( - "ActorCritic.__init__ got unexpected arguments, which will be ignored: " - + str([key for key in kwargs.keys()]) + "ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs]) ) super().__init__() - # get the observation dimensions + # Get the observation dimensions self.obs_groups = obs_groups num_actor_obs = 0 for obs_group in obs_groups["policy"]: @@ -49,28 +50,31 @@ def __init__( num_critic_obs += obs[obs_group].shape[-1] self.state_dependent_std = state_dependent_std - # actor + + # Actor if self.state_dependent_std: self.actor = MLP(num_actor_obs, [2, num_actions], actor_hidden_dims, activation) else: self.actor = MLP(num_actor_obs, num_actions, actor_hidden_dims, activation) - # actor observation normalization + print(f"Actor MLP: {self.actor}") + + # Actor observation normalization self.actor_obs_normalization = actor_obs_normalization if actor_obs_normalization: self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) else: self.actor_obs_normalizer = torch.nn.Identity() - print(f"Actor MLP: {self.actor}") - # critic + # Critic self.critic = MLP(num_critic_obs, 1, critic_hidden_dims, activation) - # critic observation normalization + print(f"Critic MLP: {self.critic}") + + # Critic observation normalization self.critic_obs_normalization = critic_obs_normalization if critic_obs_normalization: self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) else: self.critic_obs_normalizer = torch.nn.Identity() - print(f"Critic MLP: {self.critic}") # Action noise self.noise_std_type = noise_std_type @@ -92,32 +96,34 @@ def __init__( else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # Action distribution (populated in update_distribution) + # Action distribution + # Note: Populated in update_distribution self.distribution = None - # disable args validation for speedup + + # Disable args validation for speedup Normal.set_default_validate_args(False) - def reset(self, dones=None): + def reset(self, dones: torch.Tensor | None = None) -> None: pass - def forward(self): + def forward(self) -> NoReturn: raise NotImplementedError @property - def action_mean(self): + def action_mean(self) -> torch.Tensor: return self.distribution.mean @property - def action_std(self): + def action_std(self) -> torch.Tensor: return self.distribution.stddev @property - def entropy(self): + def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def update_distribution(self, obs): + def _update_distribution(self, obs: TensorDict) -> None: if self.state_dependent_std: - # compute mean and standard deviation + # Compute mean and standard deviation mean_and_std = self.actor(obs) if self.noise_std_type == "scalar": mean, std = torch.unbind(mean_and_std, dim=-2) @@ -127,25 +133,25 @@ def update_distribution(self, obs): else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") else: - # compute mean + # Compute mean mean = self.actor(obs) - # compute standard deviation + # Compute standard deviation if self.noise_std_type == "scalar": std = self.std.expand_as(mean) elif self.noise_std_type == "log": std = torch.exp(self.log_std).expand_as(mean) else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # create distribution + # Create distribution self.distribution = Normal(mean, std) - def act(self, obs, **kwargs): + def act(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: obs = self.get_actor_obs(obs) obs = self.actor_obs_normalizer(obs) - self.update_distribution(obs) + self._update_distribution(obs) return self.distribution.sample() - def act_inference(self, obs): + def act_inference(self, obs: TensorDict) -> torch.Tensor: obs = self.get_actor_obs(obs) obs = self.actor_obs_normalizer(obs) if self.state_dependent_std: @@ -153,27 +159,23 @@ def act_inference(self, obs): else: return self.actor(obs) - def evaluate(self, obs, **kwargs): + def evaluate(self, obs: TensorDict, **kwargs: dict[str, Any]) -> torch.Tensor: obs = self.get_critic_obs(obs) obs = self.critic_obs_normalizer(obs) return self.critic(obs) - def get_actor_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["policy"]: - obs_list.append(obs[obs_group]) + def get_actor_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]] return torch.cat(obs_list, dim=-1) - def get_critic_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["critic"]: - obs_list.append(obs[obs_group]) + def get_critic_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["critic"]] return torch.cat(obs_list, dim=-1) - def get_actions_log_prob(self, actions): + def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor: return self.distribution.log_prob(actions).sum(dim=-1) - def update_normalization(self, obs): + def update_normalization(self, obs: TensorDict) -> None: if self.actor_obs_normalization: actor_obs = self.get_actor_obs(obs) self.actor_obs_normalizer.update(actor_obs) @@ -181,18 +183,17 @@ def update_normalization(self, obs): critic_obs = self.get_critic_obs(obs) self.critic_obs_normalizer.update(critic_obs) - def load_state_dict(self, state_dict, strict=True): + def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool: """Load the parameters of the actor-critic model. Args: - state_dict (dict): State dictionary of the model. - strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this - module's state_dict() function. + state_dict: State dictionary of the model. + strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's + :meth:`state_dict` function. Returns: - bool: Whether this training resumes a previous training. This flag is used by the `load()` function of - `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation). + Whether this training resumes a previous training. This flag is used by the :func:`load` function of + :class:`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation). """ - super().load_state_dict(state_dict, strict=strict) - return True # training resumes + return True diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py index 03dfc162..509b6821 100644 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -8,32 +8,34 @@ import torch import torch.nn as nn import warnings +from tensordict import TensorDict from torch.distributions import Normal +from typing import Any, NoReturn -from rsl_rl.networks import MLP, EmpiricalNormalization, Memory +from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState, Memory class ActorCriticRecurrent(nn.Module): - is_recurrent = True + is_recurrent: bool = True def __init__( self, - obs, - obs_groups, - num_actions, - actor_obs_normalization=False, - critic_obs_normalization=False, - actor_hidden_dims=[256, 256, 256], - critic_hidden_dims=[256, 256, 256], - activation="elu", - init_noise_std=1.0, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, + actor_obs_normalization: bool = False, + critic_obs_normalization: bool = False, + actor_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + critic_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + activation: str = "elu", + init_noise_std: float = 1.0, noise_std_type: str = "scalar", - state_dependent_std=False, - rnn_type="lstm", - rnn_hidden_dim=256, - rnn_num_layers=1, - **kwargs, - ): + state_dependent_std: bool = False, + rnn_type: str = "lstm", + rnn_hidden_dim: int = 256, + rnn_num_layers: int = 1, + **kwargs: dict[str, Any], + ) -> None: if "rnn_hidden_size" in kwargs: warnings.warn( "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. " @@ -48,7 +50,7 @@ def __init__( ) super().__init__() - # get the observation dimensions + # Get the observation dimensions self.obs_groups = obs_groups num_actor_obs = 0 for obs_group in obs_groups["policy"]: @@ -60,33 +62,35 @@ def __init__( num_critic_obs += obs[obs_group].shape[-1] self.state_dependent_std = state_dependent_std - # actor - self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim) + + # Actor + self.memory_a = Memory(num_actor_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) if self.state_dependent_std: self.actor = MLP(rnn_hidden_dim, [2, num_actions], actor_hidden_dims, activation) else: self.actor = MLP(rnn_hidden_dim, num_actions, actor_hidden_dims, activation) + print(f"Actor RNN: {self.memory_a}") + print(f"Actor MLP: {self.actor}") - # actor observation normalization + # Actor observation normalization self.actor_obs_normalization = actor_obs_normalization if actor_obs_normalization: self.actor_obs_normalizer = EmpiricalNormalization(num_actor_obs) else: self.actor_obs_normalizer = torch.nn.Identity() - print(f"Actor RNN: {self.memory_a}") - print(f"Actor MLP: {self.actor}") - # critic - self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim) + # Critic + self.memory_c = Memory(num_critic_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) self.critic = MLP(rnn_hidden_dim, 1, critic_hidden_dims, activation) - # critic observation normalization + print(f"Critic RNN: {self.memory_c}") + print(f"Critic MLP: {self.critic}") + + # Critic observation normalization self.critic_obs_normalization = critic_obs_normalization if critic_obs_normalization: self.critic_obs_normalizer = EmpiricalNormalization(num_critic_obs) else: self.critic_obs_normalizer = torch.nn.Identity() - print(f"Critic RNN: {self.memory_c}") - print(f"Critic MLP: {self.critic}") # Action noise self.noise_std_type = noise_std_type @@ -108,33 +112,35 @@ def __init__( else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # Action distribution (populated in update_distribution) + # Action distribution + # Note: Populated in update_distribution self.distribution = None - # disable args validation for speedup + + # Disable args validation for speedup Normal.set_default_validate_args(False) @property - def action_mean(self): + def action_mean(self) -> torch.Tensor: return self.distribution.mean @property - def action_std(self): + def action_std(self) -> torch.Tensor: return self.distribution.stddev @property - def entropy(self): + def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def reset(self, dones=None): + def reset(self, dones: torch.Tensor | None = None) -> None: self.memory_a.reset(dones) self.memory_c.reset(dones) - def forward(self): + def forward(self) -> NoReturn: raise NotImplementedError - def update_distribution(self, obs): + def _update_distribution(self, obs: TensorDict) -> None: if self.state_dependent_std: - # compute mean and standard deviation + # Compute mean and standard deviation mean_and_std = self.actor(obs) if self.noise_std_type == "scalar": mean, std = torch.unbind(mean_and_std, dim=-2) @@ -144,26 +150,26 @@ def update_distribution(self, obs): else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") else: - # compute mean + # Compute mean mean = self.actor(obs) - # compute standard deviation + # Compute standard deviation if self.noise_std_type == "scalar": std = self.std.expand_as(mean) elif self.noise_std_type == "log": std = torch.exp(self.log_std).expand_as(mean) else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # create distribution + # Create distribution self.distribution = Normal(mean, std) - def act(self, obs, masks=None, hidden_states=None): + def act(self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None) -> torch.Tensor: obs = self.get_actor_obs(obs) obs = self.actor_obs_normalizer(obs) - out_mem = self.memory_a(obs, masks, hidden_states).squeeze(0) - self.update_distribution(out_mem) + out_mem = self.memory_a(obs, masks, hidden_state).squeeze(0) + self._update_distribution(out_mem) return self.distribution.sample() - def act_inference(self, obs): + def act_inference(self, obs: TensorDict) -> torch.Tensor: obs = self.get_actor_obs(obs) obs = self.actor_obs_normalizer(obs) out_mem = self.memory_a(obs).squeeze(0) @@ -172,31 +178,29 @@ def act_inference(self, obs): else: return self.actor(out_mem) - def evaluate(self, obs, masks=None, hidden_states=None): + def evaluate( + self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state: HiddenState = None + ) -> torch.Tensor: obs = self.get_critic_obs(obs) obs = self.critic_obs_normalizer(obs) - out_mem = self.memory_c(obs, masks, hidden_states).squeeze(0) + out_mem = self.memory_c(obs, masks, hidden_state).squeeze(0) return self.critic(out_mem) - def get_actor_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["policy"]: - obs_list.append(obs[obs_group]) + def get_actor_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]] return torch.cat(obs_list, dim=-1) - def get_critic_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["critic"]: - obs_list.append(obs[obs_group]) + def get_critic_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["critic"]] return torch.cat(obs_list, dim=-1) - def get_actions_log_prob(self, actions): + def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor: return self.distribution.log_prob(actions).sum(dim=-1) - def get_hidden_states(self): - return self.memory_a.hidden_states, self.memory_c.hidden_states + def get_hidden_states(self) -> tuple[HiddenState, HiddenState]: + return self.memory_a.hidden_state, self.memory_c.hidden_state - def update_normalization(self, obs): + def update_normalization(self, obs: TensorDict) -> None: if self.actor_obs_normalization: actor_obs = self.get_actor_obs(obs) self.actor_obs_normalizer.update(actor_obs) @@ -204,18 +208,17 @@ def update_normalization(self, obs): critic_obs = self.get_critic_obs(obs) self.critic_obs_normalizer.update(critic_obs) - def load_state_dict(self, state_dict, strict=True): + def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool: """Load the parameters of the actor-critic model. Args: - state_dict (dict): State dictionary of the model. - strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this - module's state_dict() function. + state_dict: State dictionary of the model. + strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's + :meth:`state_dict` function. Returns: - bool: Whether this training resumes a previous training. This flag is used by the `load()` function of - `OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation). + Whether this training resumes a previous training. This flag is used by the :func:`load` function of + :class:`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation). """ - super().load_state_dict(state_dict, strict=strict) return True diff --git a/rsl_rl/modules/rnd.py b/rsl_rl/modules/rnd.py index 8e65c43a..3b536dcd 100644 --- a/rsl_rl/modules/rnd.py +++ b/rsl_rl/modules/rnd.py @@ -7,15 +7,18 @@ import torch import torch.nn as nn +from tensordict import TensorDict +from typing import Any, NoReturn +from rsl_rl.env import VecEnv from rsl_rl.networks import MLP, EmpiricalDiscountedVariationNormalization, EmpiricalNormalization class RandomNetworkDistillation(nn.Module): - """Implementation of Random Network Distillation (RND) [1] + """Implementation of Random Network Distillation (RND) [1]. References: - .. [1] Burda, Yuri, et al. "Exploration by random network distillation." arXiv preprint arXiv:1810.12894 (2018). + .. [1] Burda, Yuri, et al. "Exploration by Random Network Distillation." arXiv preprint arXiv:1810.12894 (2018). """ def __init__( @@ -23,55 +26,54 @@ def __init__( num_states: int, obs_groups: dict, num_outputs: int, - predictor_hidden_dims: list[int], - target_hidden_dims: list[int], + predictor_hidden_dims: tuple[int] | list[int], + target_hidden_dims: tuple[int] | list[int], activation: str = "elu", weight: float = 0.0, state_normalization: bool = False, reward_normalization: bool = False, device: str = "cpu", weight_schedule: dict | None = None, - ): + ) -> None: """Initialize the RND module. - - If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization layer. + - If :attr:`state_normalization` is True, then the input state is normalized using an Empirical Normalization + layer. - If :attr:`reward_normalization` is True, then the intrinsic reward is normalized using an Empirical Discounted Variation Normalization layer. - - .. note:: - If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states - is used as the hidden dimension. + - If the hidden dimensions are -1 in the predictor and target networks configuration, then the number of states + is used as the hidden dimension. Args: num_states: Number of states/inputs to the predictor and target networks. + obs_groups: Dictionary of observation groups. num_outputs: Number of outputs (embedding size) of the predictor and target networks. predictor_hidden_dims: List of hidden dimensions of the predictor network. target_hidden_dims: List of hidden dimensions of the target network. - activation: Activation function. Defaults to "elu". - weight: Scaling factor of the intrinsic reward. Defaults to 0.0. - state_normalization: Whether to normalize the input state. Defaults to False. - reward_normalization: Whether to normalize the intrinsic reward. Defaults to False. - device: Device to use. Defaults to "cpu". - weight_schedule: The type of schedule to use for the RND weight parameter. - Defaults to None, in which case the weight parameter is constant. + activation: Activation function. + weight: Scaling factor of the intrinsic reward. + state_normalization: Whether to normalize the input state. + reward_normalization: Whether to normalize the intrinsic reward. + device: Device to use. + weight_schedule: Type of schedule to use for the RND weight parameter. It is a dictionary with the following keys: - - "mode": The type of schedule to use for the RND weight parameter. + - "mode": Type of schedule to use for the RND weight parameter. - "constant": Constant weight schedule. - "step": Step weight schedule. - "linear": Linear weight schedule. For the "step" weight schedule, the following parameters are required: - - "final_step": The step at which the weight parameter is set to the final value. - - "final_value": The final value of the weight parameter. + - "final_step": Step at which the weight parameter is set to the final value. + - "final_value": Final value of the weight parameter. For the "linear" weight schedule, the following parameters are required: - - "initial_step": The step at which the weight parameter is set to the initial value. - - "final_step": The step at which the weight parameter is set to the final value. - - "final_value": The final value of the weight parameter. + - "initial_step": Step at which the weight parameter is set to the initial value. + - "final_step": Step at which the weight parameter is set to the final value. + - "final_value": Final value of the weight parameter. """ - # initialize parent class + # Initialize parent class super().__init__() # Store parameters @@ -88,30 +90,32 @@ def __init__( self.state_normalizer = EmpiricalNormalization(shape=[self.num_states], until=1.0e8).to(self.device) else: self.state_normalizer = torch.nn.Identity() + # Normalization of intrinsic reward if reward_normalization: self.reward_normalizer = EmpiricalDiscountedVariationNormalization(shape=[], until=1.0e8).to(self.device) else: self.reward_normalizer = torch.nn.Identity() - # counter for the number of updates + # Counter for the number of updates self.update_counter = 0 - # resolve weight schedule + # Resolve weight schedule if weight_schedule is not None: self.weight_scheduler_params = weight_schedule self.weight_scheduler = getattr(self, f"_{weight_schedule['mode']}_weight_schedule") else: self.weight_scheduler = None + # Create network architecture self.predictor = MLP(num_states, num_outputs, predictor_hidden_dims, activation).to(self.device) self.target = MLP(num_states, num_outputs, target_hidden_dims, activation).to(self.device) - # make target network not trainable + # Make target network not trainable self.target.eval() - def get_intrinsic_reward(self, obs) -> torch.Tensor: - # Note: the counter is updated number of env steps per learning iteration + def get_intrinsic_reward(self, obs: TensorDict) -> torch.Tensor: + # Note: The counter is updated number of env steps per learning iteration self.update_counter += 1 # Extract the rnd state from the observation rnd_state = self.get_rnd_state(obs) @@ -123,7 +127,6 @@ def get_intrinsic_reward(self, obs) -> torch.Tensor: intrinsic_reward = torch.linalg.norm(target_embedding - predictor_embedding, dim=1) # Normalize intrinsic reward intrinsic_reward = self.reward_normalizer(intrinsic_reward) - # Check the weight schedule if self.weight_scheduler is not None: self.weight = self.weight_scheduler(step=self.update_counter, **self.weight_scheduler_params) @@ -134,11 +137,11 @@ def get_intrinsic_reward(self, obs) -> torch.Tensor: return intrinsic_reward - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: dict[str, Any]) -> NoReturn: raise RuntimeError("Forward method is not implemented. Use get_intrinsic_reward instead.") - def train(self, mode: bool = True): - # sets module into training mode + def train(self, mode: bool = True) -> RandomNetworkDistillation: + # Set module into training mode self.predictor.train(mode) if self.state_normalization: self.state_normalizer.train(mode) @@ -146,32 +149,28 @@ def train(self, mode: bool = True): self.reward_normalizer.train(mode) return self - def eval(self): + def eval(self) -> RandomNetworkDistillation: return self.train(False) - def get_rnd_state(self, obs): - obs_list = [] - for obs_group in self.obs_groups["rnd_state"]: - obs_list.append(obs[obs_group]) + def get_rnd_state(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["rnd_state"]] return torch.cat(obs_list, dim=-1) - def update_normalization(self, obs): + def update_normalization(self, obs: TensorDict) -> None: # Normalize the state if self.state_normalization: rnd_state = self.get_rnd_state(obs) self.state_normalizer.update(rnd_state) - """ - Different weight schedules. - """ - - def _constant_weight_schedule(self, step: int, **kwargs): + def _constant_weight_schedule(self, step: int, **kwargs: dict[str, Any]) -> float: return self.initial_weight - def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs): + def _step_weight_schedule(self, step: int, final_step: int, final_value: float, **kwargs: dict[str, Any]) -> float: return self.initial_weight if step < final_step else final_value - def _linear_weight_schedule(self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs): + def _linear_weight_schedule( + self, step: int, initial_step: int, final_step: int, final_value: float, **kwargs: dict[str, Any] + ) -> float: if step < initial_step: return self.initial_weight elif step > final_step: @@ -182,28 +181,28 @@ def _linear_weight_schedule(self, step: int, initial_step: int, final_step: int, ) -def resolve_rnd_config(alg_cfg, obs, obs_groups, env): +def resolve_rnd_config(alg_cfg: dict, obs: TensorDict, obs_groups: dict[str, list[str]], env: VecEnv) -> dict: """Resolve the RND configuration. Args: - alg_cfg: The algorithm configuration dictionary. - obs: The observation dictionary. - obs_groups: The observation groups dictionary. - env: The environment. + alg_cfg: Algorithm configuration dictionary. + obs: Observation dictionary. + obs_groups: Observation groups dictionary. + env: Environment object. Returns: The resolved algorithm configuration dictionary. """ - # resolve dimension of rnd gated state + # Resolve dimension of rnd gated state if "rnd_cfg" in alg_cfg and alg_cfg["rnd_cfg"] is not None: - # get dimension of rnd gated state + # Get dimension of rnd gated state num_rnd_state = 0 for obs_group in obs_groups["rnd_state"]: assert len(obs[obs_group].shape) == 2, "The RND module only supports 1D observations." num_rnd_state += obs[obs_group].shape[-1] - # add rnd gated state to config + # Add rnd gated state to config alg_cfg["rnd_cfg"]["num_states"] = num_rnd_state alg_cfg["rnd_cfg"]["obs_groups"] = obs_groups - # scale down the rnd weight with timestep + # Scale down the rnd weight with timestep alg_cfg["rnd_cfg"]["weight"] *= env.unwrapped.step_dt return alg_cfg diff --git a/rsl_rl/modules/student_teacher.py b/rsl_rl/modules/student_teacher.py index 6bf1380c..82b3ef62 100644 --- a/rsl_rl/modules/student_teacher.py +++ b/rsl_rl/modules/student_teacher.py @@ -7,38 +7,40 @@ import torch import torch.nn as nn +from tensordict import TensorDict from torch.distributions import Normal +from typing import Any, NoReturn -from rsl_rl.networks import MLP, EmpiricalNormalization +from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState class StudentTeacher(nn.Module): - is_recurrent = False + is_recurrent: bool = False def __init__( self, - obs, - obs_groups, - num_actions, - student_obs_normalization=False, - teacher_obs_normalization=False, - student_hidden_dims=[256, 256, 256], - teacher_hidden_dims=[256, 256, 256], - activation="elu", - init_noise_std=0.1, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, + student_obs_normalization: bool = False, + teacher_obs_normalization: bool = False, + student_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + teacher_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + activation: str = "elu", + init_noise_std: float = 0.1, noise_std_type: str = "scalar", - **kwargs, - ): + **kwargs: dict[str, Any], + ) -> None: if kwargs: print( "StudentTeacher.__init__ got unexpected arguments, which will be ignored: " - + str([key for key in kwargs.keys()]) + + str([key for key in kwargs]) ) super().__init__() - self.loaded_teacher = False # indicates if teacher has been loaded + self.loaded_teacher = False # Indicates if teacher has been loaded - # get the observation dimensions + # Get the observation dimensions self.obs_groups = obs_groups num_student_obs = 0 for obs_group in obs_groups["policy"]: @@ -49,32 +51,30 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The StudentTeacher module only supports 1D observations." num_teacher_obs += obs[obs_group].shape[-1] - # student + # Student self.student = MLP(num_student_obs, num_actions, student_hidden_dims, activation) + print(f"Student MLP: {self.student}") - # student observation normalization + # Student observation normalization self.student_obs_normalization = student_obs_normalization if student_obs_normalization: self.student_obs_normalizer = EmpiricalNormalization(num_student_obs) else: self.student_obs_normalizer = torch.nn.Identity() - print(f"Student MLP: {self.student}") - - # teacher + # Teacher self.teacher = MLP(num_teacher_obs, num_actions, teacher_hidden_dims, activation) self.teacher.eval() + print(f"Teacher MLP: {self.teacher}") - # teacher observation normalization + # Teacher observation normalization self.teacher_obs_normalization = teacher_obs_normalization if teacher_obs_normalization: self.teacher_obs_normalizer = EmpiricalNormalization(num_teacher_obs) else: self.teacher_obs_normalizer = torch.nn.Identity() - print(f"Teacher MLP: {self.teacher}") - - # action noise + # Action noise self.noise_std_type = noise_std_type if self.noise_std_type == "scalar": self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) @@ -83,104 +83,103 @@ def __init__( else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # action distribution (populated in update_distribution) + # Action distribution + # Note: Populated in update_distribution self.distribution = None - # disable args validation for speedup + + # Disable args validation for speedup Normal.set_default_validate_args(False) - def reset(self, dones=None, hidden_states=None): + def reset( + self, dones: torch.Tensor | None = None, hidden_states: tuple[HiddenState, HiddenState] = (None, None) + ) -> None: pass - def forward(self): + def forward(self) -> NoReturn: raise NotImplementedError @property - def action_mean(self): + def action_mean(self) -> torch.Tensor: return self.distribution.mean @property - def action_std(self): + def action_std(self) -> torch.Tensor: return self.distribution.stddev @property - def entropy(self): + def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def update_distribution(self, obs): - # compute mean + def _update_distribution(self, obs: TensorDict) -> None: + # Compute mean mean = self.student(obs) - # compute standard deviation + # Compute standard deviation if self.noise_std_type == "scalar": std = self.std.expand_as(mean) elif self.noise_std_type == "log": std = torch.exp(self.log_std).expand_as(mean) else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # create distribution + # Create distribution self.distribution = Normal(mean, std) - def act(self, obs): + def act(self, obs: TensorDict) -> torch.Tensor: obs = self.get_student_obs(obs) obs = self.student_obs_normalizer(obs) - self.update_distribution(obs) + self._update_distribution(obs) return self.distribution.sample() - def act_inference(self, obs): + def act_inference(self, obs: TensorDict) -> torch.Tensor: obs = self.get_student_obs(obs) obs = self.student_obs_normalizer(obs) return self.student(obs) - def evaluate(self, obs): + def evaluate(self, obs: TensorDict) -> torch.Tensor: obs = self.get_teacher_obs(obs) obs = self.teacher_obs_normalizer(obs) with torch.no_grad(): return self.teacher(obs) - def get_student_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["policy"]: - obs_list.append(obs[obs_group]) + def get_student_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]] return torch.cat(obs_list, dim=-1) - def get_teacher_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["teacher"]: - obs_list.append(obs[obs_group]) + def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]] return torch.cat(obs_list, dim=-1) - def get_hidden_states(self): - return None + def get_hidden_states(self) -> tuple[HiddenState, HiddenState]: + return None, None - def detach_hidden_states(self, dones=None): + def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None: pass - def train(self, mode=True): + def train(self, mode: bool = True) -> None: super().train(mode) - # make sure teacher is in eval mode + # Make sure teacher is in eval mode self.teacher.eval() self.teacher_obs_normalizer.eval() - def update_normalization(self, obs): + def update_normalization(self, obs: TensorDict) -> None: if self.student_obs_normalization: student_obs = self.get_student_obs(obs) self.student_obs_normalizer.update(student_obs) - def load_state_dict(self, state_dict, strict=True): + def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool: """Load the parameters of the student and teacher networks. Args: - state_dict (dict): State dictionary of the model. - strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this - module's state_dict() function. + state_dict: State dictionary of the model. + strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's + :meth:`state_dict` function. Returns: - bool: Whether this training resumes a previous training. This flag is used by the `load()` function of - `OnPolicyRunner` to determine how to load further parameters. + Whether this training resumes a previous training. This flag is used by the :func:`load` function of + :class:`OnPolicyRunner` to determine how to load further parameters. """ - - # check if state_dict contains teacher and student or just teacher parameters - if any("actor" in key for key in state_dict.keys()): # loading parameters from rl training - # rename keys to match teacher and remove critic parameters + # Check if state_dict contains teacher and student or just teacher parameters + if any("actor" in key for key in state_dict): # Load parameters from rl training + # Rename keys to match teacher and remove critic parameters teacher_state_dict = {} teacher_obs_normalizer_state_dict = {} for key, value in state_dict.items(): @@ -190,17 +189,17 @@ def load_state_dict(self, state_dict, strict=True): teacher_obs_normalizer_state_dict[key.replace("actor_obs_normalizer.", "")] = value self.teacher.load_state_dict(teacher_state_dict, strict=strict) self.teacher_obs_normalizer.load_state_dict(teacher_obs_normalizer_state_dict, strict=strict) - # set flag for successfully loading the parameters + # Set flag for successfully loading the parameters self.loaded_teacher = True self.teacher.eval() self.teacher_obs_normalizer.eval() - return False # training does not resume - elif any("student" in key for key in state_dict.keys()): # loading parameters from distillation training + return False # Training does not resume + elif any("student" in key for key in state_dict): # Load parameters from distillation training super().load_state_dict(state_dict, strict=strict) - # set flag for successfully loading the parameters + # Set flag for successfully loading the parameters self.loaded_teacher = True self.teacher.eval() self.teacher_obs_normalizer.eval() - return True # training resumes + return True # Training resumes else: raise ValueError("state_dict does not contain student or teacher parameters") diff --git a/rsl_rl/modules/student_teacher_recurrent.py b/rsl_rl/modules/student_teacher_recurrent.py index bba4bd34..2819c560 100644 --- a/rsl_rl/modules/student_teacher_recurrent.py +++ b/rsl_rl/modules/student_teacher_recurrent.py @@ -8,32 +8,34 @@ import torch import torch.nn as nn import warnings +from tensordict import TensorDict from torch.distributions import Normal +from typing import Any, NoReturn -from rsl_rl.networks import MLP, EmpiricalNormalization, Memory +from rsl_rl.networks import MLP, EmpiricalNormalization, HiddenState, Memory class StudentTeacherRecurrent(nn.Module): - is_recurrent = True + is_recurrent: bool = True def __init__( self, - obs, - obs_groups, - num_actions, - student_obs_normalization=False, - teacher_obs_normalization=False, - student_hidden_dims=[256, 256, 256], - teacher_hidden_dims=[256, 256, 256], - activation="elu", - init_noise_std=0.1, + obs: TensorDict, + obs_groups: dict[str, list[str]], + num_actions: int, + student_obs_normalization: bool = False, + teacher_obs_normalization: bool = False, + student_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + teacher_hidden_dims: tuple[int] | list[int] = [256, 256, 256], + activation: str = "elu", + init_noise_std: float = 0.1, noise_std_type: str = "scalar", - rnn_type="lstm", - rnn_hidden_dim=256, - rnn_num_layers=1, - teacher_recurrent=False, - **kwargs, - ): + rnn_type: str = "lstm", + rnn_hidden_dim: int = 256, + rnn_num_layers: int = 1, + teacher_recurrent: bool = False, + **kwargs: dict[str, Any], + ) -> None: if "rnn_hidden_size" in kwargs: warnings.warn( "The argument `rnn_hidden_size` is deprecated and will be removed in a future version. " @@ -49,10 +51,10 @@ def __init__( ) super().__init__() - self.loaded_teacher = False # indicates if teacher has been loaded - self.teacher_recurrent = teacher_recurrent # indicates if teacher is recurrent too + self.loaded_teacher = False # Indicates if teacher has been loaded + self.teacher_recurrent = teacher_recurrent # Indicates if teacher is recurrent too - # get the observation dimensions + # Get the observation dimensions self.obs_groups = obs_groups num_student_obs = 0 for obs_group in obs_groups["policy"]: @@ -63,39 +65,35 @@ def __init__( assert len(obs[obs_group].shape) == 2, "The StudentTeacher module only supports 1D observations." num_teacher_obs += obs[obs_group].shape[-1] - # student - self.memory_s = Memory(num_student_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim) + # Student + self.memory_s = Memory(num_student_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) self.student = MLP(rnn_hidden_dim, num_actions, student_hidden_dims, activation) + print(f"Student RNN: {self.memory_s}") + print(f"Student MLP: {self.student}") - # student observation normalization + # Student observation normalization self.student_obs_normalization = student_obs_normalization if student_obs_normalization: self.student_obs_normalizer = EmpiricalNormalization(num_student_obs) else: self.student_obs_normalizer = torch.nn.Identity() - print(f"Student RNN: {self.memory_s}") - print(f"Student MLP: {self.student}") - - # teacher + # Teacher if self.teacher_recurrent: - self.memory_t = Memory( - num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim - ) + self.memory_t = Memory(num_teacher_obs, rnn_hidden_dim, rnn_num_layers, rnn_type) self.teacher = MLP(rnn_hidden_dim, num_actions, teacher_hidden_dims, activation) + if self.teacher_recurrent: + print(f"Teacher RNN: {self.memory_t}") + print(f"Teacher MLP: {self.teacher}") - # teacher observation normalization + # Teacher observation normalization self.teacher_obs_normalization = teacher_obs_normalization if teacher_obs_normalization: self.teacher_obs_normalizer = EmpiricalNormalization(num_teacher_obs) else: self.teacher_obs_normalizer = torch.nn.Identity() - if self.teacher_recurrent: - print(f"Teacher RNN: {self.memory_t}") - print(f"Teacher MLP: {self.teacher}") - - # action noise + # Action noise self.noise_std_type = noise_std_type if self.noise_std_type == "scalar": self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) @@ -104,60 +102,62 @@ def __init__( else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # action distribution (populated in update_distribution) + # Action distribution + # Note: Populated in update_distribution self.distribution = None - # disable args validation for speedup + + # Disable args validation for speedup Normal.set_default_validate_args(False) - def reset(self, dones=None, hidden_states=None): - if hidden_states is None: - hidden_states = (None, None) + def reset( + self, dones: torch.Tensor | None = None, hidden_states: tuple[HiddenState, HiddenState] = (None, None) + ) -> None: self.memory_s.reset(dones, hidden_states[0]) if self.teacher_recurrent: self.memory_t.reset(dones, hidden_states[1]) - def forward(self): + def forward(self) -> NoReturn: raise NotImplementedError @property - def action_mean(self): + def action_mean(self) -> torch.Tensor: return self.distribution.mean @property - def action_std(self): + def action_std(self) -> torch.Tensor: return self.distribution.stddev @property - def entropy(self): + def entropy(self) -> torch.Tensor: return self.distribution.entropy().sum(dim=-1) - def update_distribution(self, obs): - # compute mean + def _update_distribution(self, obs: TensorDict) -> None: + # Compute mean mean = self.student(obs) - # compute standard deviation + # Compute standard deviation if self.noise_std_type == "scalar": std = self.std.expand_as(mean) elif self.noise_std_type == "log": std = torch.exp(self.log_std).expand_as(mean) else: raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'") - # create distribution + # Create distribution self.distribution = Normal(mean, std) - def act(self, obs): + def act(self, obs: TensorDict) -> torch.Tensor: obs = self.get_student_obs(obs) obs = self.student_obs_normalizer(obs) out_mem = self.memory_s(obs).squeeze(0) - self.update_distribution(out_mem) + self._update_distribution(out_mem) return self.distribution.sample() - def act_inference(self, obs): + def act_inference(self, obs: TensorDict) -> torch.Tensor: obs = self.get_student_obs(obs) obs = self.student_obs_normalizer(obs) out_mem = self.memory_s(obs).squeeze(0) return self.student(out_mem) - def evaluate(self, obs): + def evaluate(self, obs: TensorDict) -> torch.Tensor: obs = self.get_teacher_obs(obs) obs = self.teacher_obs_normalizer(obs) with torch.no_grad(): @@ -166,56 +166,51 @@ def evaluate(self, obs): obs = self.memory_t(obs).squeeze(0) return self.teacher(obs) - def get_student_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["policy"]: - obs_list.append(obs[obs_group]) + def get_student_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["policy"]] return torch.cat(obs_list, dim=-1) - def get_teacher_obs(self, obs): - obs_list = [] - for obs_group in self.obs_groups["teacher"]: - obs_list.append(obs[obs_group]) + def get_teacher_obs(self, obs: TensorDict) -> torch.Tensor: + obs_list = [obs[obs_group] for obs_group in self.obs_groups["teacher"]] return torch.cat(obs_list, dim=-1) - def get_hidden_states(self): + def get_hidden_states(self) -> tuple[HiddenState, HiddenState]: if self.teacher_recurrent: - return self.memory_s.hidden_states, self.memory_t.hidden_states + return self.memory_s.hidden_state, self.memory_t.hidden_state else: - return self.memory_s.hidden_states, None + return self.memory_s.hidden_state, None - def detach_hidden_states(self, dones=None): - self.memory_s.detach_hidden_states(dones) + def detach_hidden_states(self, dones: torch.Tensor | None = None) -> None: + self.memory_s.detach_hidden_state(dones) if self.teacher_recurrent: - self.memory_t.detach_hidden_states(dones) + self.memory_t.detach_hidden_state(dones) - def train(self, mode=True): + def train(self, mode: bool = True) -> None: super().train(mode) - # make sure teacher is in eval mode + # Make sure teacher is in eval mode self.teacher.eval() self.teacher_obs_normalizer.eval() - def update_normalization(self, obs): + def update_normalization(self, obs: TensorDict) -> None: if self.student_obs_normalization: student_obs = self.get_student_obs(obs) self.student_obs_normalizer.update(student_obs) - def load_state_dict(self, state_dict, strict=True): + def load_state_dict(self, state_dict: dict, strict: bool = True) -> bool: """Load the parameters of the student and teacher networks. Args: - state_dict (dict): State dictionary of the model. - strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this - module's state_dict() function. + state_dict: State dictionary of the model. + strict: Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's + :meth:`state_dict` function. Returns: - bool: Whether this training resumes a previous training. This flag is used by the `load()` function of - `OnPolicyRunner` to determine how to load further parameters. + Whether this training resumes a previous training. This flag is used by the :func:`load` function of + :class:`OnPolicyRunner` to determine how to load further parameters. """ - - # check if state_dict contains teacher and student or just teacher parameters - if any("actor" in key for key in state_dict.keys()): # loading parameters from rl training - # rename keys to match teacher and remove critic parameters + # Check if state_dict contains teacher and student or just teacher parameters + if any("actor" in key for key in state_dict): # Load parameters from rl training + # Rename keys to match teacher and remove critic parameters teacher_state_dict = {} teacher_obs_normalizer_state_dict = {} for key, value in state_dict.items(): @@ -225,24 +220,24 @@ def load_state_dict(self, state_dict, strict=True): teacher_obs_normalizer_state_dict[key.replace("actor_obs_normalizer.", "")] = value self.teacher.load_state_dict(teacher_state_dict, strict=strict) self.teacher_obs_normalizer.load_state_dict(teacher_obs_normalizer_state_dict, strict=strict) - # also load recurrent memory if teacher is recurrent + # Also load recurrent memory if teacher is recurrent if self.teacher_recurrent: memory_t_state_dict = {} for key, value in state_dict.items(): if "memory_a." in key: memory_t_state_dict[key.replace("memory_a.", "")] = value self.memory_t.load_state_dict(memory_t_state_dict, strict=strict) - # set flag for successfully loading the parameters + # Set flag for successfully loading the parameters self.loaded_teacher = True self.teacher.eval() self.teacher_obs_normalizer.eval() - return False # training does not resume - elif any("student" in key for key in state_dict.keys()): # loading parameters from distillation training + return False # Training does not resume + elif any("student" in key for key in state_dict): # Load parameters from distillation training super().load_state_dict(state_dict, strict=strict) - # set flag for successfully loading the parameters + # Set flag for successfully loading the parameters self.loaded_teacher = True self.teacher.eval() self.teacher_obs_normalizer.eval() - return True # training resumes + return True # Training resumes else: raise ValueError("state_dict does not contain student or teacher parameters") diff --git a/rsl_rl/modules/symmetry.py b/rsl_rl/modules/symmetry.py index b0175151..8e21d1da 100644 --- a/rsl_rl/modules/symmetry.py +++ b/rsl_rl/modules/symmetry.py @@ -5,20 +5,21 @@ from __future__ import annotations +from rsl_rl.env import VecEnv -def resolve_symmetry_config(alg_cfg, env): + +def resolve_symmetry_config(alg_cfg: dict, env: VecEnv) -> dict: """Resolve the symmetry configuration. Args: - alg_cfg: The algorithm configuration dictionary. - env: The environment. + alg_cfg: Algorithm configuration dictionary. + env: Environment object. Returns: The resolved algorithm configuration dictionary. """ - - # if using symmetry then pass the environment config object + # If using symmetry then pass the environment config object + # Note: This is used by the symmetry function for handling different observation terms if "symmetry_cfg" in alg_cfg and alg_cfg["symmetry_cfg"] is not None: - # this is used by the symmetry function for handling different observation terms alg_cfg["symmetry_cfg"]["_env"] = env return alg_cfg diff --git a/rsl_rl/networks/__init__.py b/rsl_rl/networks/__init__.py index c18f487a..7ede0665 100644 --- a/rsl_rl/networks/__init__.py +++ b/rsl_rl/networks/__init__.py @@ -5,6 +5,14 @@ """Definitions for components of modules.""" -from .memory import Memory +from .memory import HiddenState, Memory from .mlp import MLP from .normalization import EmpiricalDiscountedVariationNormalization, EmpiricalNormalization + +__all__ = [ + "MLP", + "EmpiricalDiscountedVariationNormalization", + "EmpiricalNormalization", + "HiddenState", + "Memory", +] diff --git a/rsl_rl/networks/memory.py b/rsl_rl/networks/memory.py index 75773577..dd40afc2 100644 --- a/rsl_rl/networks/memory.py +++ b/rsl_rl/networks/memory.py @@ -5,66 +5,76 @@ from __future__ import annotations +import torch import torch.nn as nn from rsl_rl.utils import unpad_trajectories +HiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] | None +"""Type alias for the hidden state of RNNs (GRU/LSTM). + +For GRUs, this is a single tensor while for LSTMs, this is a tuple of two tensors (hidden state and cell state). +""" + class Memory(nn.Module): """Memory module for recurrent networks. - This module is used to store the hidden states of the policy. - Currently only supports GRU and LSTM. + This module is used to store the hidden state of the policy. It currently supports GRU and LSTM. """ - def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256): + def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None: super().__init__() - # RNN rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM - self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) - self.hidden_states = None + self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers) + self.hidden_state = None - def forward(self, input, masks=None, hidden_states=None): + def forward( + self, + input: torch.Tensor, + masks: torch.Tensor | None = None, + hidden_state: HiddenState = None, + ) -> torch.Tensor: batch_mode = masks is not None if batch_mode: - # batch mode: needs saved hidden states - if hidden_states is None: + # Batch mode needs saved hidden states + if hidden_state is None: raise ValueError("Hidden states not passed to memory module during policy update") - out, _ = self.rnn(input, hidden_states) + out, _ = self.rnn(input, hidden_state) out = unpad_trajectories(out, masks) else: - # inference/distillation mode: uses hidden states of last step - out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) + # Inference/distillation mode uses hidden state of last step + out, self.hidden_state = self.rnn(input.unsqueeze(0), self.hidden_state) return out - def reset(self, dones=None, hidden_states=None): - if dones is None: # reset all hidden states - if hidden_states is None: - self.hidden_states = None + def reset(self, dones: torch.Tensor | None = None, hidden_state: HiddenState = None) -> None: + if dones is None: # Reset hidden state + if hidden_state is None: + self.hidden_state = None else: - self.hidden_states = hidden_states - elif self.hidden_states is not None: # reset hidden states of done environments - if hidden_states is None: - if isinstance(self.hidden_states, tuple): # tuple in case of LSTM - for hidden_state in self.hidden_states: + self.hidden_state = hidden_state + elif self.hidden_state is not None: # Reset hidden state of done environments + if hidden_state is None: + if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM + for hidden_state in self.hidden_state: hidden_state[..., dones == 1, :] = 0.0 else: - self.hidden_states[..., dones == 1, :] = 0.0 + self.hidden_state[..., dones == 1, :] = 0.0 else: NotImplementedError( - "Resetting hidden states of done environments with custom hidden states is not implemented" + "Resetting the hidden state of done environments with a custom hidden state is not implemented" ) - def detach_hidden_states(self, dones=None): - if self.hidden_states is not None: - if dones is None: # detach all hidden states - if isinstance(self.hidden_states, tuple): # tuple in case of LSTM - self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states) + def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None: + if self.hidden_state is not None: + if dones is None: # Detach hidden state + if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM + self.hidden_state = tuple(hidden_state.detach() for hidden_state in self.hidden_state) else: - self.hidden_states = self.hidden_states.detach() - else: # detach hidden states of done environments - if isinstance(self.hidden_states, tuple): # tuple in case of LSTM - for hidden_state in self.hidden_states: + self.hidden_state = self.hidden_state.detach() + else: # Detach hidden state of done environments + if isinstance(self.hidden_state, tuple): # Tuple in case of LSTM + for hidden_state in self.hidden_state: hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach() else: - self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach() + self.hidden_state[..., dones == 1, :] = self.hidden_state[..., dones == 1, :].detach() diff --git a/rsl_rl/networks/mlp.py b/rsl_rl/networks/mlp.py index e91574ea..f01a7577 100644 --- a/rsl_rl/networks/mlp.py +++ b/rsl_rl/networks/mlp.py @@ -15,17 +15,12 @@ class MLP(nn.Sequential): """Multi-layer perceptron. - The MLP network is a sequence of linear layers and activation functions. The - last layer is a linear layer that outputs the desired dimension unless the - last activation function is specified. + The MLP network is a sequence of linear layers and activation functions. The last layer is a linear layer that + outputs the desired dimension unless the last activation function is specified. It provides additional conveniences: - - - If the hidden dimensions have a value of ``-1``, the dimension is inferred - from the input dimension. - - If the output dimension is a tuple, the output is reshaped to the desired - shape. - + - If the hidden dimensions have a value of ``-1``, the dimension is inferred from the input dimension. + - If the output dimension is a tuple, the output is reshaped to the desired shape. """ def __init__( @@ -35,27 +30,26 @@ def __init__( hidden_dims: tuple[int] | list[int], activation: str = "elu", last_activation: str | None = None, - ): + ) -> None: """Initialize the MLP. Args: input_dim: Dimension of the input. output_dim: Dimension of the output. - hidden_dims: Dimensions of the hidden layers. A value of ``-1`` indicates - that the dimension should be inferred from the input dimension. - activation: Activation function. Defaults to "elu". - last_activation: Activation function of the last layer. Defaults to None, - in which case the last layer is linear. + hidden_dims: Dimensions of the hidden layers. A value of ``-1`` indicates that the dimension should be + inferred from the input dimension. + activation: Activation function. + last_activation: Activation function of the last layer. None results in a linear last layer. """ super().__init__() - # resolve activation functions + # Resolve activation functions activation_mod = resolve_nn_activation(activation) last_activation_mod = resolve_nn_activation(last_activation) if last_activation is not None else None - # resolve number of hidden dims if they are -1 + # Resolve number of hidden dims if they are -1 hidden_dims_processed = [input_dim if dim == -1 else dim for dim in hidden_dims] - # create layers sequentially + # Create layers sequentially layers = [] layers.append(nn.Linear(input_dim, hidden_dims_processed[0])) layers.append(activation_mod) @@ -64,32 +58,32 @@ def __init__( layers.append(nn.Linear(hidden_dims_processed[layer_index], hidden_dims_processed[layer_index + 1])) layers.append(activation_mod) - # add last layer + # Add last layer if isinstance(output_dim, int): layers.append(nn.Linear(hidden_dims_processed[-1], output_dim)) else: - # compute the total output dimension + # Compute the total output dimension total_out_dim = reduce(lambda x, y: x * y, output_dim) - # add a layer to reshape the output to the desired shape + # Add a layer to reshape the output to the desired shape layers.append(nn.Linear(hidden_dims_processed[-1], total_out_dim)) layers.append(nn.Unflatten(dim=-1, unflattened_size=output_dim)) - # add last activation function if specified + # Add last activation function if specified if last_activation_mod is not None: layers.append(last_activation_mod) - # register the layers + # Register the layers for idx, layer in enumerate(layers): self.add_module(f"{idx}", layer) - def init_weights(self, scales: float | tuple[float]): + def init_weights(self, scales: float | tuple[float]) -> None: """Initialize the weights of the MLP. Args: scales: Scale factor for the weights. """ - def get_scale(idx) -> float: + def get_scale(idx: int) -> float: """Get the scale factor for the weights of the MLP. Args: @@ -97,7 +91,7 @@ def get_scale(idx) -> float: """ return scales[idx] if isinstance(scales, (list, tuple)) else scales - # initialize the weights + # Initialize the weights for idx, module in enumerate(self): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=get_scale(idx)) @@ -112,9 +106,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self: x = layer(x) return x - - def reset(self, dones=None, hidden_states=None): - pass - - def detach_hidden_states(self, dones=None): - pass diff --git a/rsl_rl/networks/normalization.py b/rsl_rl/networks/normalization.py index 5fd96921..5077479f 100644 --- a/rsl_rl/networks/normalization.py +++ b/rsl_rl/networks/normalization.py @@ -14,16 +14,15 @@ class EmpiricalNormalization(nn.Module): """Normalize mean and variance of values based on empirical values.""" - def __init__(self, shape, eps=1e-2, until=None): + def __init__(self, shape: int | tuple[int] | list[int], eps: float = 1e-2, until: int | None = None) -> None: """Initialize EmpiricalNormalization module. - Args: - shape (int or tuple of int): Shape of input values except batch axis. - eps (float): Small value for stability. - until (int or None): If this arg is specified, the module learns input values until the sum of batch sizes - exceeds it. + .. note:: The normalization parameters are computed over the whole batch, not for each environment separately. - Note: The normalization parameters are computed over the whole batch, not for each environment separately. + Args: + shape: Shape of input values except batch axis. + eps: Small value for stability. + until: If this arg is specified, the module learns input values until the sum of batch sizes exceeds it. """ super().__init__() self.eps = eps @@ -34,22 +33,20 @@ def __init__(self, shape, eps=1e-2, until=None): self.register_buffer("count", torch.tensor(0, dtype=torch.long)) @property - def mean(self): + def mean(self) -> torch.Tensor: return self._mean.squeeze(0).clone() @property - def std(self): + def std(self) -> torch.Tensor: return self._std.squeeze(0).clone() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """Normalize mean and variance of values based on empirical values.""" - return (x - self._mean) / (self._std + self.eps) @torch.jit.unused - def update(self, x): - """Learn input values without computing the output values of them""" - + def update(self, x: torch.Tensor) -> None: + """Learn input values without computing the output values of them.""" if not self.training: return if self.until is not None and self.count >= self.until: @@ -66,45 +63,41 @@ def update(self, x): self._std = torch.sqrt(self._var) @torch.jit.unused - def inverse(self, y): + def inverse(self, y: torch.Tensor) -> torch.Tensor: """De-normalize values based on empirical values.""" - return y * (self._std + self.eps) + self._mean class EmpiricalDiscountedVariationNormalization(nn.Module): """Reward normalization from Pathak's large scale study on PPO. - Reward normalization. Since the reward function is non-stationary, it is useful to normalize - the scale of the rewards so that the value function can learn quickly. We did this by dividing - the rewards by a running estimate of the standard deviation of the sum of discounted rewards. + Reward normalization. Since the reward function is non-stationary, it is useful to normalize the scale of the + rewards so that the value function can learn quickly. We did this by dividing the rewards by a running estimate of + the standard deviation of the sum of discounted rewards. """ - def __init__(self, shape, eps=1e-2, gamma=0.99, until=None): + def __init__( + self, shape: int | tuple[int] | list[int], eps: float = 1e-2, gamma: float = 0.99, until: int | None = None + ) -> None: super().__init__() self.emp_norm = EmpiricalNormalization(shape, eps, until) self.disc_avg = _DiscountedAverage(gamma) - def forward(self, rew): + def forward(self, rew: torch.Tensor) -> torch.Tensor: if self.training: - # update discounted rewards + # Update discounted rewards avg = self.disc_avg.update(rew) - # update moments from discounted rewards + # Update moments from discounted rewards self.emp_norm.update(avg) - # normalize rewards with the empirical std + # Normalize rewards with the empirical std if self.emp_norm._std > 0: return rew / self.emp_norm._std else: return rew -""" -Helper class. -""" - - class _DiscountedAverage: r"""Discounted average of rewards. @@ -113,12 +106,9 @@ class _DiscountedAverage: .. math:: \bar{R}_t = \gamma \bar{R}_{t-1} + r_t - - Args: - gamma (float): Discount factor. """ - def __init__(self, gamma): + def __init__(self, gamma: float) -> None: self.avg = None self.gamma = gamma diff --git a/rsl_rl/runners/__init__.py b/rsl_rl/runners/__init__.py index d6f630b6..568ebc3b 100644 --- a/rsl_rl/runners/__init__.py +++ b/rsl_rl/runners/__init__.py @@ -5,7 +5,7 @@ """Implementation of runners for environment-agent interaction.""" -from .on_policy_runner import OnPolicyRunner # isort:skip +from .on_policy_runner import OnPolicyRunner # noqa: I001 from .distillation_runner import DistillationRunner -__all__ = ["OnPolicyRunner", "DistillationRunner"] +__all__ = ["DistillationRunner", "OnPolicyRunner"] diff --git a/rsl_rl/runners/distillation_runner.py b/rsl_rl/runners/distillation_runner.py index 9cc6a8b1..6f9c502b 100644 --- a/rsl_rl/runners/distillation_runner.py +++ b/rsl_rl/runners/distillation_runner.py @@ -9,6 +9,7 @@ import time import torch from collections import deque +from tensordict import TensorDict import rsl_rl from rsl_rl.algorithms import Distillation @@ -21,29 +22,29 @@ class DistillationRunner(OnPolicyRunner): """On-policy runner for training and evaluation of teacher-student training.""" - def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device="cpu"): + def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device: str = "cpu") -> None: self.cfg = train_cfg self.alg_cfg = train_cfg["algorithm"] self.policy_cfg = train_cfg["policy"] self.device = device self.env = env - # check if multi-gpu is enabled + # Check if multi-GPU is enabled self._configure_multi_gpu() - # store training configuration + # Store training configuration self.num_steps_per_env = self.cfg["num_steps_per_env"] self.save_interval = self.cfg["save_interval"] - # query observations from environment for algorithm construction + # Query observations from environment for algorithm construction obs = self.env.get_observations() self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets=["teacher"]) - # create the algorithm + # Create the algorithm self.alg = self._construct_algorithm(obs) # Decide whether to disable logging - # We only log from the process with rank 0 (main process) + # Note: We only log from the process with rank 0 (main process) self.disable_logs = self.is_distributed and self.gpu_global_rank != 0 # Logging @@ -54,20 +55,20 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev self.current_learning_iteration = 0 self.git_status_repos = [rsl_rl.__file__] - def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901 - # initialize writer + def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None: + # Initialize writer self._prepare_logging_writer() - # check if teacher is loaded + # Check if teacher is loaded if not self.alg.policy.loaded_teacher: raise ValueError("Teacher model parameters not loaded. Please load a teacher model to distill.") - # randomize initial episode lengths (for exploration) + # Randomize initial episode lengths (for exploration) if init_at_random_ep_len: self.env.episode_length_buf = torch.randint_like( self.env.episode_length_buf, high=int(self.env.max_episode_length) ) - # start learning + # Start learning obs = self.env.get_observations().to(self.device) self.train_mode() # switch to train mode (for dropout for example) @@ -97,9 +98,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals obs, rewards, dones, extras = self.env.step(actions.to(self.env.device)) # Move to device obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device)) - # process the step + # Process the step self.alg.process_env_step(obs, rewards, dones, extras) - # book keeping + # Book keeping if self.log_dir is not None: if "episode" in extras: ep_infos.append(extras["episode"]) @@ -120,13 +121,13 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals collection_time = stop - start start = stop - # update policy + # Update policy loss_dict = self.alg.update() stop = time.time() learn_time = stop - start self.current_learning_iteration = it - # log info + if self.log_dir is not None and not self.disable_logs: # Log information self.log(locals()) @@ -138,9 +139,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals ep_infos.clear() # Save code state if it == start_iter and not self.disable_logs: - # obtain all the diff files + # Obtain all the diff files git_file_paths = store_code_state(self.log_dir, self.git_status_repos) - # if possible store them to wandb + # If possible store them to wandb or neptune if self.logger_type in ["wandb", "neptune"] and git_file_paths: for path in git_file_paths: self.writer.save_file(path) @@ -149,25 +150,21 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals if self.log_dir is not None and not self.disable_logs: self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) - """ - Helper methods. - """ - - def _construct_algorithm(self, obs) -> Distillation: + def _construct_algorithm(self, obs: TensorDict) -> Distillation: """Construct the distillation algorithm.""" - # initialize the actor-critic + # Initialize the policy student_teacher_class = eval(self.policy_cfg.pop("class_name")) student_teacher: StudentTeacher | StudentTeacherRecurrent = student_teacher_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) - # initialize the algorithm + # Initialize the algorithm alg_class = eval(self.alg_cfg.pop("class_name")) alg: Distillation = alg_class( student_teacher, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg ) - # initialize the storage + # Initialize the storage alg.init_storage( "distillation", self.env.num_envs, diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 36f11f37..46a9b524 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -11,6 +11,7 @@ import torch import warnings from collections import deque +from tensordict import TensorDict import rsl_rl from rsl_rl.algorithms import PPO @@ -22,32 +23,32 @@ class OnPolicyRunner: """On-policy runner for training and evaluation of actor-critic methods.""" - def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device="cpu"): + def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, device: str = "cpu") -> None: self.cfg = train_cfg self.alg_cfg = train_cfg["algorithm"] self.policy_cfg = train_cfg["policy"] self.device = device self.env = env - # check if multi-gpu is enabled + # Check if multi-GPU is enabled self._configure_multi_gpu() - # store training configuration + # Store training configuration self.num_steps_per_env = self.cfg["num_steps_per_env"] self.save_interval = self.cfg["save_interval"] - # query observations from environment for algorithm construction + # Query observations from environment for algorithm construction obs = self.env.get_observations() default_sets = ["critic"] if "rnd_cfg" in self.alg_cfg and self.alg_cfg["rnd_cfg"] is not None: default_sets.append("rnd_state") self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets) - # create the algorithm + # Create the algorithm self.alg = self._construct_algorithm(obs) # Decide whether to disable logging - # We only log from the process with rank 0 (main process) + # Note: We only log from the process with rank 0 (main process) self.disable_logs = self.is_distributed and self.gpu_global_rank != 0 # Logging @@ -58,17 +59,17 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev self.current_learning_iteration = 0 self.git_status_repos = [rsl_rl.__file__] - def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): # noqa: C901 - # initialize writer + def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False) -> None: + # Initialize writer self._prepare_logging_writer() - # randomize initial episode lengths (for exploration) + # Randomize initial episode lengths (for exploration) if init_at_random_ep_len: self.env.episode_length_buf = torch.randint_like( self.env.episode_length_buf, high=int(self.env.max_episode_length) ) - # start learning + # Start learning obs = self.env.get_observations().to(self.device) self.train_mode() # switch to train mode (for dropout for example) @@ -79,7 +80,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) - # create buffers for logging extrinsic and intrinsic rewards + # Create buffers for logging extrinsic and intrinsic rewards if self.alg.rnd: erewbuffer = deque(maxlen=100) irewbuffer = deque(maxlen=100) @@ -105,11 +106,11 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals obs, rewards, dones, extras = self.env.step(actions.to(self.env.device)) # Move to device obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device)) - # process the step + # Process the step self.alg.process_env_step(obs, rewards, dones, extras) # Extract intrinsic rewards (only for logging) intrinsic_rewards = self.alg.intrinsic_rewards if self.alg.rnd else None - # book keeping + # Book keeping if self.log_dir is not None: if "episode" in extras: ep_infos.append(extras["episode"]) @@ -118,20 +119,18 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals # Update rewards if self.alg.rnd: cur_ereward_sum += rewards - cur_ireward_sum += intrinsic_rewards # type: ignore + cur_ireward_sum += intrinsic_rewards cur_reward_sum += rewards + intrinsic_rewards else: cur_reward_sum += rewards # Update episode length cur_episode_length += 1 # Clear data for completed episodes - # -- common new_ids = (dones > 0).nonzero(as_tuple=False) rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) cur_reward_sum[new_ids] = 0 cur_episode_length[new_ids] = 0 - # -- intrinsic and extrinsic rewards if self.alg.rnd: erewbuffer.extend(cur_ereward_sum[new_ids][:, 0].cpu().numpy().tolist()) irewbuffer.extend(cur_ireward_sum[new_ids][:, 0].cpu().numpy().tolist()) @@ -142,16 +141,16 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals collection_time = stop - start start = stop - # compute returns + # Compute returns self.alg.compute_returns(obs) - # update policy + # Update policy loss_dict = self.alg.update() stop = time.time() learn_time = stop - start self.current_learning_iteration = it - # log info + if self.log_dir is not None and not self.disable_logs: # Log information self.log(locals()) @@ -163,9 +162,9 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals ep_infos.clear() # Save code state if it == start_iter and not self.disable_logs: - # obtain all the diff files + # Obtain all the diff files git_file_paths = store_code_state(self.log_dir, self.git_status_repos) - # if possible store them to wandb + # If possible store them to wandb or neptune if self.logger_type in ["wandb", "neptune"] and git_file_paths: for path in git_file_paths: self.writer.save_file(path) @@ -174,7 +173,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals if self.log_dir is not None and not self.disable_logs: self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) - def log(self, locs: dict, width: int = 80, pad: int = 35): + def log(self, locs: dict, width: int = 80, pad: int = 35) -> None: # Compute the collection size collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size # Update total time-steps and time @@ -182,13 +181,13 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): self.tot_time += locs["collection_time"] + locs["learn_time"] iteration_time = locs["collection_time"] + locs["learn_time"] - # -- Episode info + # Log episode information ep_string = "" if locs["ep_infos"]: for key in locs["ep_infos"][0]: infotensor = torch.tensor([], device=self.device) for ep_info in locs["ep_infos"]: - # handle scalar and zero dimensional tensor infos + # Handle scalar and zero dimensional tensor infos if key not in ep_info: continue if not isinstance(ep_info[key], torch.Tensor): @@ -197,38 +196,38 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): ep_info[key] = ep_info[key].unsqueeze(0) infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) value = torch.mean(infotensor) - # log to logger and terminal + # Log to logger and terminal if "/" in key: self.writer.add_scalar(key, value, locs["it"]) - ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" + ep_string += f"""{f"{key}:":>{pad}} {value:.4f}\n""" else: self.writer.add_scalar("Episode/" + key, value, locs["it"]) - ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" + ep_string += f"""{f"Mean episode {key}:":>{pad}} {value:.4f}\n""" mean_std = self.alg.policy.action_std.mean() fps = int(collection_size / (locs["collection_time"] + locs["learn_time"])) - # -- Losses + # Log losses for key, value in locs["loss_dict"].items(): self.writer.add_scalar(f"Loss/{key}", value, locs["it"]) self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"]) - # -- Policy + # Log noise std self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"]) - # -- Performance + # Log performance self.writer.add_scalar("Perf/total_fps", fps, locs["it"]) self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"]) self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"]) - # -- Training + # Log training if len(locs["rewbuffer"]) > 0: - # separate logging for intrinsic and extrinsic rewards + # Separate logging for intrinsic and extrinsic rewards if hasattr(self.alg, "rnd") and self.alg.rnd: self.writer.add_scalar("Rnd/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), locs["it"]) self.writer.add_scalar("Rnd/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), locs["it"]) self.writer.add_scalar("Rnd/weight", self.alg.rnd.weight, locs["it"]) - # everything else + # Everything else self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"]) self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"]) if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging @@ -241,145 +240,144 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): if len(locs["rewbuffer"]) > 0: log_string = ( - f"""{'#' * width}\n""" - f"""{str.center(width, ' ')}\n\n""" - f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ - 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" - f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" + f"""{"#" * width}\n""" + f"""{str.center(width, " ")}\n\n""" + f"""{"Computation:":>{pad}} {fps:.0f} steps/s (collection: {locs["collection_time"]:.3f}s, learning { + locs["learn_time"]:.3f}s)\n""" + f"""{"Mean action noise std:":>{pad}} {mean_std.item():.2f}\n""" ) - # -- Losses + # Print losses for key, value in locs["loss_dict"].items(): - log_string += f"""{f'Mean {key} loss:':>{pad}} {value:.4f}\n""" - # -- Rewards + log_string += f"""{f"Mean {key} loss:":>{pad}} {value:.4f}\n""" + # Print rewards if hasattr(self.alg, "rnd") and self.alg.rnd: log_string += ( - f"""{'Mean extrinsic reward:':>{pad}} {statistics.mean(locs['erewbuffer']):.2f}\n""" - f"""{'Mean intrinsic reward:':>{pad}} {statistics.mean(locs['irewbuffer']):.2f}\n""" + f"""{"Mean extrinsic reward:":>{pad}} {statistics.mean(locs["erewbuffer"]):.2f}\n""" + f"""{"Mean intrinsic reward:":>{pad}} {statistics.mean(locs["irewbuffer"]):.2f}\n""" ) - log_string += f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" - # -- episode info - log_string += f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" + log_string += f"""{"Mean reward:":>{pad}} {statistics.mean(locs["rewbuffer"]):.2f}\n""" + # Print episode information + log_string += f"""{"Mean episode length:":>{pad}} {statistics.mean(locs["lenbuffer"]):.2f}\n""" else: log_string = ( - f"""{'#' * width}\n""" - f"""{str.center(width, ' ')}\n\n""" - f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ - 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" - f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" + f"""{"#" * width}\n""" + f"""{str.center(width, " ")}\n\n""" + f"""{"Computation:":>{pad}} {fps:.0f} steps/s (collection: {locs["collection_time"]:.3f}s, learning { + locs["learn_time"]:.3f}s)\n""" + f"""{"Mean action noise std:":>{pad}} {mean_std.item():.2f}\n""" ) for key, value in locs["loss_dict"].items(): - log_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" + log_string += f"""{f"{key}:":>{pad}} {value:.4f}\n""" log_string += ep_string log_string += ( - f"""{'-' * width}\n""" - f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" - f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" - f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n""" - f"""{'ETA:':>{pad}} {time.strftime( - "%H:%M:%S", - time.gmtime( - self.tot_time / (locs['it'] - locs['start_iter'] + 1) - * (locs['start_iter'] + locs['num_learning_iterations'] - locs['it']) + f"""{"-" * width}\n""" + f"""{"Total timesteps:":>{pad}} {self.tot_timesteps}\n""" + f"""{"Iteration time:":>{pad}} {iteration_time:.2f}s\n""" + f"""{"Time elapsed:":>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n""" + f"""{"ETA:":>{pad}} { + time.strftime( + "%H:%M:%S", + time.gmtime( + self.tot_time + / (locs["it"] - locs["start_iter"] + 1) + * (locs["start_iter"] + locs["num_learning_iterations"] - locs["it"]) + ), ) - )}\n""" + }\n""" ) print(log_string) - def save(self, path: str, infos=None): - # -- Save model + def save(self, path: str, infos: dict | None = None) -> None: + # Save model saved_dict = { "model_state_dict": self.alg.policy.state_dict(), "optimizer_state_dict": self.alg.optimizer.state_dict(), "iter": self.current_learning_iteration, "infos": infos, } - # -- Save RND model if used + # Save RND model if used if hasattr(self.alg, "rnd") and self.alg.rnd: saved_dict["rnd_state_dict"] = self.alg.rnd.state_dict() saved_dict["rnd_optimizer_state_dict"] = self.alg.rnd_optimizer.state_dict() torch.save(saved_dict, path) - # upload model to external logging service + # Upload model to external logging service if self.logger_type in ["neptune", "wandb"] and not self.disable_logs: self.writer.save_model(path, self.current_learning_iteration) - def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None): + def load(self, path: str, load_optimizer: bool = True, map_location: str | None = None) -> dict: loaded_dict = torch.load(path, weights_only=False, map_location=map_location) - # -- Load model + # Load model resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"]) - # -- Load RND model if used + # Load RND model if used if hasattr(self.alg, "rnd") and self.alg.rnd: self.alg.rnd.load_state_dict(loaded_dict["rnd_state_dict"]) - # -- load optimizer if used + # Load optimizer if used if load_optimizer and resumed_training: - # -- algorithm optimizer + # Algorithm optimizer self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) - # -- RND optimizer if used + # RND optimizer if used if hasattr(self.alg, "rnd") and self.alg.rnd: self.alg.rnd_optimizer.load_state_dict(loaded_dict["rnd_optimizer_state_dict"]) - # -- load current learning iteration + # Load current learning iteration if resumed_training: self.current_learning_iteration = loaded_dict["iter"] return loaded_dict["infos"] - def get_inference_policy(self, device=None): - self.eval_mode() # switch to evaluation mode (dropout for example) + def get_inference_policy(self, device: str | None = None) -> callable: + self.eval_mode() # Switch to evaluation mode (e.g. for dropout) if device is not None: self.alg.policy.to(device) return self.alg.policy.act_inference - def train_mode(self): - # -- PPO + def train_mode(self) -> None: + # PPO self.alg.policy.train() - # -- RND + # RND if hasattr(self.alg, "rnd") and self.alg.rnd: self.alg.rnd.train() - def eval_mode(self): - # -- PPO + def eval_mode(self) -> None: + # PPO self.alg.policy.eval() - # -- RND + # RND if hasattr(self.alg, "rnd") and self.alg.rnd: self.alg.rnd.eval() - def add_git_repo_to_log(self, repo_file_path): + def add_git_repo_to_log(self, repo_file_path: str) -> None: self.git_status_repos.append(repo_file_path) - """ - Helper functions. - """ - - def _configure_multi_gpu(self): + def _configure_multi_gpu(self) -> None: """Configure multi-gpu training.""" - # check if distributed training is enabled + # Check if distributed training is enabled self.gpu_world_size = int(os.getenv("WORLD_SIZE", "1")) self.is_distributed = self.gpu_world_size > 1 - # if not distributed training, set local and global rank to 0 and return + # If not distributed training, set local and global rank to 0 and return if not self.is_distributed: self.gpu_local_rank = 0 self.gpu_global_rank = 0 self.multi_gpu_cfg = None return - # get rank and world size + # Get rank and world size self.gpu_local_rank = int(os.getenv("LOCAL_RANK", "0")) self.gpu_global_rank = int(os.getenv("RANK", "0")) - # make a configuration dictionary + # Make a configuration dictionary self.multi_gpu_cfg = { - "global_rank": self.gpu_global_rank, # rank of the main process - "local_rank": self.gpu_local_rank, # rank of the current process - "world_size": self.gpu_world_size, # total number of processes + "global_rank": self.gpu_global_rank, # Rank of the main process + "local_rank": self.gpu_local_rank, # Rank of the current process + "world_size": self.gpu_world_size, # Total number of processes } - # check if user has device specified for local rank + # Check if user has device specified for local rank if self.device != f"cuda:{self.gpu_local_rank}": raise ValueError( f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'." ) - # validate multi-gpu configuration + # Validate multi-gpu configuration if self.gpu_local_rank >= self.gpu_world_size: raise ValueError( f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'." @@ -389,20 +387,20 @@ def _configure_multi_gpu(self): f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'." ) - # initialize torch distributed + # Initialize torch distributed torch.distributed.init_process_group(backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size) - # set device to the local rank + # Set device to the local rank torch.cuda.set_device(self.gpu_local_rank) - def _construct_algorithm(self, obs) -> PPO: + def _construct_algorithm(self, obs: TensorDict) -> PPO: """Construct the actor-critic algorithm.""" - # resolve RND config + # Resolve RND config self.alg_cfg = resolve_rnd_config(self.alg_cfg, obs, self.cfg["obs_groups"], self.env) - # resolve symmetry config + # Resolve symmetry config self.alg_cfg = resolve_symmetry_config(self.alg_cfg, self.env) - # resolve deprecated normalization config + # Resolve deprecated normalization config if self.cfg.get("empirical_normalization") is not None: warnings.warn( "The `empirical_normalization` parameter is deprecated. Please set `actor_obs_normalization` and " @@ -414,17 +412,17 @@ def _construct_algorithm(self, obs) -> PPO: if self.policy_cfg.get("critic_obs_normalization") is None: self.policy_cfg["critic_obs_normalization"] = self.cfg["empirical_normalization"] - # initialize the actor-critic + # Initialize the policy actor_critic_class = eval(self.policy_cfg.pop("class_name")) actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class( obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg ).to(self.device) - # initialize the algorithm + # Initialize the algorithm alg_class = eval(self.alg_cfg.pop("class_name")) alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg) - # initialize the storage + # Initialize the storage alg.init_storage( "rl", self.env.num_envs, @@ -435,10 +433,10 @@ def _construct_algorithm(self, obs) -> PPO: return alg - def _prepare_logging_writer(self): - """Prepares the logging writers.""" + def _prepare_logging_writer(self) -> None: + """Prepare the logging writers.""" if self.log_dir is not None and self.writer is None and not self.disable_logs: - # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard. + # Launch either Tensorboard or Neptune or Tensorboard summary writer, default: Tensorboard. self.logger_type = self.cfg.get("logger", "tensorboard") self.logger_type = self.logger_type.lower() diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py index e9309b3b..539a57bb 100644 --- a/rsl_rl/storage/rollout_storage.py +++ b/rsl_rl/storage/rollout_storage.py @@ -6,38 +6,39 @@ from __future__ import annotations import torch +from collections.abc import Generator from tensordict import TensorDict +from rsl_rl.networks import HiddenState from rsl_rl.utils import split_and_pad_trajectories class RolloutStorage: class Transition: - def __init__(self): - self.observations = None - self.actions = None - self.privileged_actions = None - self.rewards = None - self.dones = None - self.values = None - self.actions_log_prob = None - self.action_mean = None - self.action_sigma = None - self.hidden_states = None - - def clear(self): + def __init__(self) -> None: + self.observations: TensorDict | None = None + self.actions: torch.Tensor | None = None + self.privileged_actions: torch.Tensor | None = None + self.rewards: torch.Tensor | None = None + self.dones: torch.Tensor | None = None + self.values: torch.Tensor | None = None + self.actions_log_prob: torch.Tensor + self.action_mean: torch.Tensor | None = None + self.action_sigma: torch.Tensor | None = None + self.hidden_states: tuple[HiddenState, HiddenState] = (None, None) + + def clear(self) -> None: self.__init__() def __init__( self, - training_type, - num_envs, - num_transitions_per_env, - obs, - actions_shape, - device="cpu", - ): - # store inputs + training_type: str, + num_envs: int, + num_transitions_per_env: int, + obs: TensorDict, + actions_shape: tuple[int] | list[int], + device: str = "cpu", + ) -> None: self.training_type = training_type self.device = device self.num_transitions_per_env = num_transitions_per_env @@ -54,11 +55,11 @@ def __init__( self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() - # for distillation + # For distillation if training_type == "distillation": self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) - # for reinforcement learning + # For reinforcement learning if training_type == "rl": self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) @@ -68,14 +69,14 @@ def __init__( self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) # For RNN networks - self.saved_hidden_states_a = None - self.saved_hidden_states_c = None + self.saved_hidden_state_a = None + self.saved_hidden_state_c = None - # counter for the number of transitions stored + # Counter for the number of transitions stored self.step = 0 - def add_transitions(self, transition: Transition): - # check if the transition is valid + def add_transitions(self, transition: Transition) -> None: + # Check if the transition is valid if self.step >= self.num_transitions_per_env: raise OverflowError("Rollout buffer overflow! You should call clear() before adding new transitions.") @@ -85,11 +86,11 @@ def add_transitions(self, transition: Transition): self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) self.dones[self.step].copy_(transition.dones.view(-1, 1)) - # for distillation + # For distillation if self.training_type == "distillation": self.privileged_actions[self.step].copy_(transition.privileged_actions) - # for reinforcement learning + # For reinforcement learning if self.training_type == "rl": self.values[self.step].copy_(transition.values) self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) @@ -99,39 +100,40 @@ def add_transitions(self, transition: Transition): # For RNN networks self._save_hidden_states(transition.hidden_states) - # increment the counter + # Increment the counter self.step += 1 - def _save_hidden_states(self, hidden_states): - if hidden_states is None or hidden_states == (None, None): + def _save_hidden_states(self, hidden_states: tuple[HiddenState, HiddenState]) -> None: + if hidden_states == (None, None): return - # make a tuple out of GRU hidden state sto match the LSTM format - hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) - hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) - # initialize if needed - if self.saved_hidden_states_a is None: - self.saved_hidden_states_a = [ - torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a)) + # Make a tuple out of GRU hidden states to match the LSTM format + hidden_state_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) + hidden_state_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) + # Initialize hidden states if needed + if self.saved_hidden_state_a is None: + self.saved_hidden_state_a = [ + torch.zeros(self.observations.shape[0], *hidden_state_a[i].shape, device=self.device) + for i in range(len(hidden_state_a)) ] - self.saved_hidden_states_c = [ - torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c)) + self.saved_hidden_state_c = [ + torch.zeros(self.observations.shape[0], *hidden_state_c[i].shape, device=self.device) + for i in range(len(hidden_state_c)) ] - # copy the states - for i in range(len(hid_a)): - self.saved_hidden_states_a[i][self.step].copy_(hid_a[i]) - self.saved_hidden_states_c[i][self.step].copy_(hid_c[i]) + # Copy the states + for i in range(len(hidden_state_a)): + self.saved_hidden_state_a[i][self.step].copy_(hidden_state_a[i]) + self.saved_hidden_state_c[i][self.step].copy_(hidden_state_c[i]) - def clear(self): + def clear(self) -> None: self.step = 0 - def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True): + def compute_returns( + self, last_values: torch.Tensor, gamma: float, lam: float, normalize_advantage: bool = True + ) -> None: advantage = 0 for step in reversed(range(self.num_transitions_per_env)): - # if we are at the last step, bootstrap the return value - if step == self.num_transitions_per_env - 1: - next_values = last_values - else: - next_values = self.values[step + 1] + # If we are at the last step, bootstrap the return value + next_values = last_values if step == self.num_transitions_per_env - 1 else self.values[step + 1] # 1 if we are not in a terminal state, 0 otherwise next_is_not_terminal = 1.0 - self.dones[step].float() # TD error: r_t + gamma * V(s_{t+1}) - V(s_t) @@ -144,20 +146,20 @@ def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = T # Compute the advantages self.advantages = self.returns - self.values # Normalize the advantages if flag is set - # This is to prevent double normalization (i.e. if per minibatch normalization is used) + # Note: This is to prevent double normalization (i.e. if per minibatch normalization is used) if normalize_advantage: self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) - # for distillation - def generator(self): + # For distillation + def generator(self) -> Generator: if self.training_type != "distillation": raise ValueError("This function is only available for distillation training.") for i in range(self.num_transitions_per_env): yield self.observations[i], self.actions[i], self.privileged_actions[i], self.dones[i] - # for reinforcement learning with feedforward networks - def mini_batch_generator(self, num_mini_batches, num_epochs=8): + # For reinforcement learning with feedforward networks + def mini_batch_generator(self, num_mini_batches: int, num_epochs: int = 8) -> Generator: if self.training_type != "rl": raise ValueError("This function is only available for reinforcement learning training.") batch_size = self.num_envs * self.num_transitions_per_env @@ -180,15 +182,12 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8): for i in range(num_mini_batches): # Select the indices for the mini-batch start = i * mini_batch_size - end = (i + 1) * mini_batch_size - batch_idx = indices[start:end] + stop = (i + 1) * mini_batch_size + batch_idx = indices[start:stop] # Create the mini-batch - # -- Core obs_batch = observations[batch_idx] actions_batch = actions[batch_idx] - - # -- For PPO target_values_batch = values[batch_idx] returns_batch = returns[batch_idx] old_actions_log_prob_batch = old_actions_log_prob[batch_idx] @@ -196,14 +195,29 @@ def mini_batch_generator(self, num_mini_batches, num_epochs=8): old_mu_batch = old_mu[batch_idx] old_sigma_batch = old_sigma[batch_idx] - # yield the mini-batch - yield obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( - None, - None, - ), None - - # for reinfrocement learning with recurrent networks - def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8): + hidden_state_a_batch = None + hidden_state_c_batch = None + masks_batch = None + + # Yield the mini-batch + yield ( + obs_batch, + actions_batch, + target_values_batch, + advantages_batch, + returns_batch, + old_actions_log_prob_batch, + old_mu_batch, + old_sigma_batch, + ( + hidden_state_a_batch, + hidden_state_c_batch, + ), + masks_batch, + ) + + # For reinforcement learning with recurrent networks + def recurrent_mini_batch_generator(self, num_mini_batches: int, num_epochs: int = 8) -> Generator: if self.training_type != "rl": raise ValueError("This function is only available for reinforcement learning training.") padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) @@ -232,29 +246,46 @@ def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8): values_batch = self.values[:, start:stop] old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] - # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim]) - # then take only time steps after dones (flattens num envs and time dimensions), - # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] + # Reshape to [num_envs, time, num layers, hidden dim] + # Original shape: [time, num_layers, num_envs, hidden_dim]) last_was_done = last_was_done.permute(1, 0) - hid_a_batch = [ - saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] + # Take only time steps after dones (flattens num envs and time dimensions), + # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] + hidden_state_a_batch = [ + saved_hidden_state.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] .transpose(1, 0) .contiguous() - for saved_hidden_states in self.saved_hidden_states_a + for saved_hidden_state in self.saved_hidden_state_a ] - hid_c_batch = [ - saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] + hidden_state_c_batch = [ + saved_hidden_state.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] .transpose(1, 0) .contiguous() - for saved_hidden_states in self.saved_hidden_states_c + for saved_hidden_state in self.saved_hidden_state_c ] - # remove the tuple for GRU - hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch - hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch - - yield obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( - hid_a_batch, - hid_c_batch, - ), masks_batch + # Remove the tuple for GRU + hidden_state_a_batch = ( + hidden_state_a_batch[0] if len(hidden_state_a_batch) == 1 else hidden_state_a_batch + ) + hidden_state_c_batch = ( + hidden_state_c_batch[0] if len(hidden_state_c_batch) == 1 else hidden_state_c_batch + ) + + # Yield the mini-batch + yield ( + obs_batch, + actions_batch, + values_batch, + advantages_batch, + returns_batch, + old_actions_log_prob_batch, + old_mu_batch, + old_sigma_batch, + ( + hidden_state_a_batch, + hidden_state_c_batch, + ), + masks_batch, + ) first_traj = last_traj diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index 0a7deab6..a11074e0 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -5,4 +5,22 @@ """Helper functions.""" -from .utils import * +from .utils import ( + resolve_nn_activation, + resolve_obs_groups, + resolve_optimizer, + split_and_pad_trajectories, + store_code_state, + string_to_callable, + unpad_trajectories, +) + +__all__ = [ + "resolve_nn_activation", + "resolve_obs_groups", + "resolve_optimizer", + "split_and_pad_trajectories", + "store_code_state", + "string_to_callable", + "unpad_trajectories", +] diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py index 3796ec8b..90bcd7ef 100644 --- a/rsl_rl/utils/neptune_utils.py +++ b/rsl_rl/utils/neptune_utils.py @@ -12,14 +12,14 @@ try: import neptune except ModuleNotFoundError: - raise ModuleNotFoundError("neptune-client is required to log to Neptune.") + raise ModuleNotFoundError("neptune-client is required to log to Neptune.") from None class NeptuneLogger: - def __init__(self, project, token): + def __init__(self, project: str, token: str) -> None: self.run = neptune.init_run(project=project, api_token=token) - def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): + def store_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None: self.run["runner_cfg"] = runner_cfg self.run["policy_cfg"] = policy_cfg self.run["alg_cfg"] = alg_cfg @@ -29,48 +29,45 @@ def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): class NeptuneSummaryWriter(SummaryWriter): """Summary writer for Neptune.""" - def __init__(self, log_dir: str, flush_secs: int, cfg): + def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: super().__init__(log_dir, flush_secs) try: project = cfg["neptune_project"] except KeyError: - raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.") + raise KeyError("Please specify neptune_project in the runner config, e.g. legged_gym.") from None try: token = os.environ["NEPTUNE_API_TOKEN"] except KeyError: raise KeyError( "Neptune api token not found. Please run or add to ~/.bashrc: export NEPTUNE_API_TOKEN=YOUR_API_TOKEN" - ) + ) from None try: entity = os.environ["NEPTUNE_USERNAME"] except KeyError: raise KeyError( "Neptune username not found. Please run or add to ~/.bashrc: export NEPTUNE_USERNAME=YOUR_USERNAME" - ) + ) from None neptune_project = entity + "/" + project - self.neptune_logger = NeptuneLogger(neptune_project, token) - self.name_map = { "Train/mean_reward/time": "Train/mean_reward_time", "Train/mean_episode_length/time": "Train/mean_episode_length_time", } - run_name = os.path.split(log_dir)[-1] - self.neptune_logger.run["log_dir"].log(run_name) - def _map_path(self, path): - if path in self.name_map: - return self.name_map[path] - else: - return path - - def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False): + def add_scalar( + self, + tag: str, + scalar_value: float, + global_step: int | None = None, + walltime: float | None = None, + new_style: bool = False, + ) -> None: super().add_scalar( tag, scalar_value, @@ -80,15 +77,21 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_sty ) self.neptune_logger.run[self._map_path(tag)].log(scalar_value, step=global_step) - def stop(self): + def stop(self) -> None: self.neptune_logger.run.stop() - def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): + def log_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None: self.neptune_logger.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) - def save_model(self, model_path, iter): + def save_model(self, model_path: str, iter: int) -> None: self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path) - def save_file(self, path, iter=None): + def save_file(self, path: str) -> None: name = path.rsplit("/", 1)[-1].split(".")[0] self.neptune_logger.run["git_diff/" + name].upload(path) + + def _map_path(self, path: str) -> str: + if path in self.name_map: + return self.name_map[path] + else: + return path diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index da19a0b7..7a044e83 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -16,10 +16,10 @@ def resolve_nn_activation(act_name: str) -> torch.nn.Module: - """Resolves the activation function from the name. + """Resolve the activation function from the name. Args: - act_name: The name of the activation function. + act_name: Name of the activation function. Returns: The activation function. @@ -50,10 +50,10 @@ def resolve_nn_activation(act_name: str) -> torch.nn.Module: def resolve_optimizer(optimizer_name: str) -> torch.optim.Optimizer: - """Resolves the optimizer from the name. + """Resolve the optimizer from the name. Args: - optimizer_name: The name of the optimizer. + optimizer_name: Name of the optimizer. Returns: The optimizer. @@ -78,10 +78,12 @@ def resolve_optimizer(optimizer_name: str) -> torch.optim.Optimizer: def split_and_pad_trajectories( tensor: torch.Tensor | TensorDict, dones: torch.Tensor ) -> tuple[torch.Tensor | TensorDict, torch.Tensor]: - """Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length of the longest - trajectory. Returns masks corresponding to valid parts of the trajectories. + """Split trajectories at done indices. - Example: + Split trajectories, concatenate them and pad with zeros up to the length of the longest trajectory. Return masks + corresponding to valid parts of the trajectories. + + Example (transposed for readability): Input: [[a1, a2, a3, a4 | a5, a6], [b1, b2 | b3, b4, b5 | b6]] @@ -93,10 +95,9 @@ def split_and_pad_trajectories( Assumes that the input has the following order of dimensions: [time, number of envs, additional dimensions] """ - dones = dones.clone() dones[-1] = 1 - # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping + # Permute the buffers to have the order (num_envs, num_transitions_per_env, ...) for correct reshaping flat_dones = dones.transpose(1, 0).reshape(-1, 1) # Get length of trajectory by counting the number of successive not done elements done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) @@ -106,33 +107,33 @@ def split_and_pad_trajectories( if isinstance(tensor, TensorDict): padded_trajectories = {} for k, v in tensor.items(): - # split the tensor into trajectories + # Split the tensor into trajectories trajectories = torch.split(v.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) - # add at least one full length trajectory - trajectories = trajectories + (torch.zeros(v.shape[0], *v.shape[2:], device=v.device),) - # pad the trajectories to the length of the longest trajectory + # Add at least one full length trajectory + trajectories = (*trajectories, torch.zeros(v.shape[0], *v.shape[2:], device=v.device)) + # Pad the trajectories to the length of the longest trajectory padded_trajectories[k] = torch.nn.utils.rnn.pad_sequence(trajectories) - # remove the added tensor + # Remove the added trajectory padded_trajectories[k] = padded_trajectories[k][:, :-1] padded_trajectories = TensorDict( padded_trajectories, batch_size=[tensor.batch_size[0], len(trajectory_lengths_list)] ) else: - # split the tensor into trajectories + # Split the tensor into trajectories trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) - # add at least one full length trajectory - trajectories = trajectories + (torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device),) - # pad the trajectories to the length of the longest trajectory + # Add at least one full length trajectory + trajectories = (*trajectories, torch.zeros(tensor.shape[0], *tensor.shape[2:], device=tensor.device)) + # Pad the trajectories to the length of the longest trajectory padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) - # remove the added tensor + # Remove the added trajectory padded_trajectories = padded_trajectories[:, :-1] - # create masks for the valid parts of the trajectories + # Create masks for the valid parts of the trajectories trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) return padded_trajectories, trajectory_masks -def unpad_trajectories(trajectories, masks): - """Does the inverse operation of split_and_pad_trajectories()""" +def unpad_trajectories(trajectories: torch.Tensor | TensorDict, masks: torch.Tensor) -> torch.Tensor | TensorDict: + """Do the inverse operation of `split_and_pad_trajectories()`.""" # Need to transpose before and after the masking to have proper reshaping return ( trajectories.transpose(1, 0)[masks.transpose(1, 0)] @@ -141,7 +142,7 @@ def unpad_trajectories(trajectories, masks): ) -def store_code_state(logdir, repositories) -> list: +def store_code_state(logdir: str, repositories: list[str]) -> list[str]: git_log_dir = os.path.join(logdir, "git") os.makedirs(git_log_dir, exist_ok=True) file_paths = [] @@ -151,58 +152,58 @@ def store_code_state(logdir, repositories) -> list: t = repo.head.commit.tree except Exception: print(f"Could not find git repository in {repository_file_path}. Skipping.") - # skip if not a git repository + # Skip if not a git repository continue - # get the name of the repository + # Get the name of the repository repo_name = pathlib.Path(repo.working_dir).name diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff") - # check if the diff file already exists + # Check if the diff file already exists if os.path.isfile(diff_file_name): continue - # write the diff file + # Write the diff file print(f"Storing git diff for '{repo_name}' in: {diff_file_name}") with open(diff_file_name, "x", encoding="utf-8") as f: content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" f.write(content) - # add the file path to the list of files to be uploaded + # Add the file path to the list of files to be uploaded file_paths.append(diff_file_name) return file_paths def string_to_callable(name: str) -> Callable: - """Resolves the module and function names to return the function. + """Resolve the module and function names to return the function. Args: - name: The function name. The format should be 'module:attribute_name'. + name: Function name. The format should be 'module:attribute_name'. + + Returns: + The function loaded from the module. Raises: ValueError: When the resolved attribute is not a function. ValueError: When unable to resolve the attribute. - - Returns: - The function loaded from the module. """ try: mod_name, attr_name = name.split(":") mod = importlib.import_module(mod_name) callable_object = getattr(mod, attr_name) - # check if attribute is callable + # Check if attribute is callable if callable(callable_object): return callable_object else: raise ValueError(f"The imported object is not callable: '{name}'") - except AttributeError as e: + except AttributeError as err: msg = ( "We could not interpret the entry as a callable object. The format of input should be" - f" 'module:attribute_name'\nWhile processing input '{name}', received the error:\n {e}." + f" 'module:attribute_name'\nWhile processing input '{name}'." ) - raise ValueError(msg) + raise ValueError(msg) from err def resolve_obs_groups( obs: TensorDict, obs_groups: dict[str, list[str]], default_sets: list[str] ) -> dict[str, list[str]]: - """Validates the observation configuration and defaults missing observation sets. + """Validate the observation configuration and defaults missing observation sets. The input is an observation dictionary `obs` containing observation groups and a configuration dictionary `obs_groups` where the keys are the observation sets and the values are lists of observation groups. @@ -213,13 +214,13 @@ def resolve_obs_groups( "critic": ["group_1", "group_3"] } - This means that the 'policy' observation set will contain the observations "group_1" and "group_2" and the - 'critic' observation set will contain the observations "group_1" and "group_3". This function will check that all - the observations in the 'policy' and 'critic' observation sets are present in the observation dictionary from the + This means that the 'policy' observation set will contain the observations "group_1" and "group_2" and the 'critic' + observation set will contain the observations "group_1" and "group_3". This function will check that all the + observations in the 'policy' and 'critic' observation sets are present in the observation dictionary from the environment. - Additionally, if one of the `default_sets`, e.g. "critic", is not present in the configuration dictionary, - this function will: + Additionally, if one of the `default_sets`, e.g. "critic", is not present in the configuration dictionary, this + function will: 1. Check if a group with the same name exists in the observations and assign this group to the observation set. 2. If 1. fails, it will assign the observations from the 'policy' observation set to the default observation set. @@ -227,8 +228,8 @@ def resolve_obs_groups( Args: obs: Observations from the environment in the form of a dictionary. obs_groups: Observation sets configuration. - default_sets: Reserved observation set names used by the algorithm (besides 'policy'). - If not provided in 'obs_groups', a default behavior gets triggered. + default_sets: Reserved observation set names used by the algorithm (besides 'policy'). If not provided in + 'obs_groups', a default behavior gets triggered. Returns: The resolved observation groups. @@ -237,8 +238,8 @@ def resolve_obs_groups( ValueError: If any observation set is an empty list. ValueError: If any observation set contains an observation term that is not present in the observations. """ - # check if policy observation set exists - if "policy" not in obs_groups.keys(): + # Check if policy observation set exists + if "policy" not in obs_groups: if "policy" in obs: obs_groups["policy"] = ["policy"] warnings.warn( @@ -253,9 +254,9 @@ def resolve_obs_groups( f" Found keys: {list(obs_groups.keys())}" ) - # check all observation sets for valid observation groups + # Check all observation sets for valid observation groups for set_name, groups in obs_groups.items(): - # check if the list is empty + # Check if the list is empty if len(groups) == 0: msg = f"The '{set_name}' key in the 'obs_groups' dictionary can not be an empty list." if set_name in default_sets: @@ -266,7 +267,7 @@ def resolve_obs_groups( f" Consider removing the key to default to the observation '{set_name}' from the environment." ) raise ValueError(msg) - # check groups exist inside the observations from the environment + # Check groups exist inside the observations from the environment for group in groups: if group not in obs: raise ValueError( @@ -274,9 +275,9 @@ def resolve_obs_groups( f" environment. Available observations from the environment: {list(obs.keys())}" ) - # fill missing observation sets + # Fill missing observation sets for default_set_name in default_sets: - if default_set_name not in obs_groups.keys(): + if default_set_name not in obs_groups: if default_set_name in obs: obs_groups[default_set_name] = [default_set_name] warnings.warn( @@ -294,7 +295,7 @@ def resolve_obs_groups( " clarity. This behavior will be removed in a future version." ) - # print the final parsed observation sets + # Print the final parsed observation sets print("-" * 80) print("Resolved observation sets: ") for set_name, groups in obs_groups.items(): diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 243e82d4..b1bf7559 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -12,13 +12,13 @@ try: import wandb except ModuleNotFoundError: - raise ModuleNotFoundError("Wandb is required to log to Weights and Biases.") + raise ModuleNotFoundError("Wandb is required to log to Weights and Biases.") from None class WandbSummaryWriter(SummaryWriter): """Summary writer for Weights and Biases.""" - def __init__(self, log_dir: str, flush_secs: int, cfg): + def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: super().__init__(log_dir, flush_secs) # Get the run name @@ -27,7 +27,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg): try: project = cfg["wandb_project"] except KeyError: - raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") + raise KeyError("Please specify wandb_project in the runner config, e.g. legged_gym.") from None try: entity = os.environ["WANDB_USERNAME"] @@ -40,12 +40,7 @@ def __init__(self, log_dir: str, flush_secs: int, cfg): # Add log directory to wandb wandb.config.update({"log_dir": log_dir}) - self.name_map = { - "Train/mean_reward/time": "Train/mean_reward_time", - "Train/mean_episode_length/time": "Train/mean_episode_length_time", - } - - def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): + def store_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None: wandb.config.update({"runner_cfg": runner_cfg}) wandb.config.update({"policy_cfg": policy_cfg}) wandb.config.update({"alg_cfg": alg_cfg}) @@ -54,7 +49,14 @@ def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): except Exception: wandb.config.update({"env_cfg": asdict(env_cfg)}) - def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False): + def add_scalar( + self, + tag: str, + scalar_value: float, + global_step: int | None = None, + walltime: float | None = None, + new_style: bool = False, + ) -> None: super().add_scalar( tag, scalar_value, @@ -62,26 +64,16 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_sty walltime=walltime, new_style=new_style, ) - wandb.log({self._map_path(tag): scalar_value}, step=global_step) + wandb.log({tag: scalar_value}, step=global_step) - def stop(self): + def stop(self) -> None: wandb.finish() - def log_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): + def log_config(self, env_cfg: dict | object, runner_cfg: dict, alg_cfg: dict, policy_cfg: dict) -> None: self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) - def save_model(self, model_path, iter): + def save_model(self, model_path: str, iter: int) -> None: wandb.save(model_path, base_path=os.path.dirname(model_path)) - def save_file(self, path, iter=None): + def save_file(self, path: str) -> None: wandb.save(path, base_path=os.path.dirname(path)) - - """ - Private methods. - """ - - def _map_path(self, path): - if path in self.name_map: - return self.name_map[path] - else: - return path diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..95ae3f32 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,71 @@ +line-length = 120 +target-version = "py39" +preview = true + +[lint] +select = [ + # pycodestyle + "E", "W", + # pydocstyle + "D", + # pylint for later + # "PL", + # pyflakes + "F", + # pyupgrade + "UP", + # pep8-naming + "N", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # flake8-tidy-imports + "TID", + # flake8-annotations + "ANN", + # isort + "I", + # perflint + "PERF", + # ruff + "RUF", +] +ignore = ["B006", + "B007", + "B028", + "ANN401", + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D106", + "D107", + "D203", + "D213", + "D413", +] +per-file-ignores = {"*/__init__.py" = ["F401"]} + +[lint.isort] +# Order of imports +section-order = [ + "future", + "standard-library", + "third-party", + "first-party", + "local-folder", +] +# Extra standard libraries considered as part of python (permissive licenses) +extra-standard-library = [ + "numpy", + "torch", + "tensordict", + "warp", + "typing_extensions", + "git", +] +# Imports from this repository +known-first-party = ["rsl_rl"] \ No newline at end of file