@@ -1464,12 +1464,29 @@ def make_env():
1464
1464
"transformed_in,transformed_out" , [[True , True ], [False , False ]]
1465
1465
) # 1226: effociency
1466
1466
@pytest .mark .parametrize ("static_seed" , [False , True ])
1467
+ @pytest .mark .parametrize ("penv_device" , ["cpu" , None ])
1468
+ @pytest .mark .parametrize ("env_device" , ["cpu" , None ])
1469
+ @pytest .mark .parametrize ("bwad" , [True , False ])
1467
1470
def test_parallel_env_seed (
1468
- self , env_name , frame_skip , transformed_in , transformed_out , static_seed
1471
+ self ,
1472
+ env_name ,
1473
+ frame_skip ,
1474
+ transformed_in ,
1475
+ transformed_out ,
1476
+ static_seed ,
1477
+ penv_device ,
1478
+ env_device ,
1479
+ bwad ,
1469
1480
):
1470
1481
env_name = env_name ()
1471
1482
env_parallel , env_serial , _ , _ = _make_envs (
1472
- env_name , frame_skip , transformed_in , transformed_out , 5
1483
+ env_name ,
1484
+ frame_skip ,
1485
+ transformed_in ,
1486
+ transformed_out ,
1487
+ 5 ,
1488
+ p_env_device = penv_device ,
1489
+ env_device = env_device ,
1473
1490
)
1474
1491
try :
1475
1492
out_seed_serial = env_serial .set_seed (0 , static_seed = static_seed )
@@ -1479,7 +1496,10 @@ def test_parallel_env_seed(
1479
1496
torch .manual_seed (0 )
1480
1497
1481
1498
td_serial = env_serial .rollout (
1482
- max_steps = 10 , auto_reset = False , tensordict = td0_serial
1499
+ max_steps = 10 ,
1500
+ auto_reset = False ,
1501
+ tensordict = td0_serial ,
1502
+ break_when_any_done = bwad ,
1483
1503
).contiguous ()
1484
1504
key = "pixels" if "pixels" in td_serial .keys () else "observation"
1485
1505
torch .testing .assert_close (
@@ -1494,7 +1514,10 @@ def test_parallel_env_seed(
1494
1514
torch .manual_seed (0 )
1495
1515
assert out_seed_parallel == out_seed_serial
1496
1516
td_parallel = env_parallel .rollout (
1497
- max_steps = 10 , auto_reset = False , tensordict = td0_parallel
1517
+ max_steps = 10 ,
1518
+ auto_reset = False ,
1519
+ tensordict = td0_parallel ,
1520
+ break_when_any_done = bwad ,
1498
1521
).contiguous ()
1499
1522
torch .testing .assert_close (
1500
1523
td_parallel [:, :- 1 ].get (("next" , key )), td_parallel [:, 1 :].get (key )
@@ -1670,7 +1693,7 @@ def test_parallel_env_device(
1670
1693
frame_skip ,
1671
1694
transformed_in = transformed_in ,
1672
1695
transformed_out = transformed_out ,
1673
- device = device ,
1696
+ env_device = device ,
1674
1697
N = N ,
1675
1698
local_mp_ctx = "spawn" ,
1676
1699
)
0 commit comments