fix(agent_loop): handle batch size smaller than num_workers#5231
fix(agent_loop): handle batch size smaller than num_workers#5231aoshen524 wants to merge 1 commit intoverl-project:mainfrom
Conversation
…te_sequences When the batch size is smaller than the number of agent loop workers, `prompts.chunk(len(self.agent_loop_workers))` produces fewer chunks than workers, causing `zip(..., strict=True)` to raise a ValueError. This fix caps the chunk count at `min(len(prompts), len(self.agent_loop_workers))` and uses index-based worker dispatch instead. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request fixes a crash that occurs when the batch size is smaller than the number of agent loop workers. The fix correctly calculates the number of workers needed and adjusts the chunking of prompts accordingly. However, the changes could lead to a crash if no workers are available or if the input prompts is empty. I've added a suggestion to handle this edge case gracefully.
| num_workers_needed = min(len(prompts), len(self.agent_loop_workers)) | ||
| chunkes = prompts.chunk(num_workers_needed) | ||
| outputs = ray.get( | ||
| [ | ||
| worker.generate_sequences.remote(chunk) | ||
| for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) | ||
| self.agent_loop_workers[i % len(self.agent_loop_workers)].generate_sequences.remote(chunk) | ||
| for i, chunk in enumerate(chunkes) | ||
| ] | ||
| ) |
There was a problem hiding this comment.
The logic here can lead to a crash if num_workers_needed is 0. This can happen if prompts is empty or if self.agent_loop_workers is empty. In this case, prompts.chunk(0) will be called, which will raise an error because the number of chunks must be positive.
To prevent this, we should add a check to handle the case where num_workers_needed is 0 and return an empty DataProto.
| num_workers_needed = min(len(prompts), len(self.agent_loop_workers)) | |
| chunkes = prompts.chunk(num_workers_needed) | |
| outputs = ray.get( | |
| [ | |
| worker.generate_sequences.remote(chunk) | |
| for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) | |
| self.agent_loop_workers[i % len(self.agent_loop_workers)].generate_sequences.remote(chunk) | |
| for i, chunk in enumerate(chunkes) | |
| ] | |
| ) | |
| num_workers_needed = min(len(prompts), len(self.agent_loop_workers)) | |
| if num_workers_needed == 0: | |
| return DataProto.concat([]) | |
| chunkes = prompts.chunk(num_workers_needed) | |
| outputs = ray.get( | |
| [ | |
| self.agent_loop_workers[i % len(self.agent_loop_workers)].generate_sequences.remote(chunk) | |
| for i, chunk in enumerate(chunkes) | |
| ] | |
| ) |
There was a problem hiding this comment.
Both cases (empty workers / empty prompts) are pre-existing — the original code would also crash with chunk(0) or produce undefined behavior on empty input. These are initialization-time invariants guaranteed by the training loop, so adding a guard here would be over-engineering.
Summary
batch_size < num_workers,prompts.chunk(len(self.agent_loop_workers))produces fewer chunks than workers, causingzip(..., strict=True)to raiseValueErrormin(len(prompts), len(self.agent_loop_workers))and uses index-based worker dispatch instead of strict zipTest plan
🤖 Generated with Claude Code