Skip to content

Commit 754d20e

Browse files
committed
[Feature] IsaacLab wrapper
ghstack-source-id: bec4441 Pull-Request-resolved: #2937
1 parent 3dbd84c commit 754d20e

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

torchrl/envs/libs/isaaclab.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import torch
6+
from torchrl.envs import GymLikeEnv
7+
8+
9+
class IsaacLabEnv(GymLikeEnv):
10+
def __init__(
11+
self,
12+
env: "ManagerBasedRLEnv",
13+
categorical_action_encoding=False,
14+
allow_done_after_reset=True,
15+
convert_actions_to_numpy=False,
16+
**kwargs
17+
):
18+
"""
19+
Here we are setting some parameters that are what we need for IsaacLab.
20+
"""
21+
super().__init__(
22+
env,
23+
device=torch.device("cuda:0"),
24+
categorical_action_encoding=categorical_action_encoding,
25+
allow_done_after_reset=allow_done_after_reset,
26+
convert_actions_to_numpy=convert_actions_to_numpy,
27+
**kwargs
28+
)
29+
30+
@classmethod
31+
def from_cfg(
32+
cls,
33+
cfg,
34+
categorical_action_encoding=False,
35+
allow_done_after_reset=True,
36+
convert_actions_to_numpy=False,
37+
**kwargs
38+
):
39+
from isaaclab.envs import ManagerBasedRLEnv
40+
41+
env = ManagerBasedRLEnv(cfg=cfg)
42+
return cls(
43+
env=env,
44+
categorical_action_encoding=categorical_action_encoding,
45+
allow_done_after_reset=allow_done_after_reset,
46+
convert_actions_to_numpy=convert_actions_to_numpy,
47+
**kwargs
48+
)
49+
50+
def seed(self, seed: int | None):
51+
self._set_seed(seed)
52+
53+
def _output_transform(self, step_outputs_tuple): # noqa: F811
54+
"""
55+
We discovered the IsaacLab will modify the `terminated` and `truncated` tensors
56+
in place. Clone them here to make sure data doesn't inadvertently get modified.
57+
58+
This is a PR in torchRL:
59+
Once we update to the version with this PR, we can delete this.
60+
"""
61+
# The variable naming follows torchrl's convention here.
62+
observations, reward, terminated, truncated, info = step_outputs_tuple
63+
done = terminated | truncated
64+
reward = reward.unsqueeze(-1) # to get to (num_envs, 1)
65+
return (
66+
observations,
67+
reward,
68+
terminated.clone(),
69+
truncated.clone(),
70+
done.clone(),
71+
info,
72+
)
73+
74+
75+
if __name__ == "__main__":
76+
from isaaclab_tasks.manager_based.classic.cartpole.cartpole_env_cfg import CartpoleEnvCfg
77+
78+
env = IsaacLabEnv.from_cfg(CartpoleEnvCfg())
79+
env.check_env_specs(break_when_any_done="both")

torchrl/envs/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import re
1515
import warnings
1616
from enum import Enum
17-
from typing import Any
17+
from typing import Any, Literal
1818

1919
import torch
2020

@@ -687,7 +687,7 @@ def check_env_specs(
687687
check_dtype=True,
688688
seed: int | None = None,
689689
tensordict: TensorDictBase | None = None,
690-
break_when_any_done: bool | str = None,
690+
break_when_any_done: bool | Literal["both"] = None,
691691
):
692692
"""Tests an environment specs against the results of short rollout.
693693

0 commit comments

Comments
 (0)