Description
🐛 Bug Description
In torchrl.collectors.distributed.ray.RayCollector._async_iterator
, the update_policy_weights_()
method is called with worker_ids=collector_index + 1
. However, collector_index
is an index into the self.remote_collectors
list, and adding 1 to it may result in an IndexError downstream when evaluating RayWeightUpdater._skip_update()
.
🔍 Relevant Code (TorchRL v0.8.1)
# Inside RayCollector._async_iterator()
if self.update_after_each_batch or self.max_weight_update_interval > -1:
torchrl_logger.info(f"Updating weights on worker {collector_index}")
self.update_policy_weights_(worker_ids=collector_index + 1)
Why This is a Bug
collector_index
comes directly from enumerate(self.remote_collectors) and is in the range [0, len(self.remote_collectors) - 1]. Adding +1 causes an out-of-range access when passed to self.update_policy_weights_() and eventually passed to RayWeightUpdater._skip_update
which uses it to index into an array of self._batches_since_weight_update
which is of length len(self.remote_collectors)
.
This likely came from a transition of the argument from worker_rank
to worker_id
. See this earlier merged PR. It used to be the case that update_policy_weights
accepted a worker_rank
argument which had value >= 1. See L883-907 in torchrl/collectors/distributed/generic.py.
Suggested Fix
Replace
self.update_policy_weights_(worker_ids=collector_index + 1)
with
self.update_policy_weights_(worker_ids=collector_index)
Reproduction Instructions
Create a clean virtualenv, install torchrl==0.8.1.0
, ray==2.47.0
, and gymnasium==1.1.1
.
Run the following script:
import ray
from torchrl.envs import GymEnv
from torch import nn
from torchrl.collectors.distributed.ray import RayCollector
from tensordict.nn import TensorDictModule
import random
import time
def create_env():
return GymEnv("CartPole-v1")
# A simple policy compatible with TorchRL
def create_policy():
module = nn.Sequential(
nn.Linear(4, 32),
nn.ReLU(),
nn.Linear(32, 2),
)
time.sleep(random.randint(0, 3))
return TensorDictModule(
module=module,
in_keys=["observation"],
out_keys=["action"],
)
# Ray setup
ray.init(ignore_reinit_error=True, include_dashboard=False, log_to_driver=False)
# === Trigger the bug ===
collector = RayCollector(
create_env_fn=create_env,
policy_factory=create_policy,
frames_per_batch=8,
total_frames=64,
update_after_each_batch=True, # This will trigger the problematic weight sync
num_collectors=2,
sync=False, # Async mode so one collector finishes before the other
)
try:
for _ in collector:
print("Collected batch")
except IndexError as e:
print("\n🔥🔥🔥 Caught IndexError as expected due to off-by-one bug! 🔥🔥🔥")
print(e)
finally:
collector.shutdown()
ray.shutdown()
The random sleep is to make it so that sometimes the second worker is the first to finish. (a bit of pain to get it to consistently trigger). If worker zero finishes first, the issue does not appear.