Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dspy/predict/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ def process_pair(pair):
# Execute the processing function over the execution pairs
results = executor.execute(process_pair, exec_pairs)

# Populate failed examples and exceptions from the executor
if self.return_failed_examples:
for failed_idx in executor.failed_indices:
if failed_idx < len(exec_pairs):
_, original_example = exec_pairs[failed_idx]
self.failed_examples.append(original_example)
if exception := executor.exceptions_map.get(failed_idx):
self.exceptions.append(exception)

return results, self.failed_examples, self.exceptions
else:
return results
Expand Down
13 changes: 11 additions & 2 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(
self.error_count = 0
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()
self.failed_indices = []
self.exceptions_map = {}

def execute(self, function, data):
tqdm.tqdm._instances.clear()
Expand All @@ -62,7 +64,7 @@ def safe_func(item):
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
else:
logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.")
return None
return e

return safe_func

Expand Down Expand Up @@ -155,7 +157,14 @@ def all_done():
pass
else:
if outcome != job_cancelled and results[index] is None:
results[index] = outcome
# Check if this is an exception
if isinstance(outcome, Exception):
with self.error_lock:
self.failed_indices.append(index)
self.exceptions_map[index] = outcome
results[index] = None # Keep None for failed examples
else:
results[index] = outcome

# Update progress
if self.compare_results:
Expand Down
31 changes: 31 additions & 0 deletions tests/predict/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,34 @@ def forward(self, input):
"test output 3",
"test output 4",
}


def test_batch_with_failed_examples():
class FailingModule(dspy.Module):
def forward(self, value: int) -> str:
if value == 42:
raise ValueError("test error")
return f"success-{value}"

module = FailingModule()

examples = [
dspy.Example(value=1).with_inputs("value"),
dspy.Example(value=42).with_inputs("value"), # This will fail
dspy.Example(value=3).with_inputs("value"),
]

results, failed_examples, exceptions = module.batch(
examples,
return_failed_examples=True,
provide_traceback=True,
)

assert results == ["success-1", None, "success-3"]

assert len(failed_examples) == 1
assert failed_examples[0].inputs()["value"] == 42

assert len(exceptions) == 1
assert isinstance(exceptions[0], ValueError)
assert str(exceptions[0]) == "test error"
24 changes: 24 additions & 0 deletions tests/utils/test_parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,27 @@ def task(item):

# Verify that the results exclude the failed task
assert results == [1, 2, None, 4, 5]


def test_parallel_executor_tracks_failed_indices_and_exceptions():
def task(item):
if item == 3:
raise ValueError("test error for 3")
if item == 5:
raise RuntimeError("test error for 5")
return item

data = [1, 2, 3, 4, 5]
executor = ParallelExecutor(num_threads=3, max_errors=3)

results = executor.execute(task, data)

assert results == [1, 2, None, 4, None]

assert sorted(executor.failed_indices) == [2, 4]

assert len(executor.exceptions_map) == 2
assert isinstance(executor.exceptions_map[2], ValueError)
assert str(executor.exceptions_map[2]) == "test error for 3"
assert isinstance(executor.exceptions_map[4], RuntimeError)
assert str(executor.exceptions_map[4]) == "test error for 5"