Skip to content

Commit

Permalink
feat: changed summary to new prompt (#1469)
Browse files Browse the repository at this point in the history
Co-authored-by: Shahules786 <[email protected]>
  • Loading branch information
jjmachan and shahules786 authored Oct 11, 2024
1 parent a4b1912 commit c06b131
Showing 1 changed file with 126 additions and 172 deletions.
298 changes: 126 additions & 172 deletions src/ragas/metrics/_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,140 +5,134 @@
from dataclasses import dataclass, field
from typing import Dict

from langchain.pydantic_v1 import BaseModel
from pydantic import BaseModel

from ragas.dataset_schema import SingleTurnSample
from ragas.llms.output_parser import RagasOutputParserOld, get_json_format_instructions
from ragas.llms.prompt import Prompt
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric
from ragas.prompt import PydanticPrompt, StringIO

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks

logger = logging.getLogger(__name__)


class ExtractKeyphrasesResponse(BaseModel):
class ExtractedKeyphrases(BaseModel):
keyphrases: t.List[str]


class GenerateQuestionsResponse(BaseModel):
class QuestionsGenerated(BaseModel):
questions: t.List[str]


class GenerateAnswersResponse(BaseModel):
class AnswersGenerated(BaseModel):
answers: t.List[str]


_output_instructions_question_generation = get_json_format_instructions(
pydantic_object=GenerateQuestionsResponse # type: ignore
)
_output_instructions_answer_generation = get_json_format_instructions(
pydantic_object=GenerateAnswersResponse # type: ignore
)
_output_instructions_keyphrase_extraction = get_json_format_instructions(
pydantic_object=ExtractKeyphrasesResponse # type: ignore
)
_output_parser_question_generation = RagasOutputParserOld(
pydantic_object=GenerateQuestionsResponse
)
_output_parser_answer_generation = RagasOutputParserOld(
pydantic_object=GenerateAnswersResponse
)
_output_parser_keyphrase_extraction = RagasOutputParserOld(
pydantic_object=ExtractKeyphrasesResponse
)


TEXT_EXTRACT_KEYPHRASES = Prompt(
name="text_extract_keyphrases",
instruction="Extract the keyphrases essential for summarizing the text.",
output_format_instruction=_output_instructions_keyphrase_extraction,
input_keys=["text"],
output_key="keyphrases",
output_type="json",
examples=[
{
"text": """JPMorgan Chase & Co. is an American multinational finance company headquartered in New York City. It is the largest bank in the United States and the world's largest by market capitalization as of 2023. Founded in 1799, it is a major provider of investment banking services, with US$3.9 trillion in total assets, and ranked #1 in the Forbes Global 2000 ranking in 2023.""",
"keyphrases": [
"JPMorgan Chase & Co.",
"American multinational finance company",
"headquartered in New York City",
"largest bank in the United States",
"world's largest bank by market capitalization",
"founded in 1799",
"major provider of investment banking services",
"US$3.9 trillion in total assets",
"ranked #1 in Forbes Global 2000 ranking",
],
}
],
)


TEXT_GENERATE_QUESTIONS = Prompt(
name="text_generate_questions",
instruction="Based on the given text and keyphrases, generate closed-ended questions that can be answered with '1' if the question can be answered using the text, or '0' if it cannot. The questions should ALWAYS result in a '1' based on the given text.",
output_format_instruction=_output_instructions_question_generation,
input_keys=["text", "keyphrases"],
output_key="questions",
output_type="json",
examples=[
{
"text": """JPMorgan Chase & Co. is an American multinational finance company headquartered in New York City. It is the largest bank in the United States and the world's largest by market capitalization as of 2023. Founded in 1799, it is a major provider of investment banking services, with US$3.9 trillion in total assets, and ranked #1 in the Forbes Global 2000 ranking in 2023.""",
"keyphrases": [
"JPMorgan Chase & Co.",
"American multinational finance company",
"headquartered in New York City",
"largest bank in the United States",
"world's largest bank by market capitalization",
"founded in 1799",
"major provider of investment banking services",
"US$3.9 trillion in total assets",
"ranked #1 in Forbes Global 2000 ranking",
],
"questions": [
"Is JPMorgan Chase & Co. an American multinational finance company?",
"Is JPMorgan Chase & Co. headquartered in New York City?",
"Is JPMorgan Chase & Co. the largest bank in the United States?",
"Is JPMorgan Chase & Co. the world's largest bank by market capitalization as of 2023?",
"Was JPMorgan Chase & Co. founded in 1799?",
"Is JPMorgan Chase & Co. a major provider of investment banking services?",
"Does JPMorgan Chase & Co. have US$3.9 trillion in total assets?",
"Was JPMorgan Chase & Co. ranked #1 in the Forbes Global 2000 ranking in 2023?",
],
}
],
)


TEXT_GENERATE_ANSWERS = Prompt(
name="text_generate_answers",
instruction="Based on the list of close-ended '1' or '0' questions, generate a JSON with key 'answers', which is a list of strings that determines whether the provided summary contains sufficient information to answer EACH question. Answers should STRICTLY be either '1' or '0'. Answer '0' if the provided summary does not contain enough information to answer the question and answer '1' if the provided summary can answer the question.",
output_format_instruction=_output_instructions_answer_generation,
input_keys=["summary", "questions"],
output_key="answers",
output_type="json",
examples=[
{
"summary": """JPMorgan Chase & Co., headquartered in New York City, is the largest bank in the US and the world's largest by market capitalization as of 2023. Founded in 1799, it offers extensive investment, private, asset management, and retail banking services, and has $3.9 trillion in assets, making it the fifth-largest bank globally. It operates the world's largest investment bank by revenue and was ranked #1 in the 2023 Forbes Global 2000.""",
"questions": [
"Is JPMorgan Chase & Co. an American multinational finance company?",
"Is JPMorgan Chase & Co. headquartered in New York City?",
"Is JPMorgan Chase & Co. the largest bank in the United States?",
"Is JPMorgan Chase & Co. the world's largest bank by market capitalization as of 2023?",
"Is JPMorgan Chase & Co. considered systemically important by the Financial Stability Board?",
"Was JPMorgan Chase & Co. founded in 1799 as the Chase Manhattan Company?",
"Is JPMorgan Chase & Co. a major provider of investment banking services?",
"Is JPMorgan Chase & Co. the fifth-largest bank in the world by assets?",
"Does JPMorgan Chase & Co. operate the largest investment bank by revenue?",
"Was JPMorgan Chase & Co. ranked #1 in the Forbes Global 2000 ranking?",
"Does JPMorgan Chase & Co. provide investment banking services?",
],
"answers": ["0", "1", "1", "1", "0", "0", "1", "1", "1", "1", "1"],
}
],
)
class ExtractKeyphrasePrompt(PydanticPrompt[StringIO, ExtractedKeyphrases]):
name: str = "extract_keyphrases"
instruction: str = "Extract keyphrases of type: Person, Organization, Location, Date/Time, Monetary Values, and Percentages."
input_model = StringIO
output_model = ExtractedKeyphrases
examples: t.List[t.Tuple[StringIO, ExtractedKeyphrases]] = [
(
StringIO(
text="Apple Inc. is a technology company based in Cupertino, California. Founded by Steve Jobs in 1976, it reached a market capitalization of $3 trillion in 2023."
),
ExtractedKeyphrases(
keyphrases=[
"Apple Inc.",
"Cupertino, California",
"Steve Jobs",
"1976",
"$3 trillion",
"2023",
]
),
)
]


class GenerateQuestionsPromptInput(BaseModel):
text: str
keyphrases: t.List[str]


class GenerateQuestionsPrompt(
PydanticPrompt[GenerateQuestionsPromptInput, QuestionsGenerated]
):
name: str = "generate_questions"
instruction: str = "Based on the given text and keyphrases, generate closed-ended questions that can be answered with '1' if the question can be answered using the text, or '0' if it cannot. The questions should ALWAYS result in a '1' based on the given text."
input_model = GenerateQuestionsPromptInput
output_model = QuestionsGenerated
examples: t.List[t.Tuple[GenerateQuestionsPromptInput, QuestionsGenerated]] = [
(
GenerateQuestionsPromptInput(
text="Apple Inc. is a technology company based in Cupertino, California. Founded by Steve Jobs in 1976, it reached a market capitalization of $3 trillion in 2023.",
keyphrases=[
"Apple Inc.",
"Cupertino, California",
"Steve Jobs",
"1976",
"$3 trillion",
"2023",
],
),
QuestionsGenerated(
questions=[
"Is Apple Inc. a technology company?",
"Is Apple Inc. based in Cupertino, California?",
"Was Apple Inc. founded by Steve Jobs?",
"Was Apple Inc. founded in 1976?",
"Did Apple Inc. reach a market capitalization of $3 trillion?",
"Did Apple Inc. reach a market capitalization of $3 trillion in 2023?",
]
),
)
]


class SummaryAndQuestions(BaseModel):
summary: str
questions: t.List[str]


class GenerateAnswersPrompt(PydanticPrompt[SummaryAndQuestions, AnswersGenerated]):
name: str = "generate_answers"
instruction: str = "Based on the list of close-ended '1' or '0' questions, generate a JSON with key 'answers', which is a list of strings that determines whether the provided summary contains sufficient information to answer EACH question. Answers should STRICTLY be either '1' or '0'. Answer '0' if the provided summary does not contain enough information to answer the question and answer '1' if the provided summary can answer the question."
input_model = SummaryAndQuestions
output_model = AnswersGenerated
examples: t.List[t.Tuple[SummaryAndQuestions, AnswersGenerated]] = [
(
SummaryAndQuestions(
summary="Apple Inc. is a technology company based in Cupertino, California. Founded by Steve Jobs in 1976, it reached a market capitalization of $3 trillion in 2023.",
questions=[
"Is Apple Inc. a technology company?",
"Is Apple Inc. based in Cupertino, California?",
"Was Apple Inc. founded by Steve Jobs?",
"Was Apple Inc. founded in 1976?",
"Did Apple Inc. reach a market capitalization of $3 trillion?",
"Did Apple Inc. reach a market capitalization of $3 trillion in 2023?",
"Is Apple Inc. a major software company?",
"Is Apple Inc. known for the iPhone?",
"Was Steve Jobs the co-founder of Apple Inc.?",
],
),
AnswersGenerated(
answers=[
"1",
"1",
"1",
"1",
"1",
"1",
"0",
"0",
"1",
]
),
)
]


@dataclass
Expand All @@ -155,14 +149,14 @@ class SummarizationScore(MetricWithLLM, SingleTurnMetric):
}
)
coeff: float = 0.5
question_generation_prompt: Prompt = field(
default_factory=lambda: TEXT_GENERATE_QUESTIONS
question_generation_prompt: PydanticPrompt = field(
default_factory=GenerateQuestionsPrompt
)
answer_generation_prompt: Prompt = field(
default_factory=lambda: TEXT_GENERATE_ANSWERS
answer_generation_prompt: PydanticPrompt = field(
default_factory=GenerateAnswersPrompt
)
extract_keyphrases_prompt: Prompt = field(
default_factory=lambda: TEXT_EXTRACT_KEYPHRASES
extract_keyphrases_prompt: PydanticPrompt = field(
default_factory=ExtractKeyphrasePrompt
)

async def _single_turn_ascore(
Expand Down Expand Up @@ -201,17 +195,11 @@ def _compute_conciseness_score(self, text, summary) -> float:

async def _extract_keyphrases(self, text: str, callbacks: Callbacks) -> t.List[str]:
assert self.llm is not None, "LLM is not initialized"
p_value = self.extract_keyphrases_prompt.format(text=text)
result = await self.llm.generate(
prompt=p_value,
callbacks=callbacks,
)
result_text = result.generations[0][0].text
response = await _output_parser_keyphrase_extraction.aparse(
result_text, p_value, self.llm, self.max_retries
)

if not response or not response.keyphrases:
response: ExtractedKeyphrases = await self.extract_keyphrases_prompt.generate(
data=StringIO(text=text), llm=self.llm, callbacks=callbacks
)
if not response:
logging.error("No keyphrases generated, unable to calculate the score.")
return []

Expand All @@ -221,20 +209,12 @@ async def _get_questions(
self, text: str, keyphrases: list[str], callbacks: Callbacks
) -> t.List[str]:
assert self.llm is not None, "LLM is not initialized"
p_value = self.question_generation_prompt.format(
text=text, keyphrases=keyphrases
)
result = await self.llm.generate(
prompt=p_value,
response: QuestionsGenerated = await self.question_generation_prompt.generate(
data=GenerateQuestionsPromptInput(text=text, keyphrases=keyphrases),
llm=self.llm,
callbacks=callbacks,
)

result_text = result.generations[0][0].text
response = await _output_parser_question_generation.aparse(
result_text, p_value, self.llm, self.max_retries
)

if not response or not response.questions:
if not response:
logging.error("No questions generated, unable to calculate the score.")
return []

Expand All @@ -244,38 +224,12 @@ async def _get_answers(
self, questions: t.List[str], summary: str, callbacks: Callbacks
) -> t.List[str]:
assert self.llm is not None, "LLM is not initialized"
p_value = self.answer_generation_prompt.format(
questions=questions, summary=summary
)
result = await self.llm.generate(
prompt=p_value,
response: AnswersGenerated = await self.answer_generation_prompt.generate(
data=SummaryAndQuestions(questions=questions, summary=summary),
llm=self.llm,
callbacks=callbacks,
)

result_text = result.generations[0][0].text
response = await _output_parser_answer_generation.aparse(
result_text, p_value, self.llm, self.max_retries
)

if not response or not response.answers:
logger.error("No answers generated, unable to calculate the score.")
return []

return response.answers

def adapt(self, language: str, cache_dir: str | None = None) -> None:
assert self.llm is not None, "set LLM before use"

logger.info(f"Adapting summarization to {language}")
self.question_generation_prompt = self.question_generation_prompt.adapt(
language, self.llm, cache_dir
)
self.answer_generation_prompt = self.answer_generation_prompt.adapt(
language, self.llm, cache_dir
)
self.answer_generation_prompt = self.answer_generation_prompt.adapt(
language, self.llm, cache_dir
)


summarization_score = SummarizationScore()

0 comments on commit c06b131

Please sign in to comment.