1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+ from torchrl .envs import GymLikeEnv
7
+
8
+
9
+ class IsaacLabEnv (GymLikeEnv ):
10
+ def __init__ (
11
+ self ,
12
+ env : "ManagerBasedRLEnv" ,
13
+ categorical_action_encoding = False ,
14
+ allow_done_after_reset = True ,
15
+ convert_actions_to_numpy = False ,
16
+ ** kwargs
17
+ ):
18
+ """
19
+ Here we are setting some parameters that are what we need for IsaacLab.
20
+ """
21
+ super ().__init__ (
22
+ env ,
23
+ device = torch .device ("cuda:0" ),
24
+ categorical_action_encoding = categorical_action_encoding ,
25
+ allow_done_after_reset = allow_done_after_reset ,
26
+ convert_actions_to_numpy = convert_actions_to_numpy ,
27
+ ** kwargs
28
+ )
29
+
30
+ @classmethod
31
+ def from_cfg (
32
+ cls ,
33
+ cfg ,
34
+ categorical_action_encoding = False ,
35
+ allow_done_after_reset = True ,
36
+ convert_actions_to_numpy = False ,
37
+ ** kwargs
38
+ ):
39
+ from isaaclab .envs import ManagerBasedRLEnv
40
+
41
+ env = ManagerBasedRLEnv (cfg = cfg )
42
+ return cls (
43
+ env = env ,
44
+ categorical_action_encoding = categorical_action_encoding ,
45
+ allow_done_after_reset = allow_done_after_reset ,
46
+ convert_actions_to_numpy = convert_actions_to_numpy ,
47
+ ** kwargs
48
+ )
49
+
50
+ def seed (self , seed : int | None ):
51
+ self ._set_seed (seed )
52
+
53
+ def _output_transform (self , step_outputs_tuple ): # noqa: F811
54
+ """
55
+ We discovered the IsaacLab will modify the `terminated` and `truncated` tensors
56
+ in place. Clone them here to make sure data doesn't inadvertently get modified.
57
+
58
+ This is a PR in torchRL:
59
+ Once we update to the version with this PR, we can delete this.
60
+ """
61
+ # The variable naming follows torchrl's convention here.
62
+ observations , reward , terminated , truncated , info = step_outputs_tuple
63
+ done = terminated | truncated
64
+ reward = reward .unsqueeze (- 1 ) # to get to (num_envs, 1)
65
+ return (
66
+ observations ,
67
+ reward ,
68
+ terminated .clone (),
69
+ truncated .clone (),
70
+ done .clone (),
71
+ info ,
72
+ )
73
+
74
+
75
+ if __name__ == "__main__" :
76
+ from isaaclab_tasks .manager_based .classic .cartpole .cartpole_env_cfg import CartpoleEnvCfg
77
+
78
+ env = IsaacLabEnv .from_cfg (CartpoleEnvCfg ())
79
+ env .check_env_specs (break_when_any_done = "both" )
0 commit comments