Skip to content

Commit 06fe74f

Browse files
authored
Send pipeline failed events to the event callback (#454)
* Send pipeline failed events to the event callback * Ruff * Remove prints * Move pipeline started event notification so that we have both started and failure event - add unit test * ruff * Round time in debug log * Rm test code * round with 2 digits
1 parent 3353643 commit 06fe74f

File tree

4 files changed

+76
-40
lines changed

4 files changed

+76
-40
lines changed

src/neo4j_graphrag/experimental/pipeline/notification.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ class EventNotifier:
8686
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
8787
self.callbacks = callbacks
8888

89+
def add_callback(self, callback: EventCallbackProtocol) -> None:
90+
self.callbacks.append(callback)
91+
92+
def remove_callback(self, callback: EventCallbackProtocol) -> None:
93+
self.callbacks.remove(callback)
94+
8995
async def notify(self, event: Event) -> None:
9096
await asyncio.gather(
9197
*[c(event) for c in self.callbacks],
@@ -114,6 +120,17 @@ async def notify_pipeline_finished(
114120
)
115121
await self.notify(event)
116122

123+
async def notify_pipeline_failed(
124+
self, run_id: str, message: Optional[str] = None
125+
) -> None:
126+
event = PipelineEvent(
127+
event_type=EventType.PIPELINE_FAILED,
128+
run_id=run_id,
129+
message=message,
130+
payload=None,
131+
)
132+
await self.notify(event)
133+
117134
async def notify_task_started(
118135
self,
119136
run_id: str,

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
PipelineMissingDependencyError,
2727
PipelineStatusUpdateError,
2828
)
29-
from neo4j_graphrag.experimental.pipeline.notification import EventNotifier
3029
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
3130
from neo4j_graphrag.experimental.pipeline.types.orchestration import (
3231
RunResult,
@@ -52,10 +51,10 @@ class Orchestrator:
5251
(checking that all dependencies are met), and run them.
5352
"""
5453

55-
def __init__(self, pipeline: Pipeline):
54+
def __init__(self, pipeline: Pipeline, run_id: Optional[str] = None):
5655
self.pipeline = pipeline
57-
self.event_notifier = EventNotifier(pipeline.callbacks)
58-
self.run_id = str(uuid.uuid4())
56+
self.event_notifier = self.pipeline.event_notifier
57+
self.run_id = run_id or str(uuid.uuid4())
5958

6059
async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
6160
"""Get inputs and run a specific task. Once the task is done,
@@ -265,9 +264,5 @@ async def run(self, data: dict[str, Any]) -> None:
265264
(node without any parent). Then the callback on_task_complete
266265
will handle the task dependencies.
267266
"""
268-
await self.event_notifier.notify_pipeline_started(self.run_id, data)
269267
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
270268
await asyncio.gather(*tasks)
271-
await self.event_notifier.notify_pipeline_finished(
272-
self.run_id, await self.pipeline.get_final_results(self.run_id)
273-
)

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from timeit import default_timer
2222
from typing import Any, AsyncGenerator, Optional
2323

24+
import uuid
25+
2426
from neo4j_graphrag.utils.logging import prettify
2527

2628
try:
@@ -39,8 +41,7 @@
3941
from neo4j_graphrag.experimental.pipeline.notification import (
4042
Event,
4143
EventCallbackProtocol,
42-
EventType,
43-
PipelineEvent,
44+
EventNotifier,
4445
)
4546
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
4647
from neo4j_graphrag.experimental.pipeline.pipeline_graph import (
@@ -103,7 +104,7 @@ async def run(
103104
res = await self.execute(context, inputs)
104105
end_time = default_timer()
105106
logger.debug(
106-
f"TASK FINISHED {self.name} in {end_time - start_time} res={prettify(res)}"
107+
f"TASK FINISHED {self.name} in {round(end_time - start_time, 2)}s res={prettify(res)}"
107108
)
108109
return res
109110

@@ -124,7 +125,6 @@ def __init__(
124125
) -> None:
125126
super().__init__()
126127
self.store = store or InMemoryStore()
127-
self.callbacks = [callback] if callback else []
128128
self.final_results = InMemoryStore()
129129
self.is_validated = False
130130
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
@@ -139,6 +139,7 @@ def __init__(
139139
}
140140
"""
141141
self.missing_inputs: dict[str, list[str]] = defaultdict()
142+
self.event_notifier = EventNotifier([callback] if callback else [])
142143

143144
@classmethod
144145
def from_template(
@@ -507,14 +508,13 @@ async def stream(
507508
"""
508509
# Create queue for events
509510
event_queue: asyncio.Queue[Event] = asyncio.Queue()
510-
run_id = None
511511

512512
async def event_stream(event: Event) -> None:
513513
# Put event in queue for streaming
514514
await event_queue.put(event)
515515

516516
# Add event streaming callback
517-
self.callbacks.append(event_stream)
517+
self.event_notifier.add_callback(event_stream)
518518

519519
event_queue_getter_task = None
520520
try:
@@ -542,39 +542,48 @@ async def event_stream(event: Event) -> None:
542542
# we are sure to get an Event here, since this is the only
543543
# thing we put in the queue, but mypy still complains
544544
event = event_future.result()
545-
run_id = getattr(event, "run_id", None)
546545
yield event # type: ignore
547546

548547
if exc := run_task.exception():
549-
yield PipelineEvent(
550-
event_type=EventType.PIPELINE_FAILED,
551-
# run_id is null if pipeline fails before even starting
552-
# ie during pipeline validation
553-
run_id=run_id or "",
554-
message=str(exc),
555-
)
556548
if raise_exception:
557549
raise exc
558550

559551
finally:
560552
# Restore original callback
561-
self.callbacks.remove(event_stream)
553+
self.event_notifier.remove_callback(event_stream)
562554
if event_queue_getter_task and not event_queue_getter_task.done():
563555
event_queue_getter_task.cancel()
564556

565557
async def run(self, data: dict[str, Any]) -> PipelineResult:
566-
logger.debug("PIPELINE START")
567558
start_time = default_timer()
568-
self.invalidate()
569-
self.validate_input_data(data)
570-
orchestrator = Orchestrator(self)
571-
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
572-
await orchestrator.run(data)
559+
run_id = str(uuid.uuid4())
560+
logger.debug(f"PIPELINE START with {run_id=}")
561+
try:
562+
res = await self._run(run_id, data)
563+
except Exception as e:
564+
await self.event_notifier.notify_pipeline_failed(
565+
run_id,
566+
message=f"Pipeline failed with error {e}",
567+
)
568+
raise e
573569
end_time = default_timer()
574570
logger.debug(
575-
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
571+
f"PIPELINE FINISHED {run_id} in {round(end_time - start_time, 2)}s"
576572
)
577-
return PipelineResult(
573+
return res
574+
575+
async def _run(self, run_id: str, data: dict[str, Any]) -> PipelineResult:
576+
await self.event_notifier.notify_pipeline_started(run_id, data)
577+
self.invalidate()
578+
self.validate_input_data(data)
579+
orchestrator = Orchestrator(self, run_id)
580+
await orchestrator.run(data)
581+
result = PipelineResult(
578582
run_id=orchestrator.run_id,
579583
result=await self.get_final_results(orchestrator.run_id),
580584
)
585+
await self.event_notifier.notify_pipeline_finished(
586+
run_id,
587+
await self.get_final_results(run_id),
588+
)
589+
return result

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,23 @@ async def test_pipeline_event_notification() -> None:
483483
previous_ts = actual_event.timestamp
484484

485485

486+
@pytest.mark.asyncio
487+
async def test_pipeline_event_notification_error_in_pipeline_run() -> None:
488+
callback = AsyncMock(spec=EventCallbackProtocol)
489+
pipe = Pipeline(callback=callback)
490+
component_a = ComponentAdd()
491+
component_b = ComponentAdd()
492+
pipe.add_component(component_a, "a")
493+
pipe.add_component(component_b, "b")
494+
pipe.connect("a", "b", {"number1": "a.result"})
495+
496+
with pytest.raises(PipelineDefinitionError):
497+
await pipe.run({"a": {"number1": 1, "number2": 2}})
498+
assert len(callback.await_args_list) == 2
499+
assert callback.await_args_list[0][0][0].event_type == EventType.PIPELINE_STARTED
500+
assert callback.await_args_list[1][0][0].event_type == EventType.PIPELINE_FAILED
501+
502+
486503
def test_event_model_no_warning(recwarn: Sized) -> None:
487504
event = Event(
488505
event_type=EventType.PIPELINE_STARTED,
@@ -503,7 +520,7 @@ async def test_pipeline_streaming_no_user_callback_happy_path() -> None:
503520
assert len(events) == 2
504521
assert events[0].event_type == EventType.PIPELINE_STARTED
505522
assert events[1].event_type == EventType.PIPELINE_FINISHED
506-
assert len(pipe.callbacks) == 0
523+
assert len(pipe.event_notifier.callbacks) == 0
507524

508525

509526
@pytest.mark.asyncio
@@ -515,7 +532,7 @@ async def test_pipeline_streaming_with_user_callback_happy_path() -> None:
515532
events.append(e)
516533
assert len(events) == 2
517534
assert len(callback.call_args_list) == 2
518-
assert len(pipe.callbacks) == 1
535+
assert len(pipe.event_notifier.callbacks) == 1
519536

520537

521538
@pytest.mark.asyncio
@@ -528,7 +545,7 @@ async def callback(event: Event) -> None:
528545
async for e in pipe.stream({}):
529546
events.append(e)
530547
assert len(events) == 2
531-
assert len(pipe.callbacks) == 1
548+
assert len(pipe.event_notifier.callbacks) == 1
532549

533550

534551
@pytest.mark.asyncio
@@ -557,11 +574,9 @@ async def test_pipeline_streaming_error_in_pipeline_definition() -> None:
557574
with pytest.raises(PipelineDefinitionError):
558575
async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}):
559576
events.append(e)
560-
# validation happens before pipeline run actually starts
561-
# but we have the PIPELINE_FAILED event
562-
assert len(events) == 1
563-
assert events[0].event_type == EventType.PIPELINE_FAILED
564-
assert events[0].run_id == ""
577+
assert len(events) == 2
578+
assert events[0].event_type == EventType.PIPELINE_STARTED
579+
assert events[1].event_type == EventType.PIPELINE_FAILED
565580

566581

567582
@pytest.mark.asyncio
@@ -589,4 +604,4 @@ async def callback(event: Event) -> None:
589604
async for e in pipe.stream({}):
590605
events.append(e)
591606
assert len(events) == 2
592-
assert len(pipe.callbacks) == 1
607+
assert len(pipe.event_notifier.callbacks) == 1

0 commit comments

Comments
 (0)