Skip to content

[BUG] RayCollector calls update_policy_weights_() with off-by-one index causing intermittent IndexError #3000

Open
@kjchavez

Description

@kjchavez

🐛 Bug Description

In torchrl.collectors.distributed.ray.RayCollector._async_iterator, the update_policy_weights_() method is called with worker_ids=collector_index + 1. However, collector_index is an index into the self.remote_collectors list, and adding 1 to it may result in an IndexError downstream when evaluating RayWeightUpdater._skip_update().


🔍 Relevant Code (TorchRL v0.8.1)

# Inside RayCollector._async_iterator()
if self.update_after_each_batch or self.max_weight_update_interval > -1:
    torchrl_logger.info(f"Updating weights on worker {collector_index}")
    self.update_policy_weights_(worker_ids=collector_index + 1)

Why This is a Bug

collector_index comes directly from enumerate(self.remote_collectors) and is in the range [0, len(self.remote_collectors) - 1]. Adding +1 causes an out-of-range access when passed to self.update_policy_weights_() and eventually passed to RayWeightUpdater._skip_update which uses it to index into an array of self._batches_since_weight_update which is of length len(self.remote_collectors).

This likely came from a transition of the argument from worker_rank to worker_id . See this earlier merged PR. It used to be the case that update_policy_weights accepted a worker_rank argument which had value >= 1. See L883-907 in torchrl/collectors/distributed/generic.py.

Suggested Fix

Replace

self.update_policy_weights_(worker_ids=collector_index + 1)

with

self.update_policy_weights_(worker_ids=collector_index)

Reproduction Instructions

Create a clean virtualenv, install torchrl==0.8.1.0, ray==2.47.0 , and gymnasium==1.1.1.

Run the following script:

import ray
from torchrl.envs import GymEnv
from torch import nn
from torchrl.collectors.distributed.ray import RayCollector
from tensordict.nn import TensorDictModule
import random
import time

def create_env():
    return GymEnv("CartPole-v1")

# A simple policy compatible with TorchRL
def create_policy():
    module = nn.Sequential(
        nn.Linear(4, 32),
        nn.ReLU(),
        nn.Linear(32, 2),
    )
    time.sleep(random.randint(0, 3))
    return TensorDictModule(
        module=module,
        in_keys=["observation"],
        out_keys=["action"],
    )

# Ray setup
ray.init(ignore_reinit_error=True, include_dashboard=False, log_to_driver=False)

# === Trigger the bug ===
collector = RayCollector(
    create_env_fn=create_env,
    policy_factory=create_policy,
    frames_per_batch=8,
    total_frames=64,
    update_after_each_batch=True,  # This will trigger the problematic weight sync
    num_collectors=2,
    sync=False,  # Async mode so one collector finishes before the other
)

try:
    for _ in collector:
        print("Collected batch")
except IndexError as e:
    print("\n🔥🔥🔥 Caught IndexError as expected due to off-by-one bug! 🔥🔥🔥")
    print(e)
finally:
    collector.shutdown()
    ray.shutdown()

The random sleep is to make it so that sometimes the second worker is the first to finish. (a bit of pain to get it to consistently trigger). If worker zero finishes first, the issue does not appear.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions