Skip to content

Commit d551499

Browse files
committed
Send pipeline failed events to the event callback
1 parent 03ddb3c commit d551499

File tree

4 files changed

+48
-25
lines changed

4 files changed

+48
-25
lines changed

src/neo4j_graphrag/experimental/pipeline/notification.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ class EventNotifier:
8686
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
8787
self.callbacks = callbacks
8888

89+
def add_callback(self, callback: EventCallbackProtocol) -> None:
90+
self.callbacks.append(callback)
91+
92+
def remove_callback(self, callback: EventCallbackProtocol) -> None:
93+
self.callbacks.remove(callback)
94+
8995
async def notify(self, event: Event) -> None:
9096
await asyncio.gather(
9197
*[c(event) for c in self.callbacks],
@@ -114,6 +120,17 @@ async def notify_pipeline_finished(
114120
)
115121
await self.notify(event)
116122

123+
async def notify_pipeline_failed(
124+
self, run_id: str, message: Optional[str] = None
125+
) -> None:
126+
event = PipelineEvent(
127+
event_type=EventType.PIPELINE_FAILED,
128+
run_id=run_id,
129+
message=message,
130+
payload=None,
131+
)
132+
await self.notify(event)
133+
117134
async def notify_task_started(
118135
self,
119136
run_id: str,

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ class Orchestrator:
5252
(checking that all dependencies are met), and run them.
5353
"""
5454

55-
def __init__(self, pipeline: Pipeline):
55+
def __init__(self, pipeline: Pipeline, run_id: Optional[str] = None):
5656
self.pipeline = pipeline
57-
self.event_notifier = EventNotifier(pipeline.callbacks)
58-
self.run_id = str(uuid.uuid4())
57+
self.event_notifier = self.pipeline.event_notifier
58+
self.run_id = run_id or str(uuid.uuid4())
5959

6060
async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
6161
"""Get inputs and run a specific task. Once the task is done,

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from timeit import default_timer
2222
from typing import Any, AsyncGenerator, Optional
2323

24+
import uuid
25+
2426
from neo4j_graphrag.utils.logging import prettify
2527

2628
try:
@@ -41,6 +43,7 @@
4143
EventCallbackProtocol,
4244
EventType,
4345
PipelineEvent,
46+
EventNotifier,
4447
)
4548
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
4649
from neo4j_graphrag.experimental.pipeline.pipeline_graph import (
@@ -124,7 +127,6 @@ def __init__(
124127
) -> None:
125128
super().__init__()
126129
self.store = store or InMemoryStore()
127-
self.callbacks = [callback] if callback else []
128130
self.final_results = InMemoryStore()
129131
self.is_validated = False
130132
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
@@ -139,6 +141,7 @@ def __init__(
139141
}
140142
"""
141143
self.missing_inputs: dict[str, list[str]] = defaultdict()
144+
self.event_notifier = EventNotifier([callback] if callback else [])
142145

143146
@classmethod
144147
def from_template(
@@ -514,7 +517,7 @@ async def event_stream(event: Event) -> None:
514517
await event_queue.put(event)
515518

516519
# Add event streaming callback
517-
self.callbacks.append(event_stream)
520+
self.event_notifier.add_callback(event_stream)
518521

519522
event_queue_getter_task = None
520523
try:
@@ -546,34 +549,36 @@ async def event_stream(event: Event) -> None:
546549
yield event # type: ignore
547550

548551
if exc := run_task.exception():
549-
yield PipelineEvent(
550-
event_type=EventType.PIPELINE_FAILED,
551-
# run_id is null if pipeline fails before even starting
552-
# ie during pipeline validation
553-
run_id=run_id or "",
554-
message=str(exc),
555-
)
556552
if raise_exception:
557553
raise exc
558554

559555
finally:
560556
# Restore original callback
561-
self.callbacks.remove(event_stream)
557+
self.event_notifier.remove_callback(event_stream)
562558
if event_queue_getter_task and not event_queue_getter_task.done():
563559
event_queue_getter_task.cancel()
564560

565561
async def run(self, data: dict[str, Any]) -> PipelineResult:
566-
logger.debug("PIPELINE START")
567562
start_time = default_timer()
563+
run_id = str(uuid.uuid4())
564+
logger.debug(f"PIPELINE START with {run_id=}")
565+
try:
566+
res = await self._run(run_id, data)
567+
except Exception as e:
568+
await self.event_notifier.notify_pipeline_failed(
569+
run_id,
570+
message=f"Pipeline failed with error {e}",
571+
)
572+
raise e
573+
end_time = default_timer()
574+
logger.debug(f"PIPELINE FINISHED {run_id} in {end_time - start_time}s")
575+
return res
576+
577+
async def _run(self, run_id: str, data: dict[str, Any]) -> PipelineResult:
568578
self.invalidate()
569579
self.validate_input_data(data)
570-
orchestrator = Orchestrator(self)
571-
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
580+
orchestrator = Orchestrator(self, run_id)
572581
await orchestrator.run(data)
573-
end_time = default_timer()
574-
logger.debug(
575-
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
576-
)
577582
return PipelineResult(
578583
run_id=orchestrator.run_id,
579584
result=await self.get_final_results(orchestrator.run_id),

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ async def test_pipeline_streaming_no_user_callback_happy_path() -> None:
503503
assert len(events) == 2
504504
assert events[0].event_type == EventType.PIPELINE_STARTED
505505
assert events[1].event_type == EventType.PIPELINE_FINISHED
506-
assert len(pipe.callbacks) == 0
506+
assert len(pipe.event_notifier.callbacks) == 0
507507

508508

509509
@pytest.mark.asyncio
@@ -515,7 +515,7 @@ async def test_pipeline_streaming_with_user_callback_happy_path() -> None:
515515
events.append(e)
516516
assert len(events) == 2
517517
assert len(callback.call_args_list) == 2
518-
assert len(pipe.callbacks) == 1
518+
assert len(pipe.event_notifier.callbacks) == 1
519519

520520

521521
@pytest.mark.asyncio
@@ -528,7 +528,7 @@ async def callback(event: Event) -> None:
528528
async for e in pipe.stream({}):
529529
events.append(e)
530530
assert len(events) == 2
531-
assert len(pipe.callbacks) == 1
531+
assert len(pipe.event_notifier.callbacks) == 1
532532

533533

534534
@pytest.mark.asyncio
@@ -559,9 +559,9 @@ async def test_pipeline_streaming_error_in_pipeline_definition() -> None:
559559
events.append(e)
560560
# validation happens before pipeline run actually starts
561561
# but we have the PIPELINE_FAILED event
562+
print(events)
562563
assert len(events) == 1
563564
assert events[0].event_type == EventType.PIPELINE_FAILED
564-
assert events[0].run_id == ""
565565

566566

567567
@pytest.mark.asyncio
@@ -573,6 +573,7 @@ async def test_pipeline_streaming_error_in_component() -> None:
573573
with pytest.raises(TypeError):
574574
async for e in pipe.stream({"component": {"number1": None, "number2": 2}}):
575575
events.append(e)
576+
print(events)
576577
assert len(events) == 3
577578
assert events[0].event_type == EventType.PIPELINE_STARTED
578579
assert events[1].event_type == EventType.TASK_STARTED
@@ -589,4 +590,4 @@ async def callback(event: Event) -> None:
589590
async for e in pipe.stream({}):
590591
events.append(e)
591592
assert len(events) == 2
592-
assert len(pipe.callbacks) == 1
593+
assert len(pipe.event_notifier.callbacks) == 1

0 commit comments

Comments
 (0)