@@ -73,26 +73,18 @@ async def afilter(self, question: str) -> bool:
73
73
74
74
@dataclass
75
75
class Evolution :
76
- ...
77
-
78
-
79
- @dataclass
80
- class SimpleEvolution (Evolution ):
81
76
node_filter : NodeFilter
82
77
question_filter : QuestionFilter
83
78
nodes : t .List [Node ] = field (default_factory = list )
84
79
max_tries : int = 5
80
+ _root_node : t .Optional [Node ] = field (default = None , init = False , repr = False )
85
81
_tries : int = field (default = 0 , init = False , repr = False )
86
82
87
83
def merged_nodes (self ) -> Node :
88
84
return Node (
89
85
doc_id = "merged" , page_content = " " .join (n .page_content for n in self .nodes )
90
86
)
91
87
92
- def evolve (self , llm : BaseRagasLLM , docstore : DocumentStore ):
93
- logger .info ("evolving question" )
94
- return asyncio .get_event_loop ().run_until_complete (self .aevolve (llm , docstore ))
95
-
96
88
async def aretry_evolve (
97
89
self , llm : BaseRagasLLM , docstore : DocumentStore , update_count : bool = True
98
90
):
@@ -104,10 +96,47 @@ async def aretry_evolve(
104
96
raise ValueError ("Max tries reached" )
105
97
return await self .aevolve (llm , docstore )
106
98
99
+ @abstractmethod
100
+ def evolve (self , llm : BaseRagasLLM , docstore : DocumentStore ) -> str :
101
+ ...
102
+
103
+ @abstractmethod
104
+ async def aevolve (self , llm : BaseRagasLLM , docstore : DocumentStore ) -> str :
105
+ ...
106
+
107
+
108
+ @dataclass
109
+ class SimpleEvolution (Evolution ):
110
+ def evolve (self , llm : BaseRagasLLM , docstore : DocumentStore ):
111
+ logger .info ("evolving question" )
112
+ return asyncio .get_event_loop ().run_until_complete (self .aevolve (llm , docstore ))
113
+
114
+ def _get_more_adjacent_nodes (self , docstore : DocumentStore ):
115
+ """
116
+ if the evolutions doesn't have enough nodes to frame a question, get more nodes
117
+ """
118
+ assert self ._root_node is not None , "root node cannot be None"
119
+ # get more nodes from above the context window
120
+ prev_adjacent_node = docstore .get_adjacent (self ._root_node , Direction .PREV )
121
+ if prev_adjacent_node is None :
122
+ # get more nodes from below the context window
123
+ next_adjacent_node = docstore .get_adjacent (self ._root_node , Direction .NEXT )
124
+ if next_adjacent_node is not None :
125
+ # add next nodes towards the end
126
+ self .nodes .append (next_adjacent_node )
127
+ else :
128
+ # retry with new base node
129
+ self .nodes = docstore .get_random_nodes (k = 1 )
130
+ self ._root_node = self .nodes [0 ]
131
+ else :
132
+ # add prev nodes in index 0
133
+ self .nodes .insert (0 , prev_adjacent_node )
134
+
107
135
async def aevolve (self , llm : BaseRagasLLM , docstore : DocumentStore ):
108
136
# can the node be used to frame a question?
109
137
if self ._tries == 0 :
110
138
self .nodes = docstore .get_random_nodes (k = 1 )
139
+ self ._root_node = self .nodes [0 ]
111
140
merged_node = self .merged_nodes ()
112
141
passed , table_is_present = await self .node_filter .afilter (self .nodes [0 ])
113
142
if not passed :
@@ -122,20 +151,7 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore):
122
151
is_valid_question = await self .question_filter .afilter (seed_question )
123
152
if not is_valid_question :
124
153
# get more context to rewrite question
125
- prev_adjacent_node = docstore .get_adjacent (self .nodes [0 ], Direction .PREV )
126
- if prev_adjacent_node is None :
127
- next_adjacent_node = docstore .get_adjacent (
128
- self .nodes [- 1 ], Direction .NEXT
129
- )
130
- if next_adjacent_node is not None :
131
- # add nodes
132
- self .nodes .append (next_adjacent_node )
133
- else :
134
- # retry with new base node
135
- self .nodes = docstore .get_random_nodes (k = 1 )
136
- else :
137
- # add prev nodes
138
- self .nodes .insert (0 , prev_adjacent_node )
154
+ self ._get_more_adjacent_nodes (docstore )
139
155
# retry with new nodes added
140
156
return await self .aretry_evolve (llm , docstore )
141
157
else :
0 commit comments