Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 14 additions & 50 deletions scripts/list_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,60 +45,24 @@ def seen(p, m={}):
except Exception:
if onerror is not None:
onerror(info.name)
else:
raise
else:
path = getattr(sys.modules[info.name], "__path__", None) or []

# don't traverse path items we've seen before
path = [p for p in path if not seen(p)]

yield from _walk_packages(path, info.name + ".", onerror)


def import_packages():
sys.path.insert(0, f"{pathlib.Path(__file__).parent.parent}/source/unitree_rl_lab/unitree_rl_lab/tasks/")
for package in ["locomotion.robots", "mimic.robots"]:
package = importlib.import_module(package)
for _ in _walk_packages(package.__path__, package.__name__ + "."):
pass
sys.path.pop(0)


import_packages()

"""Rest everything follows."""

import gymnasium as gym
from prettytable import PrettyTable
"""Import all packages in the unitree_rl_lab module."""
# get the path to the unitree_rl_lab module
unitree_rl_lab_path = pathlib.Path(__file__).parent.parent / "unitree_rl_lab"
# import all packages
for info in _walk_packages([str(unitree_rl_lab_path)], "unitree_rl_lab."):
importlib.import_module(info.name)


def main():
"""Print all environments registered in `unitree_rl_lab` extension."""
# print all the available environments
table = PrettyTable(["S. No.", "Task Name", "Entry Point", "Config"])
table.title = "Available Environments in Unitree RL Lab"
# set alignment of table columns
table.align["Task Name"] = "l"
table.align["Entry Point"] = "l"
table.align["Config"] = "l"

# count of environments
index = 0
# acquire all Isaac environments names
if __name__ == "__main__":
# import packages
import_packages()
# print all environments
for task_spec in gym.registry.values():
if "Unitree" in task_spec.id and "Isaac" not in task_spec.id:
# add details to table
table.add_row([index + 1, task_spec.id, task_spec.entry_point, task_spec.kwargs["env_cfg_entry_point"]])
# increment count
index += 1

print(table)


if __name__ == "__main__":
try:
# run the main function
main()
except Exception as e:
raise e
print(f"Task: {task_spec.id}")
print(f"\tEntry point: {task_spec.entry_point}")
print(f"\tConfig: {task_spec.kwargs.get('cfg_entry_point', 'N/A')}")
print()
94 changes: 35 additions & 59 deletions scripts/rsl_rl/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,69 +29,45 @@ def add_rsl_rl_args(parser: argparse.ArgumentParser):
# -- load arguments
arg_group.add_argument("--resume", action="store_true", default=False, help="Whether to resume from a checkpoint.")
arg_group.add_argument("--load_run", type=str, default=None, help="Name of the run folder to resume from.")
arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file to resume from.")
# -- logger arguments
arg_group.add_argument(
"--logger", type=str, default=None, choices={"wandb", "tensorboard", "neptune"}, help="Logger module to use."
)
arg_group.add_argument(
"--log_project_name", type=str, default=None, help="Name of the logging project when using wandb or neptune."
)


def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
"""Parse configuration for RSL-RL agent based on inputs.

Args:
task_name: The name of the environment.
args_cli: The command line arguments.

Returns:
The parsed configuration for RSL-RL agent based on inputs.
"""
from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry
arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file name to resume from.")

# load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
if rslrl_cfg.experiment_name == "":
rslrl_cfg.experiment_name = task_name.lower().replace("-", "_").removesuffix("_play")
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
return rslrl_cfg


def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for RSL-RL agent based on inputs.
def parse_rsl_rl_cfg(task_name: str, args: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
"""Parse configuration file for RSL-RL agent.

Args:
agent_cfg: The configuration for RSL-RL agent.
args_cli: The command line arguments.
task_name: Name of the task.
args: Command line arguments.

Returns:
The updated configuration for RSL-RL agent based on inputs.
The configuration class for RSL-RL agent.
"""
# override the default configuration with CLI arguments
if hasattr(args_cli, "seed") and args_cli.seed is not None:
# randomly sample a seed if seed = -1
if args_cli.seed == -1:
args_cli.seed = random.randint(0, 10000)
agent_cfg.seed = args_cli.seed
if args_cli.resume is not None:
agent_cfg.resume = args_cli.resume
if args_cli.load_run is not None:
agent_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None:
agent_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None:
agent_cfg.run_name = args_cli.run_name
if args_cli.logger is not None:
agent_cfg.logger = args_cli.logger
# set the project name for wandb and neptune
if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
agent_cfg.wandb_project = args_cli.log_project_name
agent_cfg.neptune_project = args_cli.log_project_name

if agent_cfg.experiment_name == "":
task_name = args_cli.task
agent_cfg.experiment_name = task_name.lower().replace("-", "_").removesuffix("_play")

return agent_cfg
# import configuration
if task_name.startswith("Unitree"):
from unitree_rl_lab.tasks import unitree_a1_cfgs

# check if the task name is provided
if task_name == "UnitreeA1TerrainEnv-v0":
cfg = unitree_a1_cfgs.UnitreeA1RoughCfg()
cfg_class = unitree_a1_cfgs.UnitreeA1RoughCfgPPO
else:
raise ValueError(f"Task {task_name} not found!")
else:
raise ValueError(f"Task {task_name} not supported!")

# update runner configuration with command line arguments
cfg_class.experiment_name = args.experiment_name
cfg_class.run_name = args.run_name
cfg_class.resume = args.resume
cfg_class.load_run = args.load_run
cfg_class.load_checkpoint = args.checkpoint

# create runner configuration
runner_cfg = cfg_class()
runner_cfg.policy = cfg

# set maximum iterations if provided
if args.max_iterations is not None:
runner_cfg.max_iterations = args.max_iterations

return runner_cfg
Loading