Skip to content

Commit ce3e5ad

Browse files
committed
[Test] Add PEnv tests for devices
ghstack-source-id: ad6dcce87bcde2a22f619b9742324073eb67cd5b Pull Request resolved: #2843
1 parent 6e40548 commit ce3e5ad

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

test/_utils_internal.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -267,21 +267,23 @@ def _make_envs(
267267
transformed_in,
268268
transformed_out,
269269
N,
270-
device="cpu",
270+
p_env_device=None,
271+
env_device=None,
272+
# device="cpu",
271273
kwargs=None,
272274
local_mp_ctx=mp_ctx,
273275
):
274276
torch.manual_seed(0)
275277
if not transformed_in:
276278

277279
def create_env_fn():
278-
return GymEnv(env_name, frame_skip=frame_skip, device=device)
280+
return GymEnv(env_name, frame_skip=frame_skip, device=env_device)
279281

280282
else:
281283
if env_name == PONG_VERSIONED():
282284

283285
def create_env_fn():
284-
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
286+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device)
285287
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
286288
return TransformedEnv(
287289
base_env,
@@ -292,7 +294,7 @@ def create_env_fn():
292294

293295
def create_env_fn():
294296

295-
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
297+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=env_device)
296298
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
297299

298300
return TransformedEnv(
@@ -305,9 +307,15 @@ def create_env_fn():
305307

306308
env0 = create_env_fn()
307309
env_parallel = ParallelEnv(
308-
N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx
310+
N,
311+
create_env_fn,
312+
create_env_kwargs=kwargs,
313+
mp_start_method=local_mp_ctx,
314+
device=p_env_device,
315+
)
316+
env_serial = SerialEnv(
317+
N, create_env_fn, create_env_kwargs=kwargs, device=p_env_device
309318
)
310-
env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)
311319

312320
for key in env0.observation_spec.keys(True, True):
313321
obs_key = key

test/test_env.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -1464,12 +1464,29 @@ def make_env():
14641464
"transformed_in,transformed_out", [[True, True], [False, False]]
14651465
) # 1226: effociency
14661466
@pytest.mark.parametrize("static_seed", [False, True])
1467+
@pytest.mark.parametrize("penv_device", ["cpu", None])
1468+
@pytest.mark.parametrize("env_device", ["cpu", None])
1469+
@pytest.mark.parametrize("bwad", [True, False])
14671470
def test_parallel_env_seed(
1468-
self, env_name, frame_skip, transformed_in, transformed_out, static_seed
1471+
self,
1472+
env_name,
1473+
frame_skip,
1474+
transformed_in,
1475+
transformed_out,
1476+
static_seed,
1477+
penv_device,
1478+
env_device,
1479+
bwad,
14691480
):
14701481
env_name = env_name()
14711482
env_parallel, env_serial, _, _ = _make_envs(
1472-
env_name, frame_skip, transformed_in, transformed_out, 5
1483+
env_name,
1484+
frame_skip,
1485+
transformed_in,
1486+
transformed_out,
1487+
5,
1488+
p_env_device=penv_device,
1489+
env_device=env_device,
14731490
)
14741491
try:
14751492
out_seed_serial = env_serial.set_seed(0, static_seed=static_seed)
@@ -1479,7 +1496,10 @@ def test_parallel_env_seed(
14791496
torch.manual_seed(0)
14801497

14811498
td_serial = env_serial.rollout(
1482-
max_steps=10, auto_reset=False, tensordict=td0_serial
1499+
max_steps=10,
1500+
auto_reset=False,
1501+
tensordict=td0_serial,
1502+
break_when_any_done=bwad,
14831503
).contiguous()
14841504
key = "pixels" if "pixels" in td_serial.keys() else "observation"
14851505
torch.testing.assert_close(
@@ -1494,7 +1514,10 @@ def test_parallel_env_seed(
14941514
torch.manual_seed(0)
14951515
assert out_seed_parallel == out_seed_serial
14961516
td_parallel = env_parallel.rollout(
1497-
max_steps=10, auto_reset=False, tensordict=td0_parallel
1517+
max_steps=10,
1518+
auto_reset=False,
1519+
tensordict=td0_parallel,
1520+
break_when_any_done=bwad,
14981521
).contiguous()
14991522
torch.testing.assert_close(
15001523
td_parallel[:, :-1].get(("next", key)), td_parallel[:, 1:].get(key)
@@ -1670,7 +1693,7 @@ def test_parallel_env_device(
16701693
frame_skip,
16711694
transformed_in=transformed_in,
16721695
transformed_out=transformed_out,
1673-
device=device,
1696+
env_device=device,
16741697
N=N,
16751698
local_mp_ctx="spawn",
16761699
)

0 commit comments

Comments
 (0)