-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathexplainable_retrieval.py
86 lines (65 loc) · 2.85 KB
/
explainable_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import sys
from dotenv import load_dotenv
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
from helper_functions import *
from evaluation.evalute_rag import *
# Load environment variables from a .env file
load_dotenv()
# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# Define utility classes/functions
class ExplainableRetriever:
def __init__(self, texts):
self.embeddings = OpenAIEmbeddings()
self.vectorstore = FAISS.from_texts(texts, self.embeddings)
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
explain_prompt = PromptTemplate(
input_variables=["query", "context"],
template="""
Analyze the relationship between the following query and the retrieved context.
Explain why this context is relevant to the query and how it might help answer the query.
Query: {query}
Context: {context}
Explanation:
"""
)
self.explain_chain = explain_prompt | self.llm
def retrieve_and_explain(self, query):
docs = self.retriever.get_relevant_documents(query)
explained_results = []
for doc in docs:
input_data = {"query": query, "context": doc.page_content}
explanation = self.explain_chain.invoke(input_data).content
explained_results.append({
"content": doc.page_content,
"explanation": explanation
})
return explained_results
class ExplainableRAGMethod:
def __init__(self, texts):
self.explainable_retriever = ExplainableRetriever(texts)
def run(self, query):
return self.explainable_retriever.retrieve_and_explain(query)
# Argument Parsing
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Explainable RAG Method")
parser.add_argument('--query', type=str, default='Why is the sky blue?', help="Query for the retriever")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Sample texts (these can be replaced by actual data)
texts = [
"The sky is blue because of the way sunlight interacts with the atmosphere.",
"Photosynthesis is the process by which plants use sunlight to produce energy.",
"Global warming is caused by the increase of greenhouse gases in Earth's atmosphere."
]
explainable_rag = ExplainableRAGMethod(texts)
results = explainable_rag.run(args.query)
for i, result in enumerate(results, 1):
print(f"Result {i}:")
print(f"Content: {result['content']}")
print(f"Explanation: {result['explanation']}")
print()