Skip to content

Commit 67c2483

Browse files
committedAug 22, 2024·
refactor - in local & global search api, stick to callable objects that return langchain chains to keep the API surface minimal & consistent
1 parent e0dc062 commit 67c2483

File tree

9 files changed

+64
-86
lines changed

9 files changed

+64
-86
lines changed
 

‎README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,16 @@ retriever = LocalSearchRetriever(
105105
artifacts=artifacts,
106106
)
107107

108-
# Get a langchain chain to do local search
109-
search_chain = make_local_search_chain(
108+
# Build the LocalSearch object
109+
local_search = LocalSearch(
110110
prompt_builder=LocalSearchPromptBuilder(),
111111
llm=make_llm_instance(llm_type, llm_model, cache_dir),
112112
retriever=retriever,
113113
)
114114

115+
# it's a callable that returns the chain
116+
search_chain = local_search()
117+
115118
# you could invoke
116119
# print(search_chain.invoke(query))
117120

‎examples/simple-app/app/query.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
KeyPointsGeneratorPromptBuilder,
4040
)
4141
from langchain_graphrag.query.local_search import (
42+
LocalSearch,
4243
LocalSearchPromptBuilder,
4344
LocalSearchRetriever,
44-
make_local_search_chain,
4545
)
4646
from langchain_graphrag.query.local_search.context_builders import (
4747
ContextBuilder,
@@ -163,13 +163,15 @@ def local_search(
163163
artifacts=artifacts,
164164
)
165165

166-
# Get a langchain chain to do local search
167-
search_chain = make_local_search_chain(
166+
local_search = LocalSearch(
168167
prompt_builder=LocalSearchPromptBuilder(),
169168
llm=make_llm_instance(llm_type, llm_model, cache_dir),
170169
retriever=retriever,
171170
)
172171

172+
# get the chain
173+
search_chain = local_search()
174+
173175
# you could invoke
174176
# print(search_chain.invoke(query))
175177

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "langchain-graphrag"
3-
version = "0.0.2-beta.9"
3+
version = "0.0.2-beta.10"
44
description = "Implementation of GraphRAG (https://arxiv.org/pdf/2404.16130)"
55
authors = [{ name = "Kapil Sachdeva", email = "notan@email.com" }]
66
dependencies = [

‎src/langchain_graphrag/query/global_search/key_points_aggregator/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""KeyPointsAggregator module."""
22

3-
from .aggregator import KeyPointsAggregator, make_key_points_aggregator_chain
3+
from .aggregator import KeyPointsAggregator
44
from .context_builder import KeyPointsContextBuilder
55
from .prompt_builder import KeyPointsAggregatorPromptBuilder
66

77
__all__ = [
8-
"make_key_points_aggregator_chain",
98
"KeyPointsAggregatorPromptBuilder",
109
"KeyPointsContextBuilder",
1110
"KeyPointsAggregator",
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
import operator
42
from functools import partial
53

@@ -28,30 +26,6 @@ def _kp_result_to_docs(
2826
return context_builder(key_points)
2927

3028

31-
def make_key_points_aggregator_chain(
32-
llm: BaseLLM,
33-
prompt_builder: PromptBuilder,
34-
context_builder: KeyPointsContextBuilder,
35-
) -> Runnable:
36-
kp_lambda = partial(_kp_result_to_docs, context_builder=context_builder)
37-
38-
prompt, output_parser = prompt_builder.build()
39-
40-
search_chain: Runnable = (
41-
{
42-
"report_data": operator.itemgetter("report_data")
43-
| RunnableLambda(kp_lambda)
44-
| _format_docs,
45-
"global_query": operator.itemgetter("global_query"),
46-
}
47-
| prompt
48-
| llm
49-
| output_parser
50-
)
51-
52-
return search_chain
53-
54-
5529
class KeyPointsAggregator:
5630
def __init__(
5731
self,
@@ -64,8 +38,19 @@ def __init__(
6438
self._context_builder = context_builder
6539

6640
def __call__(self) -> Runnable:
67-
return make_key_points_aggregator_chain(
68-
llm=self._llm,
69-
prompt_builder=self._prompt_builder,
41+
kp_lambda = partial(
42+
_kp_result_to_docs,
7043
context_builder=self._context_builder,
7144
)
45+
46+
prompt, output_parser = self._prompt_builder.build()
47+
base_chain = prompt | self._llm | output_parser
48+
49+
search_chain: Runnable = {
50+
"report_data": operator.itemgetter("report_data")
51+
| RunnableLambda(kp_lambda)
52+
| _format_docs,
53+
"global_query": operator.itemgetter("global_query"),
54+
} | base_chain
55+
56+
return search_chain

‎src/langchain_graphrag/query/global_search/key_points_generator/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Key Points generator module."""
22

33
from .context_builder import CommunityReportContextBuilder
4-
from .generator import KeyPointsGenerator, make_key_points_generator_chain
4+
from .generator import KeyPointsGenerator
55
from .prompt_builder import KeyPointsGeneratorPromptBuilder
66

77
__all__ = [
8-
"make_key_points_generator_chain",
98
"KeyPointsGeneratorPromptBuilder",
109
"CommunityReportContextBuilder",
1110
"KeyPointsGenerator",
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
from langchain_core.documents import Document
42
from langchain_core.language_models import BaseLLM
53
from langchain_core.runnables import Runnable, RunnableParallel
@@ -15,28 +13,6 @@ def _format_docs(documents: list[Document]) -> str:
1513
return context_data_str
1614

1715

18-
def make_key_points_generator_chain(
19-
llm: BaseLLM,
20-
prompt_builder: PromptBuilder,
21-
context_builder: CommunityReportContextBuilder,
22-
) -> Runnable:
23-
prompt, output_parser = prompt_builder.build()
24-
25-
documents = context_builder()
26-
27-
chains: list[Runnable] = []
28-
29-
for d in documents:
30-
d_context_data = _format_docs([d])
31-
d_prompt = prompt.partial(context_data=d_context_data)
32-
generator_chain: Runnable = d_prompt | llm | output_parser
33-
chains.append(generator_chain)
34-
35-
analysts = [f"Analayst-{i}" for i in range(1, len(chains) + 1)]
36-
37-
return RunnableParallel(dict(zip(analysts, chains, strict=True)))
38-
39-
4016
class KeyPointsGenerator:
4117
def __init__(
4218
self,
@@ -49,8 +25,18 @@ def __init__(
4925
self._context_builder = context_builder
5026

5127
def __call__(self) -> Runnable:
52-
return make_key_points_generator_chain(
53-
llm=self._llm,
54-
prompt_builder=self._prompt_builder,
55-
context_builder=self._context_builder,
56-
)
28+
prompt, output_parser = self._prompt_builder.build()
29+
30+
documents = self._context_builder()
31+
32+
chains: list[Runnable] = []
33+
34+
for d in documents:
35+
d_context_data = _format_docs([d])
36+
d_prompt = prompt.partial(context_data=d_context_data)
37+
generator_chain: Runnable = d_prompt | self._llm | output_parser
38+
chains.append(generator_chain)
39+
40+
analysts = [f"Analayst-{i}" for i in range(1, len(chains) + 1)]
41+
42+
return RunnableParallel(dict(zip(analysts, chains, strict=True)))

‎src/langchain_graphrag/query/local_search/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from .prompt_builder import LocalSearchPromptBuilder
44
from .retriever import LocalSearchRetriever
5-
from .search import make_local_search_chain
5+
from .search import LocalSearch
66

77
__all__ = [
8-
"make_local_search_chain",
8+
"LocalSearch",
99
"LocalSearchPromptBuilder",
1010
"LocalSearchRetriever",
1111
]

‎src/langchain_graphrag/query/local_search/search.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,25 @@ def _format_docs(documents: list[Document]) -> str:
1212
return context_data_str
1313

1414

15-
def make_local_search_chain(
16-
llm: BaseLLM,
17-
prompt_builder: PromptBuilder,
18-
retriever: BaseRetriever,
19-
) -> Runnable:
20-
prompt, output_parser = prompt_builder.build()
21-
22-
search_chain: Runnable = (
23-
{
24-
"context_data": retriever | _format_docs,
15+
class LocalSearch:
16+
def __init__(
17+
self,
18+
llm: BaseLLM,
19+
prompt_builder: PromptBuilder,
20+
retriever: BaseRetriever,
21+
):
22+
self._llm = llm
23+
self._prompt_builder = prompt_builder
24+
self._retriever = retriever
25+
26+
def __call__(self) -> Runnable:
27+
prompt, output_parser = self._prompt_builder.build()
28+
29+
base_chain = prompt | self._llm | output_parser
30+
31+
search_chain: Runnable = {
32+
"context_data": self._retriever | _format_docs,
2533
"local_query": RunnablePassthrough(),
26-
}
27-
| prompt
28-
| llm
29-
| output_parser
30-
)
34+
} | base_chain
3135

32-
return search_chain
36+
return search_chain

0 commit comments

Comments
 (0)
Please sign in to comment.