-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery_rag_v2.py
85 lines (66 loc) · 2.57 KB
/
query_rag_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain.llms import Ollama
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import LLMChain
from langchain_community.vectorstores import FAISS
import re
from typing import List
# Define paths
vector_store_folder = 'vector_store'
def extract_keywords(query: str) -> List[str]:
"""
Extracts keywords from the input query by removing common stop words and special characters.
"""
# Lowercase and remove non-alphanumeric characters
query = query.lower()
query = re.sub(r'[^a-z\s]', '', query)
# Tokenize and remove common stop words
stop_words = set(["a", "an", "the", "and", "or", "but", "of", "to", "in", "with", "on", "for", "at", "by"])
keywords = [word for word in query.split() if word not in stop_words]
return keywords
def build_qa_chain(vectorstore, llm):
RAG_TEMPLATE = """
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
<context>
{context}
</context>
Answer the following question:
{question}
"""
prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)
chain = LLMChain(
llm=llm,
prompt=prompt,
output_parser=StrOutputParser()
)
return chain
def query_rag():
# Initialize embedding model and vector store
embedding_model = "plutonioumguy/bge-m3"
vectorstore = Chroma(persist_directory=vector_store_folder, embedding_function=OllamaEmbeddings(model=embedding_model))
# Initialize LLM
llm = Ollama(model="llama3.1", base_url="http://localhost:11434")
# Query loop
while True:
query = input("\nEnter your query (type 'quit' to exit): ")
if query.lower() == 'quit':
break
# Extract keywords from query
keywords = extract_keywords(query)
keyword_query = ' '.join(keywords)
# Build the QA chain
qa_chain = build_qa_chain(vectorstore, llm)
# Retrieve documents from vector store using keywords
docs = vectorstore.similarity_search(keyword_query)
# Prepare input for the chain
chain_input = {
"context": docs,
"question": query
}
# Get and print the answer
answer = qa_chain.run(chain_input)
print("Answer:", answer)
if __name__ == "__main__":
query_rag()