diff --git a/pufferlib/vector.py b/pufferlib/vector.py index 5746b98487..7cf1504e5c 100644 --- a/pufferlib/vector.py +++ b/pufferlib/vector.py @@ -350,6 +350,9 @@ def __init__(self, env_creators, env_args, env_kwargs, self.zero_copy = zero_copy self.sync_traj = sync_traj + self.ready_workers = [] + self.waiting_workers = [] + def recv(self): recv_precheck(self) while True: @@ -450,6 +453,17 @@ def send(self, actions): self.buf['semaphores'][idxs] = STEP def async_reset(self, seed=0): + # Flush any waiting workers + while self.waiting_workers: + worker = self.waiting_workers.pop(0) + sem = self.buf['semaphores'][worker] + if sem >= MAIN: + self.ready_workers.append(worker) + if sem == INFO: + self.recv_pipes[worker].recv() + else: + self.waiting_workers.append(worker) + self.flag = RECV self.prev_env_id = [] self.flag = RECV