|
16 | 16 | "Please, install it with `pip install llama_index`."
|
17 | 17 | )
|
18 | 18 |
|
| 19 | +try: |
| 20 | + from pydantic.v1 import ValidationError |
| 21 | +except ImportError: |
| 22 | + from pydantic import ValidationError |
| 23 | + |
19 | 24 | import numpy as np
|
20 | 25 | import numpy.testing as npt
|
21 | 26 | import pandas as pd
|
|
58 | 63 | "conditional": "_condition_question",
|
59 | 64 | }
|
60 | 65 |
|
| 66 | +retry_errors = ( |
| 67 | + ValidationError, |
| 68 | +) |
| 69 | + |
61 | 70 | DataRow = namedtuple(
|
62 | 71 | "DataRow",
|
63 | 72 | [
|
|
69 | 78 | ],
|
70 | 79 | )
|
71 | 80 |
|
| 81 | +Proposal = namedtuple("Proposal", ["question", "text_chunk"]) |
| 82 | + |
72 | 83 |
|
73 | 84 | @dataclass
|
74 | 85 | class TestDataset:
|
@@ -291,6 +302,70 @@ def _embed_nodes(self, nodes: t.List[BaseNode]) -> t.Dict[str, t.List[float]]:
|
291 | 302 |
|
292 | 303 | return embeddings
|
293 | 304 |
|
| 305 | + def _make_proposal( |
| 306 | + self, cur_node: BaseNode, neighbor_nodes: t.List[BaseNode], evolve_type: str |
| 307 | + ) -> t.Union[Proposal, None]: |
| 308 | + # Append multiple nodes randomly to remove chunking bias |
| 309 | + size = self.rng.integers(1, 3) |
| 310 | + nodes = ( |
| 311 | + self._get_neighbour_node(cur_node, neighbor_nodes) |
| 312 | + if size > 1 and evolve_type != "multi_context" |
| 313 | + else [cur_node] |
| 314 | + ) |
| 315 | + |
| 316 | + text_chunk = " ".join([node.get_content() for node in nodes]) |
| 317 | + score = self._filter_context(text_chunk) |
| 318 | + if not score: |
| 319 | + return None |
| 320 | + seed_question = self._seed_question(text_chunk) |
| 321 | + is_valid_question = self._filter_question(seed_question) |
| 322 | + if not is_valid_question: |
| 323 | + return None |
| 324 | + |
| 325 | + if evolve_type == "multi_context": |
| 326 | + # Find most similar chunk in same document |
| 327 | + node_embedding = self._embed_nodes([nodes[-1]]) |
| 328 | + neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes) |
| 329 | + neighbor_emb = self._embed_nodes(neighbor_nodes) |
| 330 | + |
| 331 | + _, indices = get_top_k_embeddings( |
| 332 | + list(node_embedding.values())[0], |
| 333 | + list(neighbor_emb.values()), |
| 334 | + similarity_cutoff=self.threshold / 10, |
| 335 | + ) |
| 336 | + if indices: |
| 337 | + # type cast indices from list[Any] to list[int] |
| 338 | + indices = t.cast(t.List[int], indices) |
| 339 | + best_neighbor = neighbor_nodes[indices[0]] |
| 340 | + question = self._multicontext_question( |
| 341 | + question=seed_question, |
| 342 | + context1=text_chunk, |
| 343 | + context2=best_neighbor.get_content(), |
| 344 | + ) |
| 345 | + text_chunk = "\n".join([text_chunk, best_neighbor.get_content()]) |
| 346 | + else: |
| 347 | + return None |
| 348 | + |
| 349 | + # for reasoning and conditional modes, evolve question with the |
| 350 | + # functions from question_deep_map |
| 351 | + else: |
| 352 | + evolve_fun = question_deep_map.get(evolve_type) |
| 353 | + question = ( |
| 354 | + getattr(self, evolve_fun)(seed_question, text_chunk) |
| 355 | + if evolve_fun |
| 356 | + else seed_question |
| 357 | + ) |
| 358 | + |
| 359 | + # compress question or convert into conversational questions |
| 360 | + if evolve_type != "simple": |
| 361 | + prob = self.rng.uniform(0, 1) |
| 362 | + if self.chat_qa and prob <= self.chat_qa: |
| 363 | + question = self._conversational_question(question=question) |
| 364 | + else: |
| 365 | + question = self._compress_question(question=question) |
| 366 | + |
| 367 | + return Proposal(question=question, text_chunk=text_chunk) |
| 368 | + |
294 | 369 | def generate(
|
295 | 370 | self,
|
296 | 371 | documents: t.List[LlamaindexDocument] | t.List[LangchainDocument],
|
@@ -339,64 +414,20 @@ def generate(
|
339 | 414 |
|
340 | 415 | neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id]
|
341 | 416 |
|
342 |
| - # Append multiple nodes randomly to remove chunking bias |
343 |
| - size = self.rng.integers(1, 3) |
344 |
| - nodes = ( |
345 |
| - self._get_neighbour_node(curr_node, neighbor_nodes) |
346 |
| - if size > 1 and evolve_type != "multi_context" |
347 |
| - else [curr_node] |
348 |
| - ) |
349 |
| - |
350 |
| - text_chunk = " ".join([node.get_content() for node in nodes]) |
351 |
| - score = self._filter_context(text_chunk) |
352 |
| - if not score: |
353 |
| - continue |
354 |
| - seed_question = self._seed_question(text_chunk) |
355 |
| - is_valid_question = self._filter_question(seed_question) |
356 |
| - if not is_valid_question: |
357 |
| - continue |
358 |
| - |
359 |
| - if evolve_type == "multi_context": |
360 |
| - # Find most similar chunk in same document |
361 |
| - node_embedding = self._embed_nodes([nodes[-1]]) |
362 |
| - neighbor_nodes = self._remove_nodes(neighbor_nodes, nodes) |
363 |
| - neighbor_emb = self._embed_nodes(neighbor_nodes) |
364 |
| - |
365 |
| - _, indices = get_top_k_embeddings( |
366 |
| - list(node_embedding.values())[0], |
367 |
| - list(neighbor_emb.values()), |
368 |
| - similarity_cutoff=self.threshold / 10, |
369 |
| - ) |
370 |
| - if indices: |
371 |
| - # type cast indices from list[Any] to list[int] |
372 |
| - indices = t.cast(t.List[int], indices) |
373 |
| - best_neighbor = neighbor_nodes[indices[0]] |
374 |
| - question = self._multicontext_question( |
375 |
| - question=seed_question, |
376 |
| - context1=text_chunk, |
377 |
| - context2=best_neighbor.get_content(), |
378 |
| - ) |
379 |
| - text_chunk = "\n".join([text_chunk, best_neighbor.get_content()]) |
380 |
| - else: |
381 |
| - continue |
382 |
| - |
383 |
| - # for reasoning and conditional modes, evolve question with the |
384 |
| - # functions from question_deep_map |
385 |
| - else: |
386 |
| - evolve_fun = question_deep_map.get(evolve_type) |
387 |
| - question = ( |
388 |
| - getattr(self, evolve_fun)(seed_question, text_chunk) |
389 |
| - if evolve_fun |
390 |
| - else seed_question |
| 417 | + proposal = None |
| 418 | + try: |
| 419 | + proposal = self._make_proposal( |
| 420 | + curr_node, neighbor_nodes, evolve_type |
391 | 421 | )
|
| 422 | + except Exception as e: |
| 423 | + err_cause = e.__cause__ |
| 424 | + if not isinstance(err_cause, retry_errors): |
| 425 | + raise e |
392 | 426 |
|
393 |
| - # compress question or convert into conversational questions |
394 |
| - if evolve_type != "simple": |
395 |
| - prob = self.rng.uniform(0, 1) |
396 |
| - if self.chat_qa and prob <= self.chat_qa: |
397 |
| - question = self._conversational_question(question=question) |
398 |
| - else: |
399 |
| - question = self._compress_question(question=question) |
| 427 | + if proposal is None: |
| 428 | + continue |
| 429 | + question = proposal.question |
| 430 | + text_chunk = proposal.text_chunk |
400 | 431 |
|
401 | 432 | is_valid_question = self._filter_question(question)
|
402 | 433 | if is_valid_question:
|
|
0 commit comments