1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ #
3+ # This source code is licensed under the MIT license found in the
4+ # LICENSE file in the root directory of this source tree.
5+ import torch
6+ from torchrl .envs import GymLikeEnv
7+
8+
9+ class IsaacLabEnv (GymLikeEnv ):
10+ def __init__ (
11+ self ,
12+ env : "ManagerBasedRLEnv" ,
13+ categorical_action_encoding = False ,
14+ allow_done_after_reset = True ,
15+ convert_actions_to_numpy = False ,
16+ ** kwargs
17+ ):
18+ """
19+ Here we are setting some parameters that are what we need for IsaacLab.
20+ """
21+ super ().__init__ (
22+ env ,
23+ device = torch .device ("cuda:0" ),
24+ categorical_action_encoding = categorical_action_encoding ,
25+ allow_done_after_reset = allow_done_after_reset ,
26+ convert_actions_to_numpy = convert_actions_to_numpy ,
27+ ** kwargs
28+ )
29+
30+ @classmethod
31+ def from_cfg (
32+ cls ,
33+ cfg ,
34+ categorical_action_encoding = False ,
35+ allow_done_after_reset = True ,
36+ convert_actions_to_numpy = False ,
37+ ** kwargs
38+ ):
39+ from isaaclab .envs import ManagerBasedRLEnv
40+
41+ env = ManagerBasedRLEnv (cfg = cfg )
42+ return cls (
43+ env = env ,
44+ categorical_action_encoding = categorical_action_encoding ,
45+ allow_done_after_reset = allow_done_after_reset ,
46+ convert_actions_to_numpy = convert_actions_to_numpy ,
47+ ** kwargs
48+ )
49+
50+ def seed (self , seed : int | None ):
51+ self ._set_seed (seed )
52+
53+ def _output_transform (self , step_outputs_tuple ): # noqa: F811
54+ """
55+ We discovered the IsaacLab will modify the `terminated` and `truncated` tensors
56+ in place. Clone them here to make sure data doesn't inadvertently get modified.
57+
58+ This is a PR in torchRL:
59+ Once we update to the version with this PR, we can delete this.
60+ """
61+ # The variable naming follows torchrl's convention here.
62+ observations , reward , terminated , truncated , info = step_outputs_tuple
63+ done = terminated | truncated
64+ reward = reward .unsqueeze (- 1 ) # to get to (num_envs, 1)
65+ return (
66+ observations ,
67+ reward ,
68+ terminated .clone (),
69+ truncated .clone (),
70+ done .clone (),
71+ info ,
72+ )
73+
74+
75+ if __name__ == "__main__" :
76+ from isaaclab_tasks .manager_based .classic .cartpole .cartpole_env_cfg import CartpoleEnvCfg
77+
78+ env = IsaacLabEnv .from_cfg (CartpoleEnvCfg ())
79+ env .check_env_specs (break_when_any_done = "both" )
0 commit comments