diff --git a/dspy/predict/parallel.py b/dspy/predict/parallel.py index 5c183885e4..91f5c1812b 100644 --- a/dspy/predict/parallel.py +++ b/dspy/predict/parallel.py @@ -64,7 +64,15 @@ def process_pair(pair): # Execute the processing function over the execution pairs results = executor.execute(process_pair, exec_pairs) + # Populate failed examples and exceptions from the executor if self.return_failed_examples: + for failed_idx in executor.failed_indices: + if failed_idx < len(exec_pairs): + _, original_example = exec_pairs[failed_idx] + self.failed_examples.append(original_example) + if exception := executor.exceptions_map.get(failed_idx): + self.exceptions.append(exception) + return results, self.failed_examples, self.exceptions else: return results diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index d45df534ed..c32f5e3ebb 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -41,6 +41,8 @@ def __init__( self.error_count = 0 self.error_lock = threading.Lock() self.cancel_jobs = threading.Event() + self.failed_indices = [] + self.exceptions_map = {} def execute(self, function, data): tqdm.tqdm._instances.clear() @@ -62,7 +64,7 @@ def safe_func(item): logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") else: logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.") - return None + return e return safe_func @@ -155,7 +157,14 @@ def all_done(): pass else: if outcome != job_cancelled and results[index] is None: - results[index] = outcome + # Check if this is an exception + if isinstance(outcome, Exception): + with self.error_lock: + self.failed_indices.append(index) + self.exceptions_map[index] = outcome + results[index] = None # Keep None for failed examples + else: + results[index] = outcome # Update progress if self.compare_results: diff --git a/tests/predict/test_parallel.py b/tests/predict/test_parallel.py index 8b4e862713..fe8479079b 100644 --- a/tests/predict/test_parallel.py +++ b/tests/predict/test_parallel.py @@ -163,3 +163,34 @@ def forward(self, input): "test output 3", "test output 4", } + + +def test_batch_with_failed_examples(): + class FailingModule(dspy.Module): + def forward(self, value: int) -> str: + if value == 42: + raise ValueError("test error") + return f"success-{value}" + + module = FailingModule() + + examples = [ + dspy.Example(value=1).with_inputs("value"), + dspy.Example(value=42).with_inputs("value"), # This will fail + dspy.Example(value=3).with_inputs("value"), + ] + + results, failed_examples, exceptions = module.batch( + examples, + return_failed_examples=True, + provide_traceback=True, + ) + + assert results == ["success-1", None, "success-3"] + + assert len(failed_examples) == 1 + assert failed_examples[0].inputs()["value"] == 42 + + assert len(exceptions) == 1 + assert isinstance(exceptions[0], ValueError) + assert str(exceptions[0]) == "test error" diff --git a/tests/utils/test_parallelizer.py b/tests/utils/test_parallelizer.py index 28307e4ea8..128614ffc8 100644 --- a/tests/utils/test_parallelizer.py +++ b/tests/utils/test_parallelizer.py @@ -59,3 +59,27 @@ def task(item): # Verify that the results exclude the failed task assert results == [1, 2, None, 4, 5] + + +def test_parallel_executor_tracks_failed_indices_and_exceptions(): + def task(item): + if item == 3: + raise ValueError("test error for 3") + if item == 5: + raise RuntimeError("test error for 5") + return item + + data = [1, 2, 3, 4, 5] + executor = ParallelExecutor(num_threads=3, max_errors=3) + + results = executor.execute(task, data) + + assert results == [1, 2, None, 4, None] + + assert sorted(executor.failed_indices) == [2, 4] + + assert len(executor.exceptions_map) == 2 + assert isinstance(executor.exceptions_map[2], ValueError) + assert str(executor.exceptions_map[2]) == "test error for 3" + assert isinstance(executor.exceptions_map[4], RuntimeError) + assert str(executor.exceptions_map[4]) == "test error for 5"