-
-
Notifications
You must be signed in to change notification settings - Fork 851
/
self_rag.py
170 lines (132 loc) · 7.3 KB
/
self_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import sys
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
sys.path.append(os.path.abspath(
os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path since we work with notebooks
from helper_functions import *
from evaluation.evalute_rag import *
# Load environment variables from a .env file
load_dotenv()
# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# Define all relevant classes/functions
class RetrievalResponse(BaseModel):
response: str = Field(..., title="Determines if retrieval is necessary", description="Output only 'Yes' or 'No'.")
class RelevanceResponse(BaseModel):
response: str = Field(..., title="Determines if context is relevant",
description="Output only 'Relevant' or 'Irrelevant'.")
class GenerationResponse(BaseModel):
response: str = Field(..., title="Generated response", description="The generated response.")
class SupportResponse(BaseModel):
response: str = Field(..., title="Determines if response is supported",
description="Output 'Fully supported', 'Partially supported', or 'No support'.")
class UtilityResponse(BaseModel):
response: int = Field(..., title="Utility rating", description="Rate the utility of the response from 1 to 5.")
# Define prompt templates
retrieval_prompt = PromptTemplate(
input_variables=["query"],
template="Given the query '{query}', determine if retrieval is necessary. Output only 'Yes' or 'No'."
)
relevance_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Given the query '{query}' and the context '{context}', determine if the context is relevant. Output only 'Relevant' or 'Irrelevant'."
)
generation_prompt = PromptTemplate(
input_variables=["query", "context"],
template="Given the query '{query}' and the context '{context}', generate a response."
)
support_prompt = PromptTemplate(
input_variables=["response", "context"],
template="Given the response '{response}' and the context '{context}', determine if the response is supported by the context. Output 'Fully supported', 'Partially supported', or 'No support'."
)
utility_prompt = PromptTemplate(
input_variables=["query", "response"],
template="Given the query '{query}' and the response '{response}', rate the utility of the response from 1 to 5."
)
# Define main class
class SelfRAG:
def __init__(self, path, top_k=3):
self.vectorstore = encode_pdf(path)
self.top_k = top_k
self.llm = ChatOpenAI(model="gpt-4o-mini", max_tokens=1000, temperature=0)
# Create LLMChains for each step
self.retrieval_chain = retrieval_prompt | self.llm.with_structured_output(RetrievalResponse)
self.relevance_chain = relevance_prompt | self.llm.with_structured_output(RelevanceResponse)
self.generation_chain = generation_prompt | self.llm.with_structured_output(GenerationResponse)
self.support_chain = support_prompt | self.llm.with_structured_output(SupportResponse)
self.utility_chain = utility_prompt | self.llm.with_structured_output(UtilityResponse)
def run(self, query):
print(f"\nProcessing query: {query}")
# Step 1: Determine if retrieval is necessary
print("Step 1: Determining if retrieval is necessary...")
input_data = {"query": query}
retrieval_decision = self.retrieval_chain.invoke(input_data).response.strip().lower()
print(f"Retrieval decision: {retrieval_decision}")
if retrieval_decision == 'yes':
# Step 2: Retrieve relevant documents
print("Step 2: Retrieving relevant documents...")
docs = self.vectorstore.similarity_search(query, k=self.top_k)
contexts = [doc.page_content for doc in docs]
print(f"Retrieved {len(contexts)} documents")
# Step 3: Evaluate relevance of retrieved documents
print("Step 3: Evaluating relevance of retrieved documents...")
relevant_contexts = []
for i, context in enumerate(contexts):
input_data = {"query": query, "context": context}
relevance = self.relevance_chain.invoke(input_data).response.strip().lower()
print(f"Document {i + 1} relevance: {relevance}")
if relevance == 'relevant':
relevant_contexts.append(context)
print(f"Number of relevant contexts: {len(relevant_contexts)}")
# If no relevant contexts found, generate without retrieval
if not relevant_contexts:
print("No relevant contexts found. Generating without retrieval...")
input_data = {"query": query, "context": "No relevant context found."}
return self.generation_chain.invoke(input_data).response
# Step 4: Generate response using relevant contexts
print("Step 4: Generating responses using relevant contexts...")
responses = []
for i, context in enumerate(relevant_contexts):
print(f"Generating response for context {i + 1}...")
input_data = {"query": query, "context": context}
response = self.generation_chain.invoke(input_data).response
# Step 5: Assess support
print(f"Step 5: Assessing support for response {i + 1}...")
input_data = {"response": response, "context": context}
support = self.support_chain.invoke(input_data).response.strip().lower()
print(f"Support assessment: {support}")
# Step 6: Evaluate utility
print(f"Step 6: Evaluating utility for response {i + 1}...")
input_data = {"query": query, "response": response}
utility = int(self.utility_chain.invoke(input_data).response)
print(f"Utility score: {utility}")
responses.append((response, support, utility))
# Select the best response based on support and utility
print("Selecting the best response...")
best_response = max(responses, key=lambda x: (x[1] == 'fully supported', x[2]))
print(f"Best response support: {best_response[1]}, utility: {best_response[2]}")
return best_response[0]
else:
# Generate without retrieval
print("Generating without retrieval...")
input_data = {"query": query, "context": "No retrieval necessary."}
return self.generation_chain.invoke(input_data).response
# Argument parsing functions
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Self-RAG method")
parser.add_argument('--path', type=str, default='../data/Understanding_Climate_Change.pdf',
help='Path to the PDF file for vector store')
parser.add_argument('--query', type=str, default='What is the impact of climate change on the environment?',
help='Query to be processed')
return parser.parse_args()
# Main entry point
if __name__ == "__main__":
args = parse_args()
rag = SelfRAG(path=args.path)
response = rag.run(args.query)
print("\nFinal response:")
print(response)