Skip to content

Commit 51c2026

Browse files
authored
Fix unpickling Box2D and MuJoCo envs (#3025)
* Try to fix car racing unpickling * Fix EzPickle for BipedalWalker and LunarLander * Shamelessly steal the pickle-unpickle test from Mark, with slight modifications * CarRacing EzPickle fix * Mujoco ezpickle fix
1 parent f54319e commit 51c2026

32 files changed

+203
-34
lines changed

gym/envs/box2d/bipedal_walker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class BipedalWalker(gym.Env, EzPickle):
169169
}
170170

171171
def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False):
172-
EzPickle.__init__(self)
172+
EzPickle.__init__(self, render_mode, hardcore)
173173
self.isopen = True
174174

175175
self.world = Box2D.b2World()

gym/envs/box2d/car_racing.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,20 @@ def __init__(
200200
domain_randomize: bool = False,
201201
continuous: bool = True,
202202
):
203-
EzPickle.__init__(self)
203+
EzPickle.__init__(
204+
self,
205+
render_mode,
206+
verbose,
207+
lap_complete_percent,
208+
domain_randomize,
209+
continuous,
210+
)
204211
self.continuous = continuous
205212
self.domain_randomize = domain_randomize
213+
self.lap_complete_percent = lap_complete_percent
206214
self._init_colors()
207215

208-
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
216+
self.contactListener_keepref = FrictionDetector(self, self.lap_complete_percent)
209217
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
210218
self.screen: Optional[pygame.Surface] = None
211219
self.surf = None
@@ -480,6 +488,10 @@ def reset(
480488
):
481489
super().reset(seed=seed)
482490
self._destroy()
491+
self.world.contactListener_bug_workaround = FrictionDetector(
492+
self, self.lap_complete_percent
493+
)
494+
self.world.contactListener = self.world.contactListener_bug_workaround
483495
self.reward = 0.0
484496
self.prev_reward = 0.0
485497
self.tile_visited_count = 0

gym/envs/box2d/lunar_lander.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,15 @@ def __init__(
192192
wind_power: float = 15.0,
193193
turbulence_power: float = 1.5,
194194
):
195-
EzPickle.__init__(self)
195+
EzPickle.__init__(
196+
self,
197+
render_mode,
198+
continuous,
199+
gravity,
200+
enable_wind,
201+
wind_power,
202+
turbulence_power,
203+
)
196204

197205
assert (
198206
-12.0 < gravity and gravity < 0.0

gym/envs/mujoco/ant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, **kwargs):
2424
MuJocoPyEnv.__init__(
2525
self, "ant.xml", 5, observation_space=observation_space, **kwargs
2626
)
27-
utils.EzPickle.__init__(self)
27+
utils.EzPickle.__init__(self, **kwargs)
2828

2929
def step(self, a):
3030
xposbefore = self.get_body_com("torso")[0]

gym/envs/mujoco/ant_v3.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,19 @@ def __init__(
3434
exclude_current_positions_from_observation=True,
3535
**kwargs
3636
):
37-
utils.EzPickle.__init__(**locals())
37+
utils.EzPickle.__init__(
38+
self,
39+
xml_file,
40+
ctrl_cost_weight,
41+
contact_cost_weight,
42+
healthy_reward,
43+
terminate_when_unhealthy,
44+
healthy_z_range,
45+
contact_force_range,
46+
reset_noise_scale,
47+
exclude_current_positions_from_observation,
48+
**kwargs
49+
)
3850

3951
self._ctrl_cost_weight = ctrl_cost_weight
4052
self._contact_cost_weight = contact_cost_weight

gym/envs/mujoco/ant_v4.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,20 @@ def __init__(
197197
exclude_current_positions_from_observation=True,
198198
**kwargs
199199
):
200-
utils.EzPickle.__init__(**locals())
200+
utils.EzPickle.__init__(
201+
self,
202+
xml_file,
203+
ctrl_cost_weight,
204+
use_contact_forces,
205+
contact_cost_weight,
206+
healthy_reward,
207+
terminate_when_unhealthy,
208+
healthy_z_range,
209+
contact_force_range,
210+
reset_noise_scale,
211+
exclude_current_positions_from_observation,
212+
**kwargs
213+
)
201214

202215
self._ctrl_cost_weight = ctrl_cost_weight
203216
self._contact_cost_weight = contact_cost_weight

gym/envs/mujoco/half_cheetah.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, **kwargs):
2222
MuJocoPyEnv.__init__(
2323
self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs
2424
)
25-
utils.EzPickle.__init__(self)
25+
utils.EzPickle.__init__(self, **kwargs)
2626

2727
def step(self, action):
2828
xposbefore = self.sim.data.qpos[0]

gym/envs/mujoco/half_cheetah_v3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,15 @@ def __init__(
3232
exclude_current_positions_from_observation=True,
3333
**kwargs
3434
):
35-
utils.EzPickle.__init__(**locals())
35+
utils.EzPickle.__init__(
36+
self,
37+
xml_file,
38+
forward_reward_weight,
39+
ctrl_cost_weight,
40+
reset_noise_scale,
41+
exclude_current_positions_from_observation,
42+
**kwargs
43+
)
3644

3745
self._forward_reward_weight = forward_reward_weight
3846

gym/envs/mujoco/half_cheetah_v4.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,14 @@ def __init__(
151151
exclude_current_positions_from_observation=True,
152152
**kwargs
153153
):
154-
utils.EzPickle.__init__(**locals())
154+
utils.EzPickle.__init__(
155+
self,
156+
forward_reward_weight,
157+
ctrl_cost_weight,
158+
reset_noise_scale,
159+
exclude_current_positions_from_observation,
160+
**kwargs
161+
)
155162

156163
self._forward_reward_weight = forward_reward_weight
157164

gym/envs/mujoco/hopper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, **kwargs):
2222
MuJocoPyEnv.__init__(
2323
self, "hopper.xml", 4, observation_space=observation_space, **kwargs
2424
)
25-
utils.EzPickle.__init__(self)
25+
utils.EzPickle.__init__(self, **kwargs)
2626

2727
def step(self, a):
2828
posbefore = self.sim.data.qpos[0]

gym/envs/mujoco/hopper_v3.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,20 @@ def __init__(
4040
exclude_current_positions_from_observation=True,
4141
**kwargs
4242
):
43-
utils.EzPickle.__init__(**locals())
43+
utils.EzPickle.__init__(
44+
self,
45+
xml_file,
46+
forward_reward_weight,
47+
ctrl_cost_weight,
48+
healthy_reward,
49+
terminate_when_unhealthy,
50+
healthy_state_range,
51+
healthy_z_range,
52+
healthy_angle_range,
53+
reset_noise_scale,
54+
exclude_current_positions_from_observation,
55+
**kwargs
56+
)
4457

4558
self._forward_reward_weight = forward_reward_weight
4659

gym/envs/mujoco/hopper_v4.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,19 @@ def __init__(
162162
exclude_current_positions_from_observation=True,
163163
**kwargs
164164
):
165-
utils.EzPickle.__init__(**locals())
165+
utils.EzPickle.__init__(
166+
self,
167+
forward_reward_weight,
168+
ctrl_cost_weight,
169+
healthy_reward,
170+
terminate_when_unhealthy,
171+
healthy_state_range,
172+
healthy_z_range,
173+
healthy_angle_range,
174+
reset_noise_scale,
175+
exclude_current_positions_from_observation,
176+
**kwargs
177+
)
166178

167179
self._forward_reward_weight = forward_reward_weight
168180

gym/envs/mujoco/humanoid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, **kwargs):
3030
MuJocoPyEnv.__init__(
3131
self, "humanoid.xml", 5, observation_space=observation_space, **kwargs
3232
)
33-
utils.EzPickle.__init__(self)
33+
utils.EzPickle.__init__(self, **kwargs)
3434

3535
def _get_obs(self):
3636
data = self.sim.data

gym/envs/mujoco/humanoid_v3.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,20 @@ def __init__(
4444
exclude_current_positions_from_observation=True,
4545
**kwargs
4646
):
47-
utils.EzPickle.__init__(**locals())
47+
utils.EzPickle.__init__(
48+
self,
49+
xml_file,
50+
forward_reward_weight,
51+
ctrl_cost_weight,
52+
contact_cost_weight,
53+
contact_cost_range,
54+
healthy_reward,
55+
terminate_when_unhealthy,
56+
healthy_z_range,
57+
reset_noise_scale,
58+
exclude_current_positions_from_observation,
59+
**kwargs
60+
)
4861

4962
self._forward_reward_weight = forward_reward_weight
5063
self._ctrl_cost_weight = ctrl_cost_weight

gym/envs/mujoco/humanoid_v4.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,17 @@ def __init__(
234234
exclude_current_positions_from_observation=True,
235235
**kwargs
236236
):
237-
utils.EzPickle.__init__(**locals())
237+
utils.EzPickle.__init__(
238+
self,
239+
forward_reward_weight,
240+
ctrl_cost_weight,
241+
healthy_reward,
242+
terminate_when_unhealthy,
243+
healthy_z_range,
244+
reset_noise_scale,
245+
exclude_current_positions_from_observation,
246+
**kwargs
247+
)
238248

239249
self._forward_reward_weight = forward_reward_weight
240250
self._ctrl_cost_weight = ctrl_cost_weight

gym/envs/mujoco/humanoidstandup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, **kwargs):
2828
observation_space=observation_space,
2929
**kwargs
3030
)
31-
utils.EzPickle.__init__(self)
31+
utils.EzPickle.__init__(self, **kwargs)
3232

3333
def _get_obs(self):
3434
data = self.sim.data

gym/envs/mujoco/humanoidstandup_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def __init__(self, **kwargs):
200200
observation_space=observation_space,
201201
**kwargs
202202
)
203-
utils.EzPickle.__init__(self)
203+
utils.EzPickle.__init__(self, **kwargs)
204204

205205
def _get_obs(self):
206206
data = self.data

gym/envs/mujoco/inverted_double_pendulum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, **kwargs):
2626
observation_space=observation_space,
2727
**kwargs
2828
)
29-
utils.EzPickle.__init__(self)
29+
utils.EzPickle.__init__(self, **kwargs)
3030

3131
def step(self, action):
3232
self.do_simulation(action, self.frame_skip)

gym/envs/mujoco/inverted_double_pendulum_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(self, **kwargs):
132132
observation_space=observation_space,
133133
**kwargs
134134
)
135-
utils.EzPickle.__init__(self)
135+
utils.EzPickle.__init__(self, **kwargs)
136136

137137
def step(self, action):
138138
self.do_simulation(action, self.frame_skip)

gym/envs/mujoco/inverted_pendulum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
1818
}
1919

2020
def __init__(self, **kwargs):
21-
utils.EzPickle.__init__(self)
21+
utils.EzPickle.__init__(self, **kwargs)
2222
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
2323
MuJocoPyEnv.__init__(
2424
self,

gym/envs/mujoco/inverted_pendulum_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
9595
}
9696

9797
def __init__(self, **kwargs):
98-
utils.EzPickle.__init__(self)
98+
utils.EzPickle.__init__(self, **kwargs)
9999
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
100100
MujocoEnv.__init__(
101101
self,

gym/envs/mujoco/pusher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
1818
}
1919

2020
def __init__(self, **kwargs):
21-
utils.EzPickle.__init__(self)
21+
utils.EzPickle.__init__(self, **kwargs)
2222
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
2323
MuJocoPyEnv.__init__(
2424
self, "pusher.xml", 5, observation_space=observation_space, **kwargs

gym/envs/mujoco/pusher_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
140140
}
141141

142142
def __init__(self, **kwargs):
143-
utils.EzPickle.__init__(self)
143+
utils.EzPickle.__init__(self, **kwargs)
144144
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
145145
MujocoEnv.__init__(
146146
self, "pusher.xml", 5, observation_space=observation_space, **kwargs

gym/envs/mujoco/reacher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
1818
}
1919

2020
def __init__(self, **kwargs):
21-
utils.EzPickle.__init__(self)
21+
utils.EzPickle.__init__(self, **kwargs)
2222
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
2323
MuJocoPyEnv.__init__(
2424
self, "reacher.xml", 2, observation_space=observation_space, **kwargs

gym/envs/mujoco/reacher_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
130130
}
131131

132132
def __init__(self, **kwargs):
133-
utils.EzPickle.__init__(self)
133+
utils.EzPickle.__init__(self, **kwargs)
134134
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
135135
MujocoEnv.__init__(
136136
self, "reacher.xml", 2, observation_space=observation_space, **kwargs

gym/envs/mujoco/swimmer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, **kwargs):
2222
MuJocoPyEnv.__init__(
2323
self, "swimmer.xml", 4, observation_space=observation_space, **kwargs
2424
)
25-
utils.EzPickle.__init__(self)
25+
utils.EzPickle.__init__(self, **kwargs)
2626

2727
def step(self, a):
2828
ctrl_cost_coeff = 0.0001

gym/envs/mujoco/swimmer_v3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,15 @@ def __init__(
3030
exclude_current_positions_from_observation=True,
3131
**kwargs
3232
):
33-
utils.EzPickle.__init__(**locals())
33+
utils.EzPickle.__init__(
34+
self,
35+
xml_file,
36+
forward_reward_weight,
37+
ctrl_cost_weight,
38+
reset_noise_scale,
39+
exclude_current_positions_from_observation,
40+
**kwargs
41+
)
3442

3543
self._forward_reward_weight = forward_reward_weight
3644
self._ctrl_cost_weight = ctrl_cost_weight

gym/envs/mujoco/swimmer_v4.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,14 @@ def __init__(
143143
exclude_current_positions_from_observation=True,
144144
**kwargs
145145
):
146-
utils.EzPickle.__init__(**locals())
146+
utils.EzPickle.__init__(
147+
self,
148+
forward_reward_weight,
149+
ctrl_cost_weight,
150+
reset_noise_scale,
151+
exclude_current_positions_from_observation,
152+
**kwargs
153+
)
147154

148155
self._forward_reward_weight = forward_reward_weight
149156
self._ctrl_cost_weight = ctrl_cost_weight

gym/envs/mujoco/walker2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, **kwargs):
2222
MuJocoPyEnv.__init__(
2323
self, "walker2d.xml", 4, observation_space=observation_space, **kwargs
2424
)
25-
utils.EzPickle.__init__(self)
25+
utils.EzPickle.__init__(self, **kwargs)
2626

2727
def step(self, a):
2828
posbefore = self.sim.data.qpos[0]

0 commit comments

Comments
 (0)