Skip to content

Expose all components in SimpleKGPipeline init and config #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function.
- Exposed `schema_builder`, `chunk_embedder`, `extractor` and `resolver` in the `SimpleKGPipeline` constructor so that they can be customized.


## 1.6.0
Expand Down
8 changes: 8 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ Neo4jWriter
.. autoclass:: neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter
:members: run


EntityResolver
==============

.. autoclass:: neo4j_graphrag.experimental.components.resolver.EntityResolver
:members: run


SinglePropertyExactMatchResolver
================================

Expand Down
24 changes: 23 additions & 1 deletion docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,13 @@ For advanced customization or when using a custom implementation, you can pass
instances of specific components to the `SimpleKGPipeline`. The components that can
customized at the moment are:

- `text_splitter`: must be an instance of :ref:`TextSplitter`
- `pdf_loader`: must be an instance of :ref:`PdfLoader`
- `schema_builder`: must be an instance of :ref:`SchemaBuilder`
- `text_splitter`: must be an instance of :ref:`TextSplitter`
- `chunk_embedder`: must be an instance of :ref:`TextChunkEmbedder`
- `extractor`: must be an instance of :ref:`EntityRelationExtractor`
- `kg_writer`: must be an instance of :ref:`KGWriter`
- `resolver`: must be an instance of :ref:`EntityResolver`

For instance, the following code can be used to customize the chunk size and
chunk overlap in the text splitter component:
Expand All @@ -200,6 +204,24 @@ chunk overlap in the text splitter component:
)


.. warning::

When providing a custom component, all other related parameters in the SimpleKGPipeline constructor are ignored. For instance, in the following example:

.. code:: python

kg_builder = SimpleKGPipeline(
# ...
writer=Neo4jKGWriter(neo4j_database="db_1"),
neo4j_database="db_2",
# ...
)


The graph will be saved to the **db_1** database.



Using a Config file
===================

Expand Down
28 changes: 18 additions & 10 deletions src/neo4j_graphrag/experimental/pipeline/config/object_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
return self.root.parse(resolved_data)


class ComponentConfig(ObjectConfig[Component]):
ComponentGeneric = TypeVar("ComponentGeneric")


class ComponentConfig(ObjectConfig[ComponentGeneric], Generic[ComponentGeneric]):
"""A config model for all components.

In addition to the object config, components can have pre-defined parameters
Expand All @@ -256,22 +259,27 @@ class ComponentConfig(ObjectConfig[Component]):
DEFAULT_MODULE = "neo4j_graphrag.experimental.components"
INTERFACE = Component

model_config = ConfigDict(arbitrary_types_allowed=True)

def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]:
self._global_data = resolved_data
return self.resolve_params(self.run_params_)


class ComponentType(RootModel): # type: ignore[type-arg]
root: Union[Component, ComponentConfig]
class ComponentType(
RootModel[Union[ComponentGeneric, ComponentConfig[ComponentGeneric]]],
Generic[ComponentGeneric],
):
root: Union[ComponentGeneric, ComponentConfig[ComponentGeneric]]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> Component:
if isinstance(self.root, Component):
return self.root
return self.root.parse(resolved_data)
def parse(self, resolved_data: dict[str, Any] | None = None) -> ComponentGeneric:
if isinstance(self.root, ComponentConfig):
return self.root.parse(resolved_data)
return self.root

def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]:
if isinstance(self.root, Component):
return {}
return self.root.get_run_params(resolved_data)
if isinstance(self.root, ComponentConfig):
return self.root.get_run_params(resolved_data)
return {}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydantic import field_validator

from neo4j_graphrag.embeddings import Embedder
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig
from neo4j_graphrag.experimental.pipeline.config.object_config import (
ComponentType,
Expand Down Expand Up @@ -83,7 +84,7 @@ def validate_embedders(
return embedders

def _resolve_component_definition(
self, name: str, config: ComponentType
self, name: str, config: ComponentType[Component]
) -> ComponentDefinition:
component = config.parse(self._global_data)
if hasattr(config.root, "run_params_"):
Expand Down Expand Up @@ -188,7 +189,7 @@ class PipelineConfig(AbstractPipelineConfig):
"""Configuration class for raw pipelines.
Config must contain all components and connections."""

component_config: dict[str, ComponentType]
component_config: dict[str, ComponentType[Component]]
connection_config: list[ConnectionDefinition]
template_: Literal[PipelineType.NONE] = PipelineType.NONE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
OnError,
)
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader, DataLoader
from neo4j_graphrag.experimental.components.resolver import (
EntityResolver,
SinglePropertyExactMatchResolver,
Expand Down Expand Up @@ -81,17 +81,21 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
lexical_graph_config: Optional[LexicalGraphConfig] = None
neo4j_database: Optional[str] = None

pdf_loader: Optional[ComponentType] = None
kg_writer: Optional[ComponentType] = None
text_splitter: Optional[ComponentType] = None
pdf_loader: Optional[ComponentType[DataLoader]] = None
schema_builder: Optional[ComponentType[SchemaBuilder]] = None
text_splitter: Optional[ComponentType[TextSplitter]] = None
chunk_embedder: Optional[ComponentType[TextChunkEmbedder]] = None
extractor: Optional[ComponentType[EntityRelationExtractor]] = None
kg_writer: Optional[ComponentType[KGWriter]] = None
resolver: Optional[ComponentType[EntityResolver]] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

def _get_pdf_loader(self) -> Optional[PdfLoader]:
def _get_pdf_loader(self) -> Optional[DataLoader]:
if not self.from_pdf:
return None
if self.pdf_loader:
return self.pdf_loader.parse(self._global_data) # type: ignore
return self.pdf_loader.parse(self._global_data)
return PdfLoader()

def _get_run_params_for_pdf_loader(self) -> dict[str, Any]:
Expand All @@ -103,7 +107,7 @@ def _get_run_params_for_pdf_loader(self) -> dict[str, Any]:

def _get_splitter(self) -> TextSplitter:
if self.text_splitter:
return self.text_splitter.parse(self._global_data) # type: ignore
return self.text_splitter.parse(self._global_data)
return FixedSizeSplitter()

def _get_run_params_for_splitter(self) -> dict[str, Any]:
Expand All @@ -112,9 +116,13 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]:
return {}

def _get_chunk_embedder(self) -> TextChunkEmbedder:
if self.chunk_embedder:
return self.chunk_embedder.parse(self._global_data)
return TextChunkEmbedder(embedder=self.get_default_embedder())

def _get_schema(self) -> SchemaBuilder:
if self.schema_builder:
return self.schema_builder.parse(self._global_data)
return SchemaBuilder()

def _get_run_params_for_schema(self) -> dict[str, Any]:
Expand All @@ -125,6 +133,8 @@ def _get_run_params_for_schema(self) -> dict[str, Any]:
}

def _get_extractor(self) -> EntityRelationExtractor:
if self.extractor:
return self.extractor.parse(self._global_data)
return LLMEntityRelationExtractor(
llm=self.get_default_llm(),
prompt_template=self.prompt_template,
Expand All @@ -134,7 +144,7 @@ def _get_extractor(self) -> EntityRelationExtractor:

def _get_writer(self) -> KGWriter:
if self.kg_writer:
return self.kg_writer.parse(self._global_data) # type: ignore
return self.kg_writer.parse(self._global_data)
return Neo4jWriter(
driver=self.get_default_neo4j_driver(),
neo4j_database=self.neo4j_database,
Expand All @@ -148,6 +158,8 @@ def _get_run_params_for_writer(self) -> dict[str, Any]:
def _get_resolver(self) -> Optional[EntityResolver]:
if not self.perform_entity_resolution:
return None
if self.resolver:
return self.resolver.parse(self._global_data)
return SinglePropertyExactMatchResolver(
driver=self.get_default_neo4j_driver(),
neo4j_database=self.neo4j_database,
Expand Down
24 changes: 21 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@
from pydantic import ValidationError

from neo4j_graphrag.embeddings import Embedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
OnError,
EntityRelationExtractor,
)
from neo4j_graphrag.experimental.components.kg_writer import KGWriter
from neo4j_graphrag.experimental.components.pdf_loader import DataLoader
from neo4j_graphrag.experimental.components.resolver import EntityResolver
from neo4j_graphrag.experimental.components.schema import SchemaBuilder
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
Expand Down Expand Up @@ -82,9 +88,13 @@ def __init__(
relations: Optional[Sequence[RelationInputType]] = None,
potential_schema: Optional[List[tuple[str, str, str]]] = None,
from_pdf: bool = True,
text_splitter: Optional[TextSplitter] = None,
pdf_loader: Optional[DataLoader] = None,
schema_builder: Optional[SchemaBuilder] = None,
text_splitter: Optional[TextSplitter] = None,
chunk_embedder: Optional[TextChunkEmbedder] = None,
extractor: Optional[EntityRelationExtractor] = None,
kg_writer: Optional[KGWriter] = None,
resolver: Optional[EntityResolver] = None,
on_error: str = "IGNORE",
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
perform_entity_resolution: bool = True,
Expand All @@ -102,8 +112,16 @@ def __init__(
potential_schema=potential_schema,
from_pdf=from_pdf,
pdf_loader=ComponentType(pdf_loader) if pdf_loader else None,
kg_writer=ComponentType(kg_writer) if kg_writer else None,
schema_builder=ComponentType(schema_builder)
if schema_builder
else None,
text_splitter=ComponentType(text_splitter) if text_splitter else None,
chunk_embedder=ComponentType(chunk_embedder)
if chunk_embedder
else None,
extractor=ComponentType(extractor) if extractor else None,
kg_writer=ComponentType(kg_writer) if kg_writer else None,
resolver=ComponentType(resolver) if resolver else None,
on_error=OnError(on_error),
prompt_template=prompt_template,
perform_entity_resolution=perform_entity_resolution,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
)
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.components.resolver import (
SinglePropertyExactMatchResolver,
)
from neo4j_graphrag.experimental.components.schema import (
SchemaBuilder,
SchemaEntity,
Expand Down Expand Up @@ -71,7 +74,7 @@ def test_simple_kg_pipeline_config_pdf_loader_class_overwrite_but_from_pdf_is_fa
def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite_from_config(
mock_component_parse: Mock,
) -> None:
my_pdf_loader_config = ComponentConfig(
my_pdf_loader_config: ComponentConfig[PdfLoader] = ComponentConfig(
class_="",
)
my_pdf_loader = PdfLoader()
Expand All @@ -92,7 +95,7 @@ def test_simple_kg_pipeline_config_text_splitter() -> None:
def test_simple_kg_pipeline_config_text_splitter_overwrite(
mock_component_parse: Mock,
) -> None:
my_text_splitter_config = ComponentConfig(
my_text_splitter_config: ComponentConfig[FixedSizeSplitter] = ComponentConfig(
class_="",
)
my_text_splitter = FixedSizeSplitter()
Expand Down Expand Up @@ -152,6 +155,28 @@ def test_simple_kg_pipeline_config_extractor(mock_llm: Mock, llm: LLMInterface)
assert extractor.prompt_template.template == "my template {text}"


@patch(
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
)
@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
def test_simple_kg_pipeline_config_extractor_overwrite(
mock_component_parse: Mock, mock_llm: Mock
) -> None:
my_extractor = LLMEntityRelationExtractor(llm=mock_llm)
mock_component_parse.return_value = my_extractor
config = SimpleKGPipelineConfig(
on_error="IGNORE", # type: ignore
prompt_template=ERExtractionTemplate(template="my template {text}"),
extractor={}, # type: ignore
)
extractor = config._get_extractor()
assert isinstance(extractor, LLMEntityRelationExtractor)
assert extractor.llm == mock_llm
# default values are not overwritten by the parameters:
assert extractor.on_error == OnError.RAISE
assert extractor.prompt_template.template == ERExtractionTemplate.DEFAULT_TEMPLATE


@patch(
"neo4j_graphrag.experimental.components.kg_writer.get_version",
return_value=((5, 23, 0), False, False),
Expand Down Expand Up @@ -184,13 +209,10 @@ def test_simple_kg_pipeline_config_writer_overwrite(
_: Mock,
driver: neo4j.Driver,
) -> None:
my_writer_config = ComponentConfig(
class_="",
)
my_writer = Neo4jWriter(driver, neo4j_database="my_db")
mock_component_parse.return_value = my_writer
config = SimpleKGPipelineConfig(
kg_writer=my_writer_config, # type: ignore
kg_writer={}, # type: ignore
neo4j_database="my_other_db",
)
writer: Neo4jWriter = config._get_writer() # type: ignore
Expand All @@ -199,6 +221,42 @@ def test_simple_kg_pipeline_config_writer_overwrite(
assert writer.neo4j_database == "my_db"


@patch(
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
)
@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
def test_simple_kg_pipeline_config_resolver_overwrite(
mock_component_parse: Mock, driver: neo4j.Driver
) -> None:
my_resolver = SinglePropertyExactMatchResolver(driver, resolve_property="full_name")
mock_component_parse.return_value = my_resolver
config = SimpleKGPipelineConfig(
perform_entity_resolution=True,
resolver={}, # type: ignore
)
resolver = config._get_resolver()
assert isinstance(resolver, SinglePropertyExactMatchResolver)
assert resolver.driver == driver
assert resolver.resolve_property == "full_name"


@patch(
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
)
@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
def test_simple_kg_pipeline_config_resolver_overwrite_but_disabled(
mock_component_parse: Mock, driver: neo4j.Driver
) -> None:
my_resolver = SinglePropertyExactMatchResolver(driver, resolve_property="full_name")
mock_component_parse.return_value = my_resolver
config = SimpleKGPipelineConfig(
perform_entity_resolution=False,
resolver={}, # type: ignore
)
resolver = config._get_resolver()
assert resolver is None


def test_simple_kg_pipeline_config_connections_from_pdf() -> None:
config = SimpleKGPipelineConfig(
from_pdf=True,
Expand Down Expand Up @@ -234,7 +292,7 @@ def test_simple_kg_pipeline_config_connections_from_text() -> None:
assert (actual.start, actual.end) == expected


def test_simple_kg_pipeline_config_connections_with_er() -> None:
def test_simple_kg_pipeline_config_connections_with_entity_resolution() -> None:
config = SimpleKGPipelineConfig(
from_pdf=True,
perform_entity_resolution=True,
Expand Down
Loading
Loading