|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import logging |
| 4 | +import math |
4 | 5 | import typing as t |
5 | 6 |
|
6 | | -from ragas.dataset_schema import EvaluationDataset, SingleTurnSample |
| 7 | +from ragas.dataset_schema import EvaluationDataset, EvaluationResult, SingleTurnSample |
7 | 8 | from ragas.embeddings import LlamaIndexEmbeddingsWrapper |
8 | 9 | from ragas.evaluation import evaluate as ragas_evaluate |
9 | 10 | from ragas.executor import Executor |
|
18 | 19 | BaseEmbedding as LlamaIndexEmbeddings, |
19 | 20 | ) |
20 | 21 | from llama_index.core.base.llms.base import BaseLLM as LlamaindexLLM |
| 22 | + from llama_index.core.base.response.schema import Response as LlamaIndexResponse |
21 | 23 | from llama_index.core.workflow import Event |
22 | 24 |
|
23 | 25 | from ragas.cost import TokenUsageParser |
24 | | - from ragas.evaluation import EvaluationResult |
25 | 26 |
|
26 | 27 |
|
27 | 28 | logger = logging.getLogger(__name__) |
@@ -78,12 +79,21 @@ def evaluate( |
78 | 79 | exec.submit(query_engine.aquery, q, name=f"query-{i}") |
79 | 80 |
|
80 | 81 | # get responses and retrieved contexts |
81 | | - responses: t.List[str] = [] |
82 | | - retrieved_contexts: t.List[t.List[str]] = [] |
| 82 | + responses: t.List[t.Optional[str]] = [] |
| 83 | + retrieved_contexts: t.List[t.Optional[t.List[str]]] = [] |
83 | 84 | results = exec.results() |
84 | | - for r in results: |
85 | | - responses.append(r.response) |
86 | | - retrieved_contexts.append([n.node.text for n in r.source_nodes]) |
| 85 | + for i, r in enumerate(results): |
| 86 | + # Handle failed jobs which are recorded as NaN in the executor |
| 87 | + if isinstance(r, float) and math.isnan(r): |
| 88 | + responses.append(None) |
| 89 | + retrieved_contexts.append(None) |
| 90 | + logger.warning(f"Query engine failed for query {i}: '{queries[i]}'") |
| 91 | + continue |
| 92 | + |
| 93 | + # Cast to LlamaIndex Response type for proper type checking |
| 94 | + response: LlamaIndexResponse = t.cast("LlamaIndexResponse", r) |
| 95 | + responses.append(response.response if response.response is not None else "") |
| 96 | + retrieved_contexts.append([n.get_text() for n in response.source_nodes]) |
87 | 97 |
|
88 | 98 | # append the extra information to the dataset |
89 | 99 | for i, sample in enumerate(samples): |
|
0 commit comments