Skip to content

Commit 1d0ed47

Browse files
committed
Move pipeline started event notification so that we have both started and failure event - add unit test
1 parent 59fb457 commit 1d0ed47

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,5 @@ async def run(self, data: dict[str, Any]) -> None:
264264
(node without any parent). Then the callback on_task_complete
265265
will handle the task dependencies.
266266
"""
267-
await self.event_notifier.notify_pipeline_started(self.run_id, data)
268267
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
269268
await asyncio.gather(*tasks)
270-
await self.event_notifier.notify_pipeline_finished(
271-
self.run_id, await self.pipeline.get_final_results(self.run_id)
272-
)

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,11 +571,16 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
571571
return res
572572

573573
async def _run(self, run_id: str, data: dict[str, Any]) -> PipelineResult:
574+
await self.event_notifier.notify_pipeline_started(run_id, data)
574575
self.invalidate()
575576
self.validate_input_data(data)
576577
orchestrator = Orchestrator(self, run_id)
577578
await orchestrator.run(data)
578-
return PipelineResult(
579+
result = PipelineResult(
579580
run_id=orchestrator.run_id,
580581
result=await self.get_final_results(orchestrator.run_id),
581582
)
583+
await self.event_notifier.notify_pipeline_finished(
584+
run_id, await self.get_final_results(run_id),
585+
)
586+
return result

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,23 @@ async def test_pipeline_event_notification() -> None:
483483
previous_ts = actual_event.timestamp
484484

485485

486+
@pytest.mark.asyncio
487+
async def test_pipeline_event_notification_error_in_pipeline_run() -> None:
488+
callback = AsyncMock(spec=EventCallbackProtocol)
489+
pipe = Pipeline(callback=callback)
490+
component_a = ComponentAdd()
491+
component_b = ComponentAdd()
492+
pipe.add_component(component_a, "a")
493+
pipe.add_component(component_b, "b")
494+
pipe.connect("a", "b", {"number1": "a.result"})
495+
496+
with pytest.raises(PipelineDefinitionError):
497+
await pipe.run({"a": {"number1": 1, "number2": 2}})
498+
assert len(callback.await_args_list) == 2
499+
assert callback.await_args_list[0][0][0].event_type == EventType.PIPELINE_STARTED
500+
assert callback.await_args_list[1][0][0].event_type == EventType.PIPELINE_FAILED
501+
502+
486503
def test_event_model_no_warning(recwarn: Sized) -> None:
487504
event = Event(
488505
event_type=EventType.PIPELINE_STARTED,
@@ -557,10 +574,9 @@ async def test_pipeline_streaming_error_in_pipeline_definition() -> None:
557574
with pytest.raises(PipelineDefinitionError):
558575
async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}):
559576
events.append(e)
560-
# validation happens before pipeline run actually starts
561-
# but we have the PIPELINE_FAILED event
562-
assert len(events) == 1
563-
assert events[0].event_type == EventType.PIPELINE_FAILED
577+
assert len(events) == 2
578+
assert events[0].event_type == EventType.PIPELINE_STARTED
579+
assert events[1].event_type == EventType.PIPELINE_FAILED
564580

565581

566582
@pytest.mark.asyncio

0 commit comments

Comments
 (0)