diff --git a/src/neo4j_graphrag/experimental/pipeline/notification.py b/src/neo4j_graphrag/experimental/pipeline/notification.py index 94665fe2c..8a0763dac 100644 --- a/src/neo4j_graphrag/experimental/pipeline/notification.py +++ b/src/neo4j_graphrag/experimental/pipeline/notification.py @@ -86,6 +86,12 @@ class EventNotifier: def __init__(self, callbacks: list[EventCallbackProtocol]) -> None: self.callbacks = callbacks + def add_callback(self, callback: EventCallbackProtocol) -> None: + self.callbacks.append(callback) + + def remove_callback(self, callback: EventCallbackProtocol) -> None: + self.callbacks.remove(callback) + async def notify(self, event: Event) -> None: await asyncio.gather( *[c(event) for c in self.callbacks], @@ -114,6 +120,17 @@ async def notify_pipeline_finished( ) await self.notify(event) + async def notify_pipeline_failed( + self, run_id: str, message: Optional[str] = None + ) -> None: + event = PipelineEvent( + event_type=EventType.PIPELINE_FAILED, + run_id=run_id, + message=message, + payload=None, + ) + await self.notify(event) + async def notify_task_started( self, run_id: str, diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index b5536b537..bacfdf89a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -26,7 +26,6 @@ PipelineMissingDependencyError, PipelineStatusUpdateError, ) -from neo4j_graphrag.experimental.pipeline.notification import EventNotifier from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.types.orchestration import ( RunResult, @@ -52,10 +51,10 @@ class Orchestrator: (checking that all dependencies are met), and run them. """ - def __init__(self, pipeline: Pipeline): + def __init__(self, pipeline: Pipeline, run_id: Optional[str] = None): self.pipeline = pipeline - self.event_notifier = EventNotifier(pipeline.callbacks) - self.run_id = str(uuid.uuid4()) + self.event_notifier = self.pipeline.event_notifier + self.run_id = run_id or str(uuid.uuid4()) async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: """Get inputs and run a specific task. Once the task is done, @@ -265,9 +264,5 @@ async def run(self, data: dict[str, Any]) -> None: (node without any parent). Then the callback on_task_complete will handle the task dependencies. """ - await self.event_notifier.notify_pipeline_started(self.run_id, data) tasks = [self.run_task(root, data) for root in self.pipeline.roots()] await asyncio.gather(*tasks) - await self.event_notifier.notify_pipeline_finished( - self.run_id, await self.pipeline.get_final_results(self.run_id) - ) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index f6ede6b51..e0687c613 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -21,6 +21,8 @@ from timeit import default_timer from typing import Any, AsyncGenerator, Optional +import uuid + from neo4j_graphrag.utils.logging import prettify try: @@ -39,8 +41,7 @@ from neo4j_graphrag.experimental.pipeline.notification import ( Event, EventCallbackProtocol, - EventType, - PipelineEvent, + EventNotifier, ) from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator from neo4j_graphrag.experimental.pipeline.pipeline_graph import ( @@ -103,7 +104,7 @@ async def run( res = await self.execute(context, inputs) end_time = default_timer() logger.debug( - f"TASK FINISHED {self.name} in {end_time - start_time} res={prettify(res)}" + f"TASK FINISHED {self.name} in {round(end_time - start_time, 2)}s res={prettify(res)}" ) return res @@ -124,7 +125,6 @@ def __init__( ) -> None: super().__init__() self.store = store or InMemoryStore() - self.callbacks = [callback] if callback else [] self.final_results = InMemoryStore() self.is_validated = False self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict) @@ -139,6 +139,7 @@ def __init__( } """ self.missing_inputs: dict[str, list[str]] = defaultdict() + self.event_notifier = EventNotifier([callback] if callback else []) @classmethod def from_template( @@ -507,14 +508,13 @@ async def stream( """ # Create queue for events event_queue: asyncio.Queue[Event] = asyncio.Queue() - run_id = None async def event_stream(event: Event) -> None: # Put event in queue for streaming await event_queue.put(event) # Add event streaming callback - self.callbacks.append(event_stream) + self.event_notifier.add_callback(event_stream) event_queue_getter_task = None try: @@ -542,39 +542,48 @@ async def event_stream(event: Event) -> None: # we are sure to get an Event here, since this is the only # thing we put in the queue, but mypy still complains event = event_future.result() - run_id = getattr(event, "run_id", None) yield event # type: ignore if exc := run_task.exception(): - yield PipelineEvent( - event_type=EventType.PIPELINE_FAILED, - # run_id is null if pipeline fails before even starting - # ie during pipeline validation - run_id=run_id or "", - message=str(exc), - ) if raise_exception: raise exc finally: # Restore original callback - self.callbacks.remove(event_stream) + self.event_notifier.remove_callback(event_stream) if event_queue_getter_task and not event_queue_getter_task.done(): event_queue_getter_task.cancel() async def run(self, data: dict[str, Any]) -> PipelineResult: - logger.debug("PIPELINE START") start_time = default_timer() - self.invalidate() - self.validate_input_data(data) - orchestrator = Orchestrator(self) - logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") - await orchestrator.run(data) + run_id = str(uuid.uuid4()) + logger.debug(f"PIPELINE START with {run_id=}") + try: + res = await self._run(run_id, data) + except Exception as e: + await self.event_notifier.notify_pipeline_failed( + run_id, + message=f"Pipeline failed with error {e}", + ) + raise e end_time = default_timer() logger.debug( - f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s" + f"PIPELINE FINISHED {run_id} in {round(end_time - start_time, 2)}s" ) - return PipelineResult( + return res + + async def _run(self, run_id: str, data: dict[str, Any]) -> PipelineResult: + await self.event_notifier.notify_pipeline_started(run_id, data) + self.invalidate() + self.validate_input_data(data) + orchestrator = Orchestrator(self, run_id) + await orchestrator.run(data) + result = PipelineResult( run_id=orchestrator.run_id, result=await self.get_final_results(orchestrator.run_id), ) + await self.event_notifier.notify_pipeline_finished( + run_id, + await self.get_final_results(run_id), + ) + return result diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index d37fc65fc..16e23cd27 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -483,6 +483,23 @@ async def test_pipeline_event_notification() -> None: previous_ts = actual_event.timestamp +@pytest.mark.asyncio +async def test_pipeline_event_notification_error_in_pipeline_run() -> None: + callback = AsyncMock(spec=EventCallbackProtocol) + pipe = Pipeline(callback=callback) + component_a = ComponentAdd() + component_b = ComponentAdd() + pipe.add_component(component_a, "a") + pipe.add_component(component_b, "b") + pipe.connect("a", "b", {"number1": "a.result"}) + + with pytest.raises(PipelineDefinitionError): + await pipe.run({"a": {"number1": 1, "number2": 2}}) + assert len(callback.await_args_list) == 2 + assert callback.await_args_list[0][0][0].event_type == EventType.PIPELINE_STARTED + assert callback.await_args_list[1][0][0].event_type == EventType.PIPELINE_FAILED + + def test_event_model_no_warning(recwarn: Sized) -> None: event = Event( event_type=EventType.PIPELINE_STARTED, @@ -503,7 +520,7 @@ async def test_pipeline_streaming_no_user_callback_happy_path() -> None: assert len(events) == 2 assert events[0].event_type == EventType.PIPELINE_STARTED assert events[1].event_type == EventType.PIPELINE_FINISHED - assert len(pipe.callbacks) == 0 + assert len(pipe.event_notifier.callbacks) == 0 @pytest.mark.asyncio @@ -515,7 +532,7 @@ async def test_pipeline_streaming_with_user_callback_happy_path() -> None: events.append(e) assert len(events) == 2 assert len(callback.call_args_list) == 2 - assert len(pipe.callbacks) == 1 + assert len(pipe.event_notifier.callbacks) == 1 @pytest.mark.asyncio @@ -528,7 +545,7 @@ async def callback(event: Event) -> None: async for e in pipe.stream({}): events.append(e) assert len(events) == 2 - assert len(pipe.callbacks) == 1 + assert len(pipe.event_notifier.callbacks) == 1 @pytest.mark.asyncio @@ -557,11 +574,9 @@ async def test_pipeline_streaming_error_in_pipeline_definition() -> None: with pytest.raises(PipelineDefinitionError): async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}): events.append(e) - # validation happens before pipeline run actually starts - # but we have the PIPELINE_FAILED event - assert len(events) == 1 - assert events[0].event_type == EventType.PIPELINE_FAILED - assert events[0].run_id == "" + assert len(events) == 2 + assert events[0].event_type == EventType.PIPELINE_STARTED + assert events[1].event_type == EventType.PIPELINE_FAILED @pytest.mark.asyncio @@ -589,4 +604,4 @@ async def callback(event: Event) -> None: async for e in pipe.stream({}): events.append(e) assert len(events) == 2 - assert len(pipe.callbacks) == 1 + assert len(pipe.event_notifier.callbacks) == 1