Skip to content

Replace pygraphviz with neo4j-viz for graph visualization #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
### Changed

- Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging.
- Switched from pygraphviz to neo4j-viz
- Renders interactive graph now on HTML instead of PNG
- Removed `get_pygraphviz_graph` method

## 1.6.1

Expand Down
23 changes: 12 additions & 11 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,26 @@ Pipelines can be visualized using the `draw` method:
pipe = Pipeline()
# ... define components and connections

pipe.draw("pipeline.png")
pipe.draw("pipeline.html")

Here is an example pipeline rendering:
Here is an example pipeline rendering as an interactive HTML visualization:

.. image:: images/pipeline_no_unused_outputs.png
:alt: Pipeline visualisation with hidden outputs if unused
.. code:: python

# To view the visualization in a browser
import webbrowser
webbrowser.open("pipeline.html")

By default, output fields which are not mapped to any component are hidden. They
can be added to the canvas by setting `hide_unused_outputs` to `False`:
can be added to the visualization by setting `hide_unused_outputs` to `False`:

.. code:: python

pipe.draw("pipeline.png", hide_unused_outputs=False)

Here is an example of final result:

.. image:: images/pipeline_full.png
:alt: Pipeline visualisation
pipe.draw("pipeline_full.html", hide_unused_outputs=False)

# To view the full visualization in a browser
import webbrowser
webbrowser.open("pipeline_full.html")


************************
Expand Down
4 changes: 2 additions & 2 deletions examples/customize/build_graph/pipeline/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ async def run(self, number: IntDataModel) -> IntDataModel:
pipe.connect("times_two", "addition", {"a": "times_two.value"})
pipe.connect("times_ten", "addition", {"b": "times_ten.value"})
pipe.connect("addition", "save", {"number": "addition"})
pipe.draw("graph.png")
pipe.draw("graph_full.png", hide_unused_outputs=False)
pipe.draw("graph.html")
pipe.draw("graph_full.html", hide_unused_outputs=False)
2,137 changes: 1,205 additions & 932 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 4 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ pyyaml = "^6.0.2"
types-pyyaml = "^6.0.12.20240917"
# optional deps
langchain-text-splitters = {version = "^0.3.0", optional = true }
pygraphviz = [
{version = "^1.13.0", python = ">=3.10,<4.0.0", optional = true},
{version = "^1.0.0", python = "<3.10", optional = true}
]
neo4j-viz = {version = "^0.2.2", optional = true }
weaviate-client = {version = "^4.6.1", optional = true }
pinecone-client = {version = "^4.1.0", optional = true }
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
Expand All @@ -68,6 +65,7 @@ sphinx = { version = "^7.2.6", python = "^3.9" }
langchain-openai = {version = "^0.2.2", optional = true }
langchain-huggingface = {version = "^0.1.0", optional = true }
enum-tools = {extras = ["sphinx"], version = "^0.12.0"}
neo4j-viz = "^0.2.2"

[tool.poetry.extras]
weaviate = ["weaviate-client"]
Expand All @@ -79,9 +77,9 @@ ollama = ["ollama"]
openai = ["openai"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
kg_creation_tools = ["pygraphviz"]
kg_creation_tools = ["neo4j-viz"]
sentence-transformers = ["sentence-transformers"]
experimental = ["langchain-text-splitters", "pygraphviz", "llama-index"]
experimental = ["langchain-text-splitters", "neo4j-viz", "llama-index"]
examples = ["langchain-openai", "langchain-huggingface"]
nlp = ["spacy"]
fuzzy-matching = ["rapidfuzz"]
Expand Down
54 changes: 54 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/neo4j_viz.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Union

class Node:
id: Union[str, int]
caption: Optional[str] = None
size: Optional[float] = None
properties: Optional[Dict[str, Any]] = None

def __init__(
self,
id: Union[str, int],
caption: Optional[str] = None,
size: Optional[float] = None,
properties: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None: ...

class Relationship:
source: Union[str, int]
target: Union[str, int]
caption: Optional[str] = None
properties: Optional[Dict[str, Any]] = None

def __init__(
self,
source: Union[str, int],
target: Union[str, int],
caption: Optional[str] = None,
properties: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None: ...

class VisualizationGraph:
nodes: List[Node]
relationships: List[Relationship]

def __init__(
self, nodes: List[Node], relationships: List[Relationship]
) -> None: ...
149 changes: 115 additions & 34 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional, AsyncGenerator
from typing import Any, Optional, AsyncGenerator, cast
import asyncio

from neo4j_graphrag.utils.logging import prettify

try:
import pygraphviz as pgv
from neo4j_viz import Node, Relationship, VisualizationGraph

neo4j_viz_available = True
except ImportError:
pgv = None
neo4j_viz_available = False

from pydantic import BaseModel

Expand Down Expand Up @@ -198,53 +200,132 @@ def show_as_dict(self) -> dict[str, Any]:
def draw(
self, path: str, layout: str = "dot", hide_unused_outputs: bool = True
) -> Any:
G = self.get_pygraphviz_graph(hide_unused_outputs)
G.layout(layout)
G.draw(path)
"""Render the pipeline graph to an HTML file at the specified path"""
G = self._get_neo4j_viz_graph(hide_unused_outputs)

# Write the visualization to an HTML file
with open(path, "w") as f:
f.write(G.render().data)

return G

def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
if pgv is None:
def _get_neo4j_viz_graph(
self, hide_unused_outputs: bool = True
) -> VisualizationGraph:
"""Generate a neo4j-viz visualization of the pipeline graph"""
if not neo4j_viz_available:
raise ImportError(
"Could not import pygraphviz. "
"Follow installation instruction in pygraphviz documentation "
"to get it up and running on your system."
"Could not import neo4j-viz. Install it with 'pip install \"neo4j-graphrag[experimental]\"'"
)

self.validate_parameter_mapping()
G = pgv.AGraph(strict=False, directed=True)
# create a node for each component
for n, node in self._nodes.items():
comp_inputs = ",".join(

nodes = []
relationships = []
node_ids = {} # Map node names to their numeric IDs
next_id = 0

# Create nodes for each component
for n, pipeline_node in self._nodes.items():
comp_inputs = ", ".join(
f"{i}: {d['annotation']}"
for i, d in node.component.component_inputs.items()
for i, d in pipeline_node.component.component_inputs.items()
)
G.add_node(
n,
node_type="component",
shape="rectangle",
label=f"{node.component.__class__.__name__}: {n}({comp_inputs})",

node_ids[n] = next_id
label = f"{pipeline_node.component.__class__.__name__}: {n}({comp_inputs})"

# Create Node with properties parameter
viz_node = Node( # type: ignore
id=next_id,
caption=label,
size=20,
properties={"node_type": "component"},
)
# create a node for each output field and connect them it to its component
for o in node.component.component_outputs:
# Cast the node to Any before adding it to the list
nodes.append(cast(Any, viz_node))
next_id += 1

# Create nodes for each output field
for o in pipeline_node.component.component_outputs:
param_node_name = f"{n}.{o}"
G.add_node(param_node_name, label=o, node_type="output")
G.add_edge(n, param_node_name)
# then we create the edges between a component output
# and the component it gets added to

# Skip if we're hiding unused outputs and it's not used
if hide_unused_outputs:
# Check if this output is used as a source in any parameter mapping
is_used = False
for params in self.param_mapping.values():
for mapping in params.values():
source_component = mapping["component"]
source_param_name = mapping.get("param")
if source_component == n and source_param_name == o:
is_used = True
break
if is_used:
break

if not is_used:
continue

node_ids[param_node_name] = next_id
# Create Node with properties parameter
output_node = Node( # type: ignore
id=next_id,
caption=o,
size=15,
properties={"node_type": "output"},
)
# Cast the node to Any before adding it to the list
nodes.append(cast(Any, output_node))

# Connect component to its output
# Add type ignore comment to suppress mypy errors
rel = Relationship( # type: ignore
source=node_ids[n],
target=node_ids[param_node_name],
properties={"type": "HAS_OUTPUT"},
)
relationships.append(rel)
next_id += 1

# Create edges between components based on parameter mapping
for component_name, params in self.param_mapping.items():
for param, mapping in params.items():
source_component = mapping["component"]
source_param_name = mapping.get("param")

if source_param_name:
source_output_node = f"{source_component}.{source_param_name}"
else:
source_output_node = source_component
G.add_edge(source_output_node, component_name, label=param)
# remove outputs that are not mapped
if hide_unused_outputs:
for n in G.nodes():
if n.attr["node_type"] == "output" and G.out_degree(n) == 0: # type: ignore
G.remove_node(n)
return G

if source_output_node in node_ids and component_name in node_ids:
# Add type ignore comment to suppress mypy errors
rel = Relationship( # type: ignore
source=node_ids[source_output_node],
target=node_ids[component_name],
caption=param,
properties={"type": "CONNECTS_TO"},
)
relationships.append(rel)

# Cast the constructor to Any, then cast the result back to VisualizationGraph
viz_graph = cast(Any, VisualizationGraph)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the mypy error without this cast?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, isn't the pyi file supposed to help mypy here? Otherwise why do we have this file?

nodes=nodes, relationships=relationships
)
# Cast the result back to the expected return type
return cast(VisualizationGraph, viz_graph)

def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> Any:
"""Legacy method for backward compatibility.
Uses neo4j-viz instead of pygraphviz.
"""
warnings.warn(
"get_pygraphviz_graph is deprecated, use draw instead",
DeprecationWarning,
stacklevel=2,
)
return self._get_neo4j_viz_graph(hide_unused_outputs)

def add_component(self, component: Component, name: str) -> None:
"""Add a new component. Components are uniquely identified
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/experimental/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,39 +380,39 @@ async def test_pipeline_async() -> None:
assert pipeline_result[1].result == {"add": {"result": 12}}


def test_pipeline_to_pgv() -> None:
def test_pipeline_to_viz() -> None:
pipe = Pipeline()
component_a = ComponentAdd()
component_b = ComponentMultiply()
pipe.add_component(component_a, "a")
pipe.add_component(component_b, "b")
pipe.connect("a", "b", {"number1": "a.result"})
g = pipe.get_pygraphviz_graph()
g = pipe._get_neo4j_viz_graph()
# 3 nodes:
# - 2 components 'a' and 'b'
# - 1 output 'a.result'
assert len(g.nodes()) == 3
g = pipe.get_pygraphviz_graph(hide_unused_outputs=False)
assert len(g.nodes) == 3
g = pipe._get_neo4j_viz_graph(hide_unused_outputs=False)
# 4 nodes:
# - 2 components 'a' and 'b'
# - 2 output 'a.result' and 'b.result'
assert len(g.nodes()) == 4
assert len(g.nodes) == 4


def test_pipeline_draw() -> None:
pipe = Pipeline()
pipe.add_component(ComponentAdd(), "add")
t = tempfile.NamedTemporaryFile()
t = tempfile.NamedTemporaryFile(suffix=".html")
pipe.draw(t.name)
content = t.file.read()
assert len(content) > 0


@patch("neo4j_graphrag.experimental.pipeline.pipeline.pgv", None)
def test_pipeline_draw_missing_pygraphviz_dep() -> None:
@patch("neo4j_graphrag.experimental.pipeline.pipeline.neo4j_viz_available", False)
def test_pipeline_draw_missing_neo4j_viz_dep() -> None:
pipe = Pipeline()
pipe.add_component(ComponentAdd(), "add")
t = tempfile.NamedTemporaryFile()
t = tempfile.NamedTemporaryFile(suffix=".html")
with pytest.raises(ImportError):
pipe.draw(t.name)

Expand Down
Loading