Skip to content

Commit 81b0209

Browse files
Structured retriever results (#73)
* Added result_formatter to all vector, hybrid, and t2c retrievers * Added unit tests to test retrievers work with a format function * Ruff formatting and fixed weaviate e2e tests * Fixed Weaviate tests * Typo in docs * Added neo4j.Record variables in result_formatter in docstring * Update CHANGELOG --------- Co-authored-by: Will Tai <[email protected]>
1 parent 23fd585 commit 81b0209

File tree

16 files changed

+324
-31
lines changed

16 files changed

+324
-31
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- Improved developer experience by copying the docstring from the `Retriever.get_search_results` method to the `Retriever.search` method
1515
- Support for specifying database names in index handling methods and retrievers.
1616
- User Guide in documentation.
17+
- Introduced result_formatter argument to all retrievers, allowing custom formatting of retriever results.
1718

1819
### Changed
1920
- Refactored import paths for retrievers to neo4j_genai.retrievers.

docs/source/user_guide.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ Format the Results
430430

431431
.. warning::
432432

433-
This API is in beta mode and will be subject to change is the future.
433+
This API is in beta mode and will be subject to change in the future.
434434

435435
For improved readability and ease in prompt-engineering, formatting the result to suit
436436
specific needs involves providing a `record_formatter` function to the Cypher retrievers.

src/neo4j_genai/retrievers/external/pinecone/pinecone.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
EmbedderModel,
3939
Neo4jDriverModel,
4040
RawSearchResult,
41+
RetrieverResultItem,
4142
)
4243

4344
logger = logging.getLogger(__name__)
@@ -78,7 +79,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
7879
id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Weaviate to Neo4j nodes.
7980
embedder (Optional[Embedder]): Embedder object to embed query text.
8081
return_properties (Optional[list[str]]): List of node properties to return.
81-
result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem.
82+
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
8283
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
8384
8485
Raises:
@@ -94,7 +95,9 @@ def __init__(
9495
embedder: Optional[Embedder] = None,
9596
return_properties: Optional[list[str]] = None,
9697
retrieval_query: Optional[str] = None,
97-
result_formatter: Optional[Callable[[Any], Any]] = None,
98+
result_formatter: Optional[
99+
Callable[[neo4j.Record], RetrieverResultItem]
100+
] = None,
98101
neo4j_database: Optional[str] = None,
99102
):
100103
try:

src/neo4j_genai/retrievers/external/pinecone/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
field_validator,
2525
)
2626

27-
from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel
27+
from neo4j_genai.types import (
28+
EmbedderModel,
29+
Neo4jDriverModel,
30+
RetrieverResultItem,
31+
VectorSearchModel,
32+
)
2833

2934

3035
class PineconeSearchModel(VectorSearchModel):
@@ -52,5 +57,5 @@ class PineconeNeo4jRetrieverModel(BaseModel):
5257
embedder_model: Optional[EmbedderModel] = None
5358
return_properties: Optional[list[str]] = None
5459
retrieval_query: Optional[str] = None
55-
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
60+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
5661
neo4j_database: Optional[str] = None

src/neo4j_genai/retrievers/external/weaviate/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
from weaviate.client import WeaviateClient
2626
from weaviate.collections.classes.filters import _Filters
2727

28-
from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel
28+
from neo4j_genai.types import (
29+
EmbedderModel,
30+
Neo4jDriverModel,
31+
RetrieverResultItem,
32+
VectorSearchModel,
33+
)
2934

3035

3136
class WeaviateModel(BaseModel):
@@ -50,7 +55,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
5055
embedder_model: Optional[EmbedderModel]
5156
return_properties: Optional[list[str]] = None
5257
retrieval_query: Optional[str] = None
53-
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
58+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
5459
neo4j_database: Optional[str] = None
5560

5661

src/neo4j_genai/retrievers/external/weaviate/weaviate.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
WeaviateNeo4jRetrieverModel,
3232
WeaviateNeo4jSearchModel,
3333
)
34-
from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, RawSearchResult
34+
from neo4j_genai.types import (
35+
EmbedderModel,
36+
Neo4jDriverModel,
37+
RawSearchResult,
38+
RetrieverResultItem,
39+
)
3540

3641
logger = logging.getLogger(__name__)
3742

@@ -69,7 +74,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
6974
id_property_neo4j (str): The name of the Neo4j node property that's used as the identifier for relating matches from Weaviate to Neo4j nodes.
7075
embedder (Optional[Embedder]): Embedder object to embed query text.
7176
return_properties (Optional[list[str]]): List of node properties to return.
72-
result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem.
77+
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Function to transform a neo4j.Record to a RetrieverResultItem.
7378
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
7479
7580
Raises:
@@ -86,7 +91,9 @@ def __init__(
8691
embedder: Optional[Embedder] = None,
8792
return_properties: Optional[list[str]] = None,
8893
retrieval_query: Optional[str] = None,
89-
result_formatter: Optional[Callable[[Any], Any]] = None,
94+
result_formatter: Optional[
95+
Callable[[neo4j.Record], RetrieverResultItem]
96+
] = None,
9097
neo4j_database: Optional[str] = None,
9198
):
9299
try:

src/neo4j_genai/retrievers/hybrid.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@ class HybridRetriever(Retriever):
7070
embedder (Optional[Embedder]): Embedder object to embed query text.
7171
return_properties (Optional[list[str]]): List of node properties to return.
7272
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
73+
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
7374
75+
Two variables are provided in the neo4j.Record:
76+
77+
- node: Represents the node retrieved from the vector index search.
78+
- score: Denotes the similarity score.
7479
"""
7580

7681
def __init__(
@@ -80,6 +85,9 @@ def __init__(
8085
fulltext_index_name: str,
8186
embedder: Optional[Embedder] = None,
8287
return_properties: Optional[list[str]] = None,
88+
result_formatter: Optional[
89+
Callable[[neo4j.Record], RetrieverResultItem]
90+
] = None,
8391
neo4j_database: Optional[str] = None,
8492
) -> None:
8593
try:
@@ -91,6 +99,7 @@ def __init__(
9199
fulltext_index_name=fulltext_index_name,
92100
embedder_model=embedder_model,
93101
return_properties=return_properties,
102+
result_formatter=result_formatter,
94103
neo4j_database=neo4j_database,
95104
)
96105
except ValidationError as e:
@@ -107,6 +116,7 @@ def __init__(
107116
if validated_data.embedder_model
108117
else None
109118
)
119+
self.result_formatter = validated_data.result_formatter
110120

111121
def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem:
112122
"""
@@ -219,7 +229,7 @@ class HybridCypherRetriever(Retriever):
219229
fulltext_index_name (str): Fulltext index name.
220230
retrieval_query (str): Cypher query that gets appended.
221231
embedder (Optional[Embedder]): Embedder object to embed query text.
222-
result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
232+
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
223233
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
224234
225235
Raises:
@@ -233,7 +243,9 @@ def __init__(
233243
fulltext_index_name: str,
234244
retrieval_query: str,
235245
embedder: Optional[Embedder] = None,
236-
result_formatter: Optional[Callable[[Any], Any]] = None,
246+
result_formatter: Optional[
247+
Callable[[neo4j.Record], RetrieverResultItem]
248+
] = None,
237249
neo4j_database: Optional[str] = None,
238250
) -> None:
239251
try:
@@ -245,6 +257,7 @@ def __init__(
245257
fulltext_index_name=fulltext_index_name,
246258
retrieval_query=retrieval_query,
247259
embedder_model=embedder_model,
260+
result_formatter=result_formatter,
248261
neo4j_database=neo4j_database,
249262
)
250263
except ValidationError as e:
@@ -261,7 +274,7 @@ def __init__(
261274
if validated_data.embedder_model
262275
else None
263276
)
264-
self.result_formatter = result_formatter
277+
self.result_formatter = validated_data.result_formatter
265278

266279
def get_search_results(
267280
self,

src/neo4j_genai/retrievers/text2cypher.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import Optional
18+
from typing import Callable, Optional
1919

2020
import neo4j
2121
from neo4j.exceptions import CypherSyntaxError, DriverError, Neo4jError
@@ -36,6 +36,7 @@
3636
Neo4jDriverModel,
3737
Neo4jSchemaModel,
3838
RawSearchResult,
39+
RetrieverResultItem,
3940
Text2CypherRetrieverModel,
4041
Text2CypherSearchModel,
4142
)
@@ -65,6 +66,9 @@ def __init__(
6566
llm: LLMInterface,
6667
neo4j_schema: Optional[str] = None,
6768
examples: Optional[list[str]] = None,
69+
result_formatter: Optional[
70+
Callable[[neo4j.Record], RetrieverResultItem]
71+
] = None,
6872
) -> None:
6973
try:
7074
driver_model = Neo4jDriverModel(driver=driver)
@@ -77,13 +81,15 @@ def __init__(
7781
llm_model=llm_model,
7882
neo4j_schema_model=neo4j_schema_model,
7983
examples=examples,
84+
result_formatter=result_formatter,
8085
)
8186
except ValidationError as e:
8287
raise RetrieverInitializationError(e.errors()) from e
8388

8489
super().__init__(validated_data.driver_model.driver)
8590
self.llm = validated_data.llm_model.llm
8691
self.examples = validated_data.examples
92+
self.result_formatter = validated_data.result_formatter
8793
try:
8894
self.neo4j_schema = (
8995
validated_data.neo4j_schema_model.neo4j_schema

src/neo4j_genai/retrievers/vector.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ class VectorRetriever(Retriever):
6565
index_name (str): Vector index name.
6666
embedder (Optional[Embedder]): Embedder object to embed query text.
6767
return_properties (Optional[list[str]]): List of node properties to return.
68+
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
69+
70+
Two variables are provided in the neo4j.Record:
71+
72+
- node: Represents the node retrieved from the vector index search.
73+
- score: Denotes the similarity score.
74+
6875
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
6976
7077
Raises:
@@ -77,6 +84,9 @@ def __init__(
7784
index_name: str,
7885
embedder: Optional[Embedder] = None,
7986
return_properties: Optional[list[str]] = None,
87+
result_formatter: Optional[
88+
Callable[[neo4j.Record], RetrieverResultItem]
89+
] = None,
8090
neo4j_database: Optional[str] = None,
8191
) -> None:
8292
try:
@@ -87,6 +97,7 @@ def __init__(
8797
index_name=index_name,
8898
embedder_model=embedder_model,
8999
return_properties=return_properties,
100+
result_formatter=result_formatter,
90101
neo4j_database=neo4j_database,
91102
)
92103
except ValidationError as e:
@@ -102,6 +113,7 @@ def __init__(
102113
if validated_data.embedder_model
103114
else None
104115
)
116+
self.result_formatter = validated_data.result_formatter
105117
self._node_label = None
106118
self._embedding_node_property = None
107119
self._embedding_dimension = None
@@ -222,7 +234,7 @@ class VectorCypherRetriever(Retriever):
222234
index_name (str): Vector index name.
223235
retrieval_query (str): Cypher query that gets appended.
224236
embedder (Optional[Embedder]): Embedder object to embed query text.
225-
result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
237+
result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
226238
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
227239
228240
"""
@@ -233,7 +245,9 @@ def __init__(
233245
index_name: str,
234246
retrieval_query: str,
235247
embedder: Optional[Embedder] = None,
236-
result_formatter: Optional[Callable[[Any], Any]] = None,
248+
result_formatter: Optional[
249+
Callable[[neo4j.Record], RetrieverResultItem]
250+
] = None,
237251
neo4j_database: Optional[str] = None,
238252
) -> None:
239253
try:
@@ -244,6 +258,7 @@ def __init__(
244258
index_name=index_name,
245259
retrieval_query=retrieval_query,
246260
embedder_model=embedder_model,
261+
result_formatter=result_formatter,
247262
neo4j_database=neo4j_database,
248263
)
249264
except ValidationError as e:
@@ -259,7 +274,7 @@ def __init__(
259274
if validated_data.embedder_model
260275
else None
261276
)
262-
self.result_formatter = result_formatter
277+
self.result_formatter = validated_data.result_formatter
263278
self._node_label = None
264279
self._node_embedding_property = None
265280
self._embedding_dimension = None

src/neo4j_genai/types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from enum import Enum
18-
from typing import Any, Literal, Optional
18+
from typing import Any, Callable, Literal, Optional
1919

2020
import neo4j
2121
from pydantic import (
@@ -201,6 +201,7 @@ class VectorRetrieverModel(BaseModel):
201201
index_name: str
202202
embedder_model: Optional[EmbedderModel] = None
203203
return_properties: Optional[list[str]] = None
204+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
204205
neo4j_database: Optional[str] = None
205206

206207

@@ -209,6 +210,7 @@ class VectorCypherRetrieverModel(BaseModel):
209210
index_name: str
210211
retrieval_query: str
211212
embedder_model: Optional[EmbedderModel] = None
213+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
212214
neo4j_database: Optional[str] = None
213215

214216

@@ -218,6 +220,7 @@ class HybridRetrieverModel(BaseModel):
218220
fulltext_index_name: str
219221
embedder_model: Optional[EmbedderModel] = None
220222
return_properties: Optional[list[str]] = None
223+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
221224
neo4j_database: Optional[str] = None
222225

223226

@@ -227,6 +230,7 @@ class HybridCypherRetrieverModel(BaseModel):
227230
fulltext_index_name: str
228231
retrieval_query: str
229232
embedder_model: Optional[EmbedderModel] = None
233+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
230234
neo4j_database: Optional[str] = None
231235

232236

@@ -235,3 +239,4 @@ class Text2CypherRetrieverModel(BaseModel):
235239
llm_model: LLMModel
236240
neo4j_schema_model: Optional[Neo4jSchemaModel] = None
237241
examples: Optional[list[str]] = None
242+
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None

tests/unit/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from typing import Callable
1617
from unittest.mock import MagicMock, patch
1718

1819
import neo4j
@@ -25,6 +26,7 @@
2526
VectorCypherRetriever,
2627
VectorRetriever,
2728
)
29+
from neo4j_genai.types import RetrieverResultItem
2830

2931

3032
@pytest.fixture(scope="function")
@@ -84,4 +86,15 @@ def t2c_retriever(
8486

8587
@pytest.fixture(scope="function")
8688
def neo4j_record() -> neo4j.Record:
87-
return neo4j.Record({"node": "dummy-node", "score": 1.0})
89+
return neo4j.Record({"node": "dummy-node", "score": 1.0, "node_id": 123})
90+
91+
92+
@pytest.fixture(scope="function")
93+
def result_formatter() -> Callable[[neo4j.Record], RetrieverResultItem]:
94+
def format_function(record: neo4j.Record) -> RetrieverResultItem:
95+
return RetrieverResultItem(
96+
content=record.get("node"),
97+
metadata={"score": record.get("score"), "node_id": record.get("node_id")},
98+
)
99+
100+
return format_function

0 commit comments

Comments
 (0)