1
- from __future__ import annotations
2
-
3
1
from langchain_core .documents import Document
4
2
from langchain_core .language_models import BaseLLM
5
3
from langchain_core .runnables import Runnable , RunnableParallel
@@ -15,28 +13,6 @@ def _format_docs(documents: list[Document]) -> str:
15
13
return context_data_str
16
14
17
15
18
- def make_key_points_generator_chain (
19
- llm : BaseLLM ,
20
- prompt_builder : PromptBuilder ,
21
- context_builder : CommunityReportContextBuilder ,
22
- ) -> Runnable :
23
- prompt , output_parser = prompt_builder .build ()
24
-
25
- documents = context_builder ()
26
-
27
- chains : list [Runnable ] = []
28
-
29
- for d in documents :
30
- d_context_data = _format_docs ([d ])
31
- d_prompt = prompt .partial (context_data = d_context_data )
32
- generator_chain : Runnable = d_prompt | llm | output_parser
33
- chains .append (generator_chain )
34
-
35
- analysts = [f"Analayst-{ i } " for i in range (1 , len (chains ) + 1 )]
36
-
37
- return RunnableParallel (dict (zip (analysts , chains , strict = True )))
38
-
39
-
40
16
class KeyPointsGenerator :
41
17
def __init__ (
42
18
self ,
@@ -49,8 +25,18 @@ def __init__(
49
25
self ._context_builder = context_builder
50
26
51
27
def __call__ (self ) -> Runnable :
52
- return make_key_points_generator_chain (
53
- llm = self ._llm ,
54
- prompt_builder = self ._prompt_builder ,
55
- context_builder = self ._context_builder ,
56
- )
28
+ prompt , output_parser = self ._prompt_builder .build ()
29
+
30
+ documents = self ._context_builder ()
31
+
32
+ chains : list [Runnable ] = []
33
+
34
+ for d in documents :
35
+ d_context_data = _format_docs ([d ])
36
+ d_prompt = prompt .partial (context_data = d_context_data )
37
+ generator_chain : Runnable = d_prompt | self ._llm | output_parser
38
+ chains .append (generator_chain )
39
+
40
+ analysts = [f"Analayst-{ i } " for i in range (1 , len (chains ) + 1 )]
41
+
42
+ return RunnableParallel (dict (zip (analysts , chains , strict = True )))
0 commit comments