Skip to content

Commit

Permalink
fix: improve inline citation parsing bump:patch
Browse files Browse the repository at this point in the history
  • Loading branch information
taprosoft committed Nov 26, 2024
1 parent f3a2a29 commit f15abdb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 19 deletions.
4 changes: 2 additions & 2 deletions libs/kotaemon/kotaemon/indices/qa/citation_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document

span_idx = span.get("idx", None)
if span_idx is not None:
to_highlight = f"【{span_idx + 1}】" + to_highlight
to_highlight = f"【{span_idx}】" + to_highlight

text += Render.highlight(
to_highlight,
elem_id=str(span_idx + 1) if span_idx is not None else None,
elem_id=str(span_idx) if span_idx is not None else None,
)
if idx < len(ss) - 1:
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]
Expand Down
105 changes: 88 additions & 17 deletions libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import re
import threading
from collections import defaultdict
from dataclasses import dataclass
from typing import Generator

import numpy as np

from kotaemon.base import AIMessage, Document, HumanMessage, SystemMessage
from kotaemon.llms import PromptTemplate

from .citation import CiteEvidence
from .citation_qa import CITATION_TIMEOUT, MAX_IMAGES, AnswerWithContextPipeline
from .format_context import EVIDENCE_MODE_FIGURE
from .utils import find_start_end_phrase
Expand Down Expand Up @@ -61,12 +61,27 @@
END_PHRASE: this shows good retrieval quality.
FINAL ANSWER
An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient【1】. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance【2】.
An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient【1】. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance【1】【2】.
QUESTION: {question}\n
ANSWER:
""" # noqa

START_ANSWER = "FINAL ANSWER"
START_CITATION = "CITATION LIST"
CITATION_PATTERN = r"citation【(\d+)】"
START_ANSWER_PATTERN = "start_phrase:"
END_ANSWER_PATTERN = "end_phrase:"


@dataclass
class InlineEvidence:
"""List of evidences to support the answer."""

start_phrase: str | None = None
end_phrase: str | None = None
idx: int | None = None


class AnswerWithInlineCitation(AnswerWithContextPipeline):
"""Answer the question based on the evidence with inline citation"""
Expand All @@ -85,15 +100,54 @@ def get_prompt(self, question, evidence, evidence_mode: int):

return prompt, evidence

def answer_to_citations(self, answer):
evidences = []
def answer_to_citations(self, answer) -> list[InlineEvidence]:
citations: list[InlineEvidence] = []
lines = answer.split("\n")
for line in lines:
for keyword in ["START_PHRASE:", "END_PHRASE:"]:
if line.startswith(keyword):
evidences.append(line[len(keyword) :].strip())

return CiteEvidence(evidences=evidences)
current_evidence = None

for line in lines:
# check citation idx using regex
match = re.match(CITATION_PATTERN, line.lower())

if match:
try:
parsed_citation_idx = int(match.group(1))
except ValueError:
parsed_citation_idx = None

# conclude the current evidence if exists
if current_evidence:
citations.append(current_evidence)
current_evidence = None

current_evidence = InlineEvidence(idx=parsed_citation_idx)
else:
for keyword in [START_ANSWER_PATTERN, END_ANSWER_PATTERN]:
if line.lower().startswith(keyword):
matched_phrase = line[len(keyword) :].strip()
if not current_evidence:
current_evidence = InlineEvidence(idx=None)

if keyword == START_ANSWER_PATTERN:
current_evidence.start_phrase = matched_phrase
else:
current_evidence.end_phrase = matched_phrase

break

if (
current_evidence
and current_evidence.end_phrase
and current_evidence.start_phrase
):
citations.append(current_evidence)
current_evidence = None

if current_evidence:
citations.append(current_evidence)

return citations

def replace_citation_with_link(self, answer: str):
# Define the regex pattern to match 【number】
Expand All @@ -114,6 +168,8 @@ def replace_citation_with_link(self, answer: str):
),
)

answer = answer.replace(START_CITATION, "")

return answer

def stream( # type: ignore
Expand Down Expand Up @@ -178,21 +234,31 @@ def mindmap_call():
# append main prompt
messages.append(HumanMessage(content=prompt))

START_ANSWER = "FINAL ANSWER"
start_of_answer = True
final_answer = ""

try:
# try streaming first
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
if START_ANSWER in output:
if not final_answer:
try:
left_over_answer = output.split(START_ANSWER)[1].lstrip()
except IndexError:
left_over_answer = ""
if left_over_answer:
out_msg.text = left_over_answer + out_msg.text

final_answer += (
out_msg.text.lstrip() if start_of_answer else out_msg.text
out_msg.text.lstrip() if not final_answer else out_msg.text
)
start_of_answer = False
yield Document(channel="chat", content=out_msg.text)

# check for the edge case of citation list is repeated
# with smaller LLMs
if START_CITATION in out_msg.text:
break

output += out_msg.text
logprobs += out_msg.logprobs
except NotImplementedError:
Expand Down Expand Up @@ -235,10 +301,15 @@ def match_evidence_with_context(self, answer, docs) -> dict[str, list[dict]]:
if not answer.metadata["citation"]:
return spans

evidences = answer.metadata["citation"].evidences
evidences = answer.metadata["citation"]

for e_id, evidence in enumerate(evidences):
start_phrase, end_phrase = evidence.start_phrase, evidence.end_phrase
evidence_idx = evidence.idx

if evidence_idx is None:
evidence_idx = e_id + 1

for start_idx in range(0, len(evidences), 2):
start_phrase, end_phrase = evidences[start_idx : start_idx + 2]
best_match = None
best_match_length = 0
best_match_doc_idx = None
Expand All @@ -259,7 +330,7 @@ def match_evidence_with_context(self, answer, docs) -> dict[str, list[dict]]:
{
"start": best_match[0],
"end": best_match[1],
"idx": start_idx // 2, # implicitly set from the start_idx
"idx": evidence_idx,
}
)
return spans

0 comments on commit f15abdb

Please sign in to comment.