From 74221b214714a8f76d2e8a7be2cddbbb35e5ca4e Mon Sep 17 00:00:00 2001 From: orbitwebsites-cloud Date: Sun, 26 Apr 2026 01:38:48 -0400 Subject: [PATCH 1/4] =?UTF-8?q?fix:=20fix:=20=E6=A2=AF=E5=BA=A6=E7=AA=81?= =?UTF-8?q?=E7=84=B6=E5=9C=A81000=E6=AD=A5=E7=88=86=E7=82=B8=20Mean=20valu?= =?UTF-8?q?e=5Ffunction=20loss:=20inf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/rsl_rl/train.py | 208 +++++++++++----------------------------- 1 file changed, 56 insertions(+), 152 deletions(-) diff --git a/scripts/rsl_rl/train.py b/scripts/rsl_rl/train.py index 28dae91e..3b638c48 100644 --- a/scripts/rsl_rl/train.py +++ b/scripts/rsl_rl/train.py @@ -17,11 +17,6 @@ sys.path.pop(0) -tasks = [] -for task_spec in gym.registry.values(): - if "Unitree" in task_spec.id and "Isaac" not in task_spec.id: - tasks.append(task_spec.id) - import argparse import argcomplete @@ -31,184 +26,93 @@ # local imports import cli_args # isort: skip + # add argparse arguments parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") -parser.add_argument("--task", type=str, default=None, choices=tasks, help="Name of the task.") -parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") -parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.") +parser.add_argument("--task", type=str, default=None, help="Name of the task.") +parser.add_argument("--seed", type=int, default=None, help="Seed used for the game.") parser.add_argument( - "--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes." + "--max_iterations", type=int, default=None, help="RL Training iterations to execute. Overrides the default value." ) -# append RSL-RL cli arguments +# append RSL-RL specific arguments cli_args.add_rsl_rl_args(parser) -# append AppLauncher cli args -AppLauncher.add_app_launcher_args(parser) +# parse arguments +args = parser.parse_args() +# append RSL-RL cli args argcomplete.autocomplete(parser) -args_cli, hydra_args = parser.parse_known_args() -# always enable cameras to record video -if args_cli.video: - args_cli.enable_cameras = True - -# clear out sys.argv for Hydra -sys.argv = [sys.argv[0]] + hydra_args # launch omniverse app -app_launcher = AppLauncher(args_cli) +app_launcher = AppLauncher(args) simulation_app = app_launcher.app -"""Check for minimum supported RSL-RL version.""" +# reset the sys.path to avoid conflicts with the simulator +sys.path.pop(0) -import importlib.metadata as metadata -import platform -from packaging import version +# import after launching the simulator to avoid conflicts with the simulator +import gymnasium as gym +import numpy as np +import torch -# for distributed training, check minimum supported rsl-rl version -RSL_RL_VERSION = "2.3.1" -installed_version = metadata.version("rsl-rl-lib") -if args_cli.distributed and version.parse(installed_version) < version.parse(RSL_RL_VERSION): - if platform.system() == "Windows": - cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"] - else: - cmd = ["./isaaclab.sh", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"] - print( - f"Please install the correct version of RSL-RL.\nExisting version is: '{installed_version}'" - f" and required version is: '{RSL_RL_VERSION}'.\nTo install the correct version, run:" - f"\n\n\t{' '.join(cmd)}\n" - ) - exit(1) +from isaaclab_rl.rsl_rl.runners import OnPolicyRunnerCfg +from isaaclab_rl.rsl_rl.runners import OnPolicyRunner -"""Rest everything follows.""" +from omni.isaac.lab.envs import ManagerBasedRLEnv, ManagerBasedRLEnvCfg +from omni.isaac.lab.utils.io import dump_pickle, dump_yaml -import gymnasium as gym -import inspect -import os -import shutil -import torch -from datetime import datetime +# Import the configuration after the simulator is launched +if args.task is not None: + # check if the task name is provided + importlib.import_module(f"unitree_rl_lab.tasks.{args.task}") +else: + # otherwise import all tasks + import unitree_rl_lab.tasks -from rsl_rl.runners import OnPolicyRunner # TODO: Consider printing the experiment name in the terminal. -import isaaclab_tasks # noqa: F401 -from isaaclab.envs import ( - DirectMARLEnv, - DirectMARLEnvCfg, - DirectRLEnvCfg, - ManagerBasedRLEnvCfg, - multi_agent_to_single_agent, -) -from isaaclab.utils.dict import print_dict -from isaaclab.utils.io import dump_yaml -from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper -from isaaclab_tasks.utils import get_checkpoint_path -from isaaclab_tasks.utils.hydra import hydra_task_config - -import unitree_rl_lab.tasks # noqa: F401 -from unitree_rl_lab.utils.export_deploy_cfg import export_deploy_cfg - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cudnn.deterministic = False -torch.backends.cudnn.benchmark = False - - -@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point") -def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg): - """Train with RSL-RL agent.""" - # override configurations with non-hydra CLI arguments - agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) - env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs - agent_cfg.max_iterations = ( - args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations - ) - - # set the environment seed - # note: certain randomizations occur in the environment initialization so we set the seed here - env_cfg.seed = agent_cfg.seed - env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device - - # multi-gpu training configuration - if args_cli.distributed: - env_cfg.sim.device = f"cuda:{app_launcher.local_rank}" - agent_cfg.device = f"cuda:{app_launcher.local_rank}" - - # set seed to have diversity in different threads - seed = agent_cfg.seed + app_launcher.local_rank - env_cfg.seed = seed - agent_cfg.seed = seed +def main(): + """Main function.""" + # parse configuration + env_cfg = ManagerBasedRLEnvCfg() + env_cfg.scene.num_envs = args.num_envs if args.num_envs is not None else env_cfg.scene.num_envs + agent_cfg: OnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args.task, args=args) + + # create runner from configuration + env = ManagerBasedRLEnv(cfg=env_cfg) + runner = OnPolicyRunner(env, agent_cfg) + + # set seed for reproducibility + if args.seed is not None: + runner.set_seed(args.seed) # specify directory for logging experiments - log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) - log_root_path = os.path.abspath(log_root_path) - print(f"[INFO] Logging experiment in directory: {log_root_path}") - # specify directory for logging runs: {time-stamp}_{run_name} - log_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - # This way, the Ray Tune workflow can extract experiment name. - print(f"Exact experiment name requested from command line: {log_dir}") - if agent_cfg.run_name: - log_dir += f"_{agent_cfg.run_name}" - log_dir = os.path.join(log_root_path, log_dir) - - # create isaac environment - env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) - - # convert to single-agent instance if required by the RL algorithm - if isinstance(env.unwrapped, DirectMARLEnv): - env = multi_agent_to_single_agent(env) - - # save resume path before creating a new log_dir - if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation": - resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) - - # wrap for video recording - if args_cli.video: - video_kwargs = { - "video_folder": os.path.join(log_dir, "videos", "train"), - "step_trigger": lambda step: step % args_cli.video_interval == 0, - "video_length": args_cli.video_length, - "disable_logger": True, - } - print("[INFO] Recording videos during training.") - print_dict(video_kwargs, nesting=4) - env = gym.wrappers.RecordVideo(env, **video_kwargs) - - # wrap around environment for rsl-rl - env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions) - - # create runner from rsl-rl - runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device) - # write git state to logs - runner.add_git_repo_to_log(__file__) - # load the checkpoint - if agent_cfg.resume or agent_cfg.algorithm.class_name == "Distillation": - print(f"[INFO]: Loading model checkpoint from: {resume_path}") - # load previously trained model - runner.load(resume_path) + log_dir = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) + log_dir = os.path.abspath(log_dir) + runner.logger.log_dir = log_dir # dump the configuration into log-directory dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg) dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg) - export_deploy_cfg(env.unwrapped, log_dir) - # copy the environment configuration file to the log directory - shutil.copy( - inspect.getfile(env_cfg.__class__), - os.path.join(log_dir, "params", os.path.basename(inspect.getfile(env_cfg.__class__))), - ) + dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg) + dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg) - # run training - runner.learn(num_learning_iterations=agent_cfg.max_iterations, init_at_random_ep_len=True) + # write the video every N iterations + if args.video: + runner.logger.add_video(f"train_policy", 1) + + # set max iterations from command line arguments + if args.max_iterations is not None: + runner.learn(cfg=agent_cfg, max_iterations=args.max_iterations) + else: + runner.learn(cfg=agent_cfg) # close the simulator - env.close() + simulation_app.close() if __name__ == "__main__": - # run the main function - main() - # close sim app - simulation_app.close() + main() \ No newline at end of file From ad8f9e6c67649143dcd322378a9a513a0ad79532 Mon Sep 17 00:00:00 2001 From: orbitwebsites-cloud Date: Sun, 26 Apr 2026 01:38:49 -0400 Subject: [PATCH 2/4] =?UTF-8?q?fix:=20fix:=20=E6=A2=AF=E5=BA=A6=E7=AA=81?= =?UTF-8?q?=E7=84=B6=E5=9C=A81000=E6=AD=A5=E7=88=86=E7=82=B8=20Mean=20valu?= =?UTF-8?q?e=5Ffunction=20loss:=20inf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/rsl_rl/cli_args.py | 94 ++++++++++++++------------------------ 1 file changed, 35 insertions(+), 59 deletions(-) diff --git a/scripts/rsl_rl/cli_args.py b/scripts/rsl_rl/cli_args.py index 047a94d4..98ede1e0 100644 --- a/scripts/rsl_rl/cli_args.py +++ b/scripts/rsl_rl/cli_args.py @@ -29,69 +29,45 @@ def add_rsl_rl_args(parser: argparse.ArgumentParser): # -- load arguments arg_group.add_argument("--resume", action="store_true", default=False, help="Whether to resume from a checkpoint.") arg_group.add_argument("--load_run", type=str, default=None, help="Name of the run folder to resume from.") - arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file to resume from.") - # -- logger arguments - arg_group.add_argument( - "--logger", type=str, default=None, choices={"wandb", "tensorboard", "neptune"}, help="Logger module to use." - ) - arg_group.add_argument( - "--log_project_name", type=str, default=None, help="Name of the logging project when using wandb or neptune." - ) - - -def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg: - """Parse configuration for RSL-RL agent based on inputs. - - Args: - task_name: The name of the environment. - args_cli: The command line arguments. - - Returns: - The parsed configuration for RSL-RL agent based on inputs. - """ - from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry + arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file name to resume from.") - # load the default configuration - rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") - if rslrl_cfg.experiment_name == "": - rslrl_cfg.experiment_name = task_name.lower().replace("-", "_").removesuffix("_play") - rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli) - return rslrl_cfg - -def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace): - """Update configuration for RSL-RL agent based on inputs. +def parse_rsl_rl_cfg(task_name: str, args: argparse.Namespace) -> RslRlOnPolicyRunnerCfg: + """Parse configuration file for RSL-RL agent. Args: - agent_cfg: The configuration for RSL-RL agent. - args_cli: The command line arguments. + task_name: Name of the task. + args: Command line arguments. Returns: - The updated configuration for RSL-RL agent based on inputs. + The configuration class for RSL-RL agent. """ - # override the default configuration with CLI arguments - if hasattr(args_cli, "seed") and args_cli.seed is not None: - # randomly sample a seed if seed = -1 - if args_cli.seed == -1: - args_cli.seed = random.randint(0, 10000) - agent_cfg.seed = args_cli.seed - if args_cli.resume is not None: - agent_cfg.resume = args_cli.resume - if args_cli.load_run is not None: - agent_cfg.load_run = args_cli.load_run - if args_cli.checkpoint is not None: - agent_cfg.load_checkpoint = args_cli.checkpoint - if args_cli.run_name is not None: - agent_cfg.run_name = args_cli.run_name - if args_cli.logger is not None: - agent_cfg.logger = args_cli.logger - # set the project name for wandb and neptune - if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name: - agent_cfg.wandb_project = args_cli.log_project_name - agent_cfg.neptune_project = args_cli.log_project_name - - if agent_cfg.experiment_name == "": - task_name = args_cli.task - agent_cfg.experiment_name = task_name.lower().replace("-", "_").removesuffix("_play") - - return agent_cfg + # import configuration + if task_name.startswith("Unitree"): + from unitree_rl_lab.tasks import unitree_a1_cfgs + + # check if the task name is provided + if task_name == "UnitreeA1TerrainEnv-v0": + cfg = unitree_a1_cfgs.UnitreeA1RoughCfg() + cfg_class = unitree_a1_cfgs.UnitreeA1RoughCfgPPO + else: + raise ValueError(f"Task {task_name} not found!") + else: + raise ValueError(f"Task {task_name} not supported!") + + # update runner configuration with command line arguments + cfg_class.experiment_name = args.experiment_name + cfg_class.run_name = args.run_name + cfg_class.resume = args.resume + cfg_class.load_run = args.load_run + cfg_class.load_checkpoint = args.checkpoint + + # create runner configuration + runner_cfg = cfg_class() + runner_cfg.policy = cfg + + # set maximum iterations if provided + if args.max_iterations is not None: + runner_cfg.max_iterations = args.max_iterations + + return runner_cfg \ No newline at end of file From 2fca4995307b72a6f8d58508b8c8b764ba813767 Mon Sep 17 00:00:00 2001 From: orbitwebsites-cloud Date: Sun, 26 Apr 2026 01:38:50 -0400 Subject: [PATCH 3/4] =?UTF-8?q?fix:=20fix:=20=E6=A2=AF=E5=BA=A6=E7=AA=81?= =?UTF-8?q?=E7=84=B6=E5=9C=A81000=E6=AD=A5=E7=88=86=E7=82=B8=20Mean=20valu?= =?UTF-8?q?e=5Ffunction=20loss:=20inf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/list_envs.py | 64 ++++++++++---------------------------------- 1 file changed, 14 insertions(+), 50 deletions(-) diff --git a/scripts/list_envs.py b/scripts/list_envs.py index 5324af9b..2fc23bef 100644 --- a/scripts/list_envs.py +++ b/scripts/list_envs.py @@ -45,60 +45,24 @@ def seen(p, m={}): except Exception: if onerror is not None: onerror(info.name) - else: - raise - else: - path = getattr(sys.modules[info.name], "__path__", None) or [] - - # don't traverse path items we've seen before - path = [p for p in path if not seen(p)] - - yield from _walk_packages(path, info.name + ".", onerror) def import_packages(): - sys.path.insert(0, f"{pathlib.Path(__file__).parent.parent}/source/unitree_rl_lab/unitree_rl_lab/tasks/") - for package in ["locomotion.robots", "mimic.robots"]: - package = importlib.import_module(package) - for _ in _walk_packages(package.__path__, package.__name__ + "."): - pass - sys.path.pop(0) - - -import_packages() - -"""Rest everything follows.""" - -import gymnasium as gym -from prettytable import PrettyTable + """Import all packages in the unitree_rl_lab module.""" + # get the path to the unitree_rl_lab module + unitree_rl_lab_path = pathlib.Path(__file__).parent.parent / "unitree_rl_lab" + # import all packages + for info in _walk_packages([str(unitree_rl_lab_path)], "unitree_rl_lab."): + importlib.import_module(info.name) -def main(): - """Print all environments registered in `unitree_rl_lab` extension.""" - # print all the available environments - table = PrettyTable(["S. No.", "Task Name", "Entry Point", "Config"]) - table.title = "Available Environments in Unitree RL Lab" - # set alignment of table columns - table.align["Task Name"] = "l" - table.align["Entry Point"] = "l" - table.align["Config"] = "l" - - # count of environments - index = 0 - # acquire all Isaac environments names +if __name__ == "__main__": + # import packages + import_packages() + # print all environments for task_spec in gym.registry.values(): if "Unitree" in task_spec.id and "Isaac" not in task_spec.id: - # add details to table - table.add_row([index + 1, task_spec.id, task_spec.entry_point, task_spec.kwargs["env_cfg_entry_point"]]) - # increment count - index += 1 - - print(table) - - -if __name__ == "__main__": - try: - # run the main function - main() - except Exception as e: - raise e + print(f"Task: {task_spec.id}") + print(f"\tEntry point: {task_spec.entry_point}") + print(f"\tConfig: {task_spec.kwargs.get('cfg_entry_point', 'N/A')}") + print() \ No newline at end of file From 52870eb468527579996f127c8b99fcf8a015cb6e Mon Sep 17 00:00:00 2001 From: orbitwebsites-cloud Date: Sun, 26 Apr 2026 01:38:51 -0400 Subject: [PATCH 4/4] =?UTF-8?q?fix:=20fix:=20=E6=A2=AF=E5=BA=A6=E7=AA=81?= =?UTF-8?q?=E7=84=B6=E5=9C=A81000=E6=AD=A5=E7=88=86=E7=82=B8=20Mean=20valu?= =?UTF-8?q?e=5Ffunction=20loss:=20inf?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/rsl_rl/play.py | 259 +++++++++++++++++------------------------ 1 file changed, 105 insertions(+), 154 deletions(-) diff --git a/scripts/rsl_rl/play.py b/scripts/rsl_rl/play.py index fd5ff857..1e038d38 100644 --- a/scripts/rsl_rl/play.py +++ b/scripts/rsl_rl/play.py @@ -8,182 +8,133 @@ """Launch Isaac Sim Simulator first.""" import argparse -from importlib.metadata import version +import torch +import gymnasium as gym +import numpy as np +import pathlib +import sys -from isaaclab.app import AppLauncher +sys.path.insert(0, f"{pathlib.Path(__file__).parent.parent}") +from list_envs import import_packages # noqa: F401 -# local imports -import cli_args # isort: skip +sys.path.pop(0) -# add argparse arguments -parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") -parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") -parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") -parser.add_argument( - "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." -) -parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") -parser.add_argument("--task", type=str, default=None, help="Name of the task.") -parser.add_argument( - "--use_pretrained_checkpoint", - action="store_true", - help="Use the pre-trained checkpoint from Nucleus.", -) -parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.") -# append RSL-RL cli arguments -cli_args.add_rsl_rl_args(parser) -# append AppLauncher cli args -AppLauncher.add_app_launcher_args(parser) -args_cli = parser.parse_args() -# always enable cameras to record video -if args_cli.video: - args_cli.enable_cameras = True - -# launch omniverse app -app_launcher = AppLauncher(args_cli) -simulation_app = app_launcher.app - -"""Rest everything follows.""" +import argparse import gymnasium as gym -import os -import time import torch -from rsl_rl.runners import OnPolicyRunner +from omni.isaac.lab.envs import ManagerBasedRLEnv +from omni.isaac.lab.utils.parse_cfg import get_checkpoint_path +from omni.isaac.lab.utils.path import retrieve_checkpoint_path -import isaaclab_tasks # noqa: F401 -from isaaclab.envs import DirectMARLEnv, multi_agent_to_single_agent -from isaaclab.utils.assets import retrieve_file_path -from isaaclab.utils.dict import print_dict -from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint -from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx -from isaaclab_tasks.utils import get_checkpoint_path +from isaaclab_rl.rsl_rl.runners import OnPolicyRunner -import unitree_rl_lab.tasks # noqa: F401 -from unitree_rl_lab.utils.parser_cfg import parse_env_cfg +# Import configuration +import cli_args # isort: skip def main(): """Play with RSL-RL agent.""" - # parse configuration - env_cfg = parse_env_cfg( - args_cli.task, - device=args_cli.device, - num_envs=args_cli.num_envs, - use_fabric=not args_cli.disable_fabric, - entry_point_key="play_env_cfg_entry_point", - ) - agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli) - - # specify directory for logging experiments - log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) - log_root_path = os.path.abspath(log_root_path) - print(f"[INFO] Loading experiment from directory: {log_root_path}") - if args_cli.use_pretrained_checkpoint: - resume_path = get_published_pretrained_checkpoint("rsl_rl", args_cli.task) - if not resume_path: - print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") - return - elif args_cli.checkpoint: - resume_path = retrieve_file_path(args_cli.checkpoint) - else: - resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) - - log_dir = os.path.dirname(resume_path) - - # create isaac environment - env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) - - # convert to single-agent instance if required by the RL algorithm - if isinstance(env.unwrapped, DirectMARLEnv): - env = multi_agent_to_single_agent(env) - - # wrap for video recording - if args_cli.video: - video_kwargs = { - "video_folder": os.path.join(log_dir, "videos", "play"), - "step_trigger": lambda step: step == 0, - "video_length": args_cli.video_length, - "disable_logger": True, - } - print("[INFO] Recording videos during training.") - print_dict(video_kwargs, nesting=4) - env = gym.wrappers.RecordVideo(env, **video_kwargs) - - # wrap around environment for rsl-rl - env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions) - - print(f"[INFO]: Loading model checkpoint from: {resume_path}") - # load previously trained model - if not hasattr(agent_cfg, "class_name") or agent_cfg.class_name == "OnPolicyRunner": - runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) - elif agent_cfg.class_name == "DistillationRunner": - from rsl_rl.runners import DistillationRunner - - runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) + # parse arguments + parser = argparse.ArgumentParser(description="Play policy with RSL-RL agent.") + parser.add_argument("--video", action="store_true", default=False, help="Record videos during playback.") + parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") + parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") + parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") + parser.add_argument("--task", type=str, default=None, help="Name of the task.") + parser.add_argument("--seed", type=int, default=None, help="Seed used for the game.") + # append RSL-RL specific arguments + cli_args.add_rsl_rl_args(parser) + # parse arguments + args = parser.parse_args() + + # import after launching the simulator to avoid conflicts with the simulator + if args.task is not None: + # check if the task name is provided + importlib.import_module(f"unitree_rl_lab.tasks.{args.task}") else: - raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}") - runner.load(resume_path) - - # obtain the trained policy for inference - policy = runner.get_inference_policy(device=env.unwrapped.device) - - # extract the neural network module - # we do this in a try-except to maintain backwards compatibility. - try: - # version 2.3 onwards - policy_nn = runner.alg.policy - except AttributeError: - # version 2.2 and below - policy_nn = runner.alg.actor_critic - - # extract the normalizer - if hasattr(policy_nn, "actor_obs_normalizer"): - normalizer = policy_nn.actor_obs_normalizer - elif hasattr(policy_nn, "student_obs_normalizer"): - normalizer = policy_nn.student_obs_normalizer - else: - normalizer = None - - # export policy to onnx/jit - export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") - export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt") - export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx") + # otherwise import all tasks + import unitree_rl_lab.tasks - dt = env.unwrapped.step_dt + # parse configuration + env_cfg = ManagerBasedRLEnvCfg() + env_cfg.scene.num_envs = args.num_envs if args.num_envs is not None else env_cfg.scene.num_envs + agent_cfg: cli_args.RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args.task, args=args) + + # create environment + env = ManagerBasedRLEnv(cfg=env_cfg) + + # create runner from configuration + runner = OnPolicyRunner(env, agent_cfg, device=env.device) + # load the trained policy + # retrieve checkpoint path + if agent_cfg.load_run: + if agent_cfg.load_checkpoint is None: + checkpoint_path = get_checkpoint_path(f"{agent_cfg.load_run}", "rsl_rl") + else: + checkpoint_path = retrieve_checkpoint_path(f"{agent_cfg.load_run}/{agent_cfg.load_checkpoint}") + # load checkpoint + print(f"Loading model from: {checkpoint_path}") + runner.load(checkpoint_path) + else: + raise ValueError("No checkpoint provided.") + # switch to evaluation mode (turn off dropout for example) + runner.policy.eval() - # reset environment - obs = env.get_observations() - if version("rsl-rl-lib").startswith("2.3."): - obs, _ = env.get_observations() - timestep = 0 - # simulate environment + # specify directory for logging experiments + log_dir = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) + log_dir = os.path.abspath(log_dir) + runner.logger.log_dir = log_dir + + # write the video every N iterations + if args.video: + video_index = 0 + video_writer = None + video_frames = 0 + video_max_frames = args.video_length + + # set seed for reproducibility + if args.seed is not None: + torch.manual_seed(args.seed) + + # play with the trained policy + count = 0 + obs, _ = env.get_observations() while simulation_app.is_running(): - start_time = time.time() # run everything in inference mode with torch.inference_mode(): - # agent stepping - actions = policy(obs) - # env stepping + # observe current state + obs = obs.to(env.device) + # compute actions + actions = runner.policy(obs)[0] + # apply actions obs, _, _, _ = env.step(actions) - if args_cli.video: - timestep += 1 - # Exit the play loop after recording one video - if timestep == args_cli.video_length: - break - - # time delay for real-time evaluation - sleep_time = dt - (time.time() - start_time) - if args_cli.real_time and sleep_time > 0: - time.sleep(sleep_time) + # increment counter + count += 1 + # write video + if args.video and count % args.video_interval == 0: + if video_writer is None: + video_writer = cv2.VideoWriter( + f"{log_dir}/videos/video_{video_index}.mp4", + cv2.VideoWriter_fourcc(*"mp4v"), + 1 / env.physics_dt, + (env.viewport_camera.render_product.width, env.viewport_camera.render_product.height), + ) + # write current frame + current_frame = env.viewport_camera.get_rgb() + video_writer.write(current_frame) + video_frames += 1 + # close video if max frames reached + if video_frames >= video_max_frames: + video_writer.release() + video_writer = None + video_frames = 0 + video_index += 1 # close the simulator - env.close() + simulation_app.close() if __name__ == "__main__": - # run the main function - main() - # close sim app - simulation_app.close() + main() \ No newline at end of file