🐛 Describe the bug
Summary
The GPTRewardModel.forward() method crashes with a division by zero error when processing batches with 0 or 1 samples.
Environment
- File:
examples/summarize_rlhf/reward_model/reward_model.py
- Class:
GPTRewardModel
- Method:
forward()
Problem Description
When the input batch has 0 or 1 samples, the batch size calculation results in bs = 0:
bs = input_ids.shape[0] // 2 # If shape[0] = 1, then bs = 0
This causes two critical crashes:
Crash #1: Division by Zero (Line 100)
loss = loss / bs # ZeroDivisionError when bs = 0
Crash #2: Empty Tensor Stack (Lines 103-104)
chosen_end_scores = torch.stack(chosen_end_scores) # RuntimeError: stack expects non-empty list
rejected_end_scores = torch.stack(rejected_end_scores)
Reproduction Steps
from reward_model import GPTRewardModel
import torch
model = GPTRewardModel("EleutherAI/gpt-j-6B")
# Test with single sample (bs will be 0)
input_ids = torch.randint(0, 50000, (1, 512))
output = model(input_ids) # 💥 CRASH
Error output:
ZeroDivisionError: division by zero
at line 100: loss = loss / bs
Root Cause
The model expects paired inputs (chosen + rejected) and splits them:
- Input shape:
(batch_size, seq_len)
- After split:
bs = batch_size // 2
- If batch_size ∈ {0, 1}:
bs = 0 → crash
The code doesn't handle the edge case where splitting results in zero samples per category.
Proposed Solution
Add an early return guard after the batch size calculation:
bs = input_ids.shape[0] // 2
# Handle empty batch edge case
if bs == 0:
return {
"loss": torch.tensor(0.0, device=input_ids.device),
"chosen_end_scores": torch.tensor([], device=input_ids.device),
"rejected_end_scores": torch.tensor([], device=input_ids.device),
}
Benefits:
✅ Prevents division by zero
✅ Prevents empty tensor stack errors
✅ Returns consistent output format
✅ Maintains proper device placement
✅ Backward compatible (no breaking changes)
Impact
Severity: Medium-High
- Crashes training/inference on edge cases
- May occur during:
- Dataset iteration with uneven batches
- Distributed training with small batch sizes
- Testing/debugging with single samples
Affected Users: Anyone using GPTRewardModel with small or variable batch sizes
Severity: Medium-High
- Crashes training/inference on edge cases
- May occur during:
- Dataset iteration with uneven batches
- Distributed training with small batch sizes
- Testing/debugging with single samples
Affected Users: Anyone using GPTRewardModel with small or variable batch sizes
Additional Context
The fix is straightforward and adds only 10 lines of defensive code. I have the implementation ready and tested if you'd like me to submit a PR.
Affected code location:
examples/summarize_rlhf/reward_model/reward_model.py
Lines 51-57
Checklist
Which trlX version are you using?
latest
Additional system and package information
Python=3.12.3, transformers=4.57.1,NVIDIA GB10, 580.95.05, CUDA=12.1
🐛 Describe the bug
Summary
The
GPTRewardModel.forward()method crashes with a division by zero error when processing batches with 0 or 1 samples.Environment
examples/summarize_rlhf/reward_model/reward_model.pyGPTRewardModelforward()Problem Description
When the input batch has 0 or 1 samples, the batch size calculation results in
bs = 0:This causes two critical crashes:
Crash #1: Division by Zero (Line 100)
Crash #2: Empty Tensor Stack (Lines 103-104)
Reproduction Steps
Error output:
Root Cause
The model expects paired inputs (chosen + rejected) and splits them:
(batch_size, seq_len)bs = batch_size // 2bs = 0→ crashThe code doesn't handle the edge case where splitting results in zero samples per category.
Proposed Solution
Add an early return guard after the batch size calculation:
Benefits:
✅ Prevents division by zero
✅ Prevents empty tensor stack errors
✅ Returns consistent output format
✅ Maintains proper device placement
✅ Backward compatible (no breaking changes)
Impact
Severity: Medium-High
Affected Users: Anyone using GPTRewardModel with small or variable batch sizes
Severity: Medium-High
Affected Users: Anyone using GPTRewardModel with small or variable batch sizes
Additional Context
The fix is straightforward and adds only 10 lines of defensive code. I have the implementation ready and tested if you'd like me to submit a PR.
Affected code location:
Checklist
Which trlX version are you using?
latest
Additional system and package information
Python=3.12.3, transformers=4.57.1,NVIDIA GB10, 580.95.05, CUDA=12.1