Skip to content

Commit c65ef27

Browse files
committed
refactor simplenode
1 parent 16f2c93 commit c65ef27

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

src/ragas/testset/evolutions.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,18 @@ async def afilter(self, question: str) -> bool:
7373

7474
@dataclass
7575
class Evolution:
76-
...
77-
78-
79-
@dataclass
80-
class SimpleEvolution(Evolution):
8176
node_filter: NodeFilter
8277
question_filter: QuestionFilter
8378
nodes: t.List[Node] = field(default_factory=list)
8479
max_tries: int = 5
80+
_root_node: t.Optional[Node] = field(default=None, init=False, repr=False)
8581
_tries: int = field(default=0, init=False, repr=False)
8682

8783
def merged_nodes(self) -> Node:
8884
return Node(
8985
doc_id="merged", page_content=" ".join(n.page_content for n in self.nodes)
9086
)
9187

92-
def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore):
93-
logger.info("evolving question")
94-
return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore))
95-
9688
async def aretry_evolve(
9789
self, llm: BaseRagasLLM, docstore: DocumentStore, update_count: bool = True
9890
):
@@ -104,10 +96,47 @@ async def aretry_evolve(
10496
raise ValueError("Max tries reached")
10597
return await self.aevolve(llm, docstore)
10698

99+
@abstractmethod
100+
def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str:
101+
...
102+
103+
@abstractmethod
104+
async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore) -> str:
105+
...
106+
107+
108+
@dataclass
109+
class SimpleEvolution(Evolution):
110+
def evolve(self, llm: BaseRagasLLM, docstore: DocumentStore):
111+
logger.info("evolving question")
112+
return asyncio.get_event_loop().run_until_complete(self.aevolve(llm, docstore))
113+
114+
def _get_more_adjacent_nodes(self, docstore: DocumentStore):
115+
"""
116+
if the evolutions doesn't have enough nodes to frame a question, get more nodes
117+
"""
118+
assert self._root_node is not None, "root node cannot be None"
119+
# get more nodes from above the context window
120+
prev_adjacent_node = docstore.get_adjacent(self._root_node, Direction.PREV)
121+
if prev_adjacent_node is None:
122+
# get more nodes from below the context window
123+
next_adjacent_node = docstore.get_adjacent(self._root_node, Direction.NEXT)
124+
if next_adjacent_node is not None:
125+
# add next nodes towards the end
126+
self.nodes.append(next_adjacent_node)
127+
else:
128+
# retry with new base node
129+
self.nodes = docstore.get_random_nodes(k=1)
130+
self._root_node = self.nodes[0]
131+
else:
132+
# add prev nodes in index 0
133+
self.nodes.insert(0, prev_adjacent_node)
134+
107135
async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore):
108136
# can the node be used to frame a question?
109137
if self._tries == 0:
110138
self.nodes = docstore.get_random_nodes(k=1)
139+
self._root_node = self.nodes[0]
111140
merged_node = self.merged_nodes()
112141
passed, table_is_present = await self.node_filter.afilter(self.nodes[0])
113142
if not passed:
@@ -122,20 +151,7 @@ async def aevolve(self, llm: BaseRagasLLM, docstore: DocumentStore):
122151
is_valid_question = await self.question_filter.afilter(seed_question)
123152
if not is_valid_question:
124153
# get more context to rewrite question
125-
prev_adjacent_node = docstore.get_adjacent(self.nodes[0], Direction.PREV)
126-
if prev_adjacent_node is None:
127-
next_adjacent_node = docstore.get_adjacent(
128-
self.nodes[-1], Direction.NEXT
129-
)
130-
if next_adjacent_node is not None:
131-
# add nodes
132-
self.nodes.append(next_adjacent_node)
133-
else:
134-
# retry with new base node
135-
self.nodes = docstore.get_random_nodes(k=1)
136-
else:
137-
# add prev nodes
138-
self.nodes.insert(0, prev_adjacent_node)
154+
self._get_more_adjacent_nodes(docstore)
139155
# retry with new nodes added
140156
return await self.aretry_evolve(llm, docstore)
141157
else:

tests/unit/testset_generator/test_docstore.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ def test_adjacent_nodes():
1616
store = InMemoryDocumentStore(splitter=None) # type: ignore
1717
store.nodes = [a1, a2, b]
1818

19-
assert store.get_adjascent(a1) == a2
20-
assert store.get_adjascent(a2, Direction.PREV) == a1
21-
assert store.get_adjascent(a2, Direction.NEXT) is None
22-
assert store.get_adjascent(b, Direction.PREV) is None
19+
assert store.get_adjacent(a1) == a2
20+
assert store.get_adjacent(a2, Direction.PREV) == a1
21+
assert store.get_adjacent(a2, Direction.NEXT) is None
22+
assert store.get_adjacent(b, Direction.PREV) is None
2323

2424
# raise ValueError if doc not in store
2525
c = Node(doc_id="c", page_content="c", filename="c")
26-
pytest.raises(ValueError, store.get_adjascent, c)
26+
pytest.raises(ValueError, store.get_adjacent, c)
2727

2828

2929
def create_test_nodes(with_embeddings=True):

0 commit comments

Comments
 (0)