faiss_index.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import faiss
  2. import numpy as np
  3. import json
  4. from time import time
  5. import asyncio
  6. from datasets import Dataset
  7. from typing import List
  8. from dotenv import load_dotenv
  9. import os
  10. import pickle
  11. from supabase.client import Client, create_client
  12. from langchain_openai import OpenAIEmbeddings, ChatOpenAI
  13. from langchain.prompts import ChatPromptTemplate
  14. from langchain_core.output_parsers import StrOutputParser
  15. import pandas as pd
  16. from langchain_core.documents import Document
  17. # Import from the parent directory
  18. import sys
  19. sys.path.append('..')
  20. from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
  21. from langchain.prompts import ChatPromptTemplate
  22. from langchain_core.output_parsers import StrOutputParser
  23. from add_vectordb import GetVectorStore
  24. # Import RAGAS metrics
  25. from ragas import evaluate
  26. from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
  27. # Load environment variables
  28. load_dotenv('../environment.env')
  29. supabase_url = os.getenv("SUPABASE_URL")
  30. supabase_key = os.getenv("SUPABASE_KEY")
  31. openai_api_key = os.getenv("OPENAI_API_KEY")
  32. document_table = "documents"
  33. # Initialize Supabase client
  34. supabase: Client = create_client(supabase_url, supabase_key)
  35. # Initialize embeddings and chat model
  36. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  37. def download_embeddings():
  38. response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
  39. embeddings = []
  40. ids = []
  41. metadatas = []
  42. contents = []
  43. for item in response.data:
  44. embedding = json.loads(item['embedding'])
  45. embeddings.append(embedding)
  46. ids.append(item['id'])
  47. metadatas.append(item['metadata'])
  48. contents.append(item['content'])
  49. return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
  50. def create_faiss_index(embeddings):
  51. dimension = embeddings.shape[1]
  52. index = faiss.IndexFlatIP(dimension) # Use Inner Product for cosine similarity
  53. faiss.normalize_L2(embeddings) # Normalize embeddings for cosine similarity
  54. index.add(embeddings)
  55. return index
  56. def save_faiss_index(index, file_path):
  57. faiss.write_index(index, file_path)
  58. print(f"FAISS index saved to {file_path}")
  59. def load_faiss_index(file_path):
  60. if os.path.exists(file_path):
  61. index = faiss.read_index(file_path)
  62. print(f"FAISS index loaded from {file_path}")
  63. return index
  64. return None
  65. def save_metadata(ids, metadatas, contents, file_path):
  66. with open(file_path, 'wb') as f:
  67. pickle.dump((ids, metadatas, contents), f)
  68. print(f"Metadata saved to {file_path}")
  69. def load_metadata(file_path):
  70. if os.path.exists(file_path):
  71. with open(file_path, 'rb') as f:
  72. ids, metadatas, contents = pickle.load(f)
  73. print(f"Metadata loaded from {file_path}")
  74. return ids, metadatas, contents
  75. return None, None, None
  76. def search_faiss(index, query_vector, k=4):
  77. query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
  78. faiss.normalize_L2(query_vector)
  79. distances, indices = index.search(query_vector, k)
  80. return distances[0], indices[0]
  81. class FAISSRetriever:
  82. def __init__(self, index, ids, metadatas, contents, embeddings_model):
  83. self.index = index
  84. self.ids = ids
  85. self.metadatas = metadatas
  86. self.contents = contents
  87. self.embeddings_model = embeddings_model
  88. def get_relevant_documents(self, query: str) -> List[Document]:
  89. query_vector = self.embeddings_model.embed_query(query)
  90. _, indices = search_faiss(self.index, query_vector)
  91. return [
  92. Document(page_content=self.contents[i], metadata=self.metadatas[i])
  93. for i in indices
  94. ]
  95. def load_qa_pairs():
  96. df = pd.read_csv("../QA_database_rows.csv")
  97. return df['Question'].tolist(), df['Answer'].tolist()
  98. def faiss_query(question: str, retriever: FAISSRetriever) -> str:
  99. docs = retriever.get_relevant_documents(question)
  100. context = "\n".join(doc.page_content for doc in docs)
  101. prompt = ChatPromptTemplate.from_template(
  102. system_prompt + "\n\n" +
  103. "Answer the following question based on this context:\n\n"
  104. "{context}\n\n"
  105. "Question: {question}\n"
  106. "Answer in the same language as the question. If you don't know the answer, "
  107. "say 'I'm sorry, I don't have enough information to answer that question.'"
  108. )
  109. chain = prompt | taide_llm | StrOutputParser()
  110. return chain.invoke({"context": context, "question": question})
  111. async def run_evaluation():
  112. faiss_index_path = "faiss_index.bin"
  113. metadata_path = "faiss_metadata.pkl"
  114. index = load_faiss_index(faiss_index_path)
  115. ids, metadatas, contents = load_metadata(metadata_path)
  116. if index is None or ids is None:
  117. print("FAISS index or metadata not found. Creating new index...")
  118. print("Downloading embeddings from Supabase...")
  119. embeddings_array, ids, metadatas, contents = download_embeddings()
  120. print("Creating FAISS index...")
  121. index = create_faiss_index(embeddings_array)
  122. save_faiss_index(index, faiss_index_path)
  123. save_metadata(ids, metadatas, contents, metadata_path)
  124. else:
  125. print("Using existing FAISS index and metadata.")
  126. print("Creating FAISS retriever...")
  127. faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
  128. print("Creating original vector store...")
  129. original_vector_store = GetVectorStore(embeddings, supabase, document_table)
  130. original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
  131. questions, ground_truths = load_qa_pairs()
  132. for question, ground_truth in zip(questions, ground_truths):
  133. print(f"\nQuestion: {question}")
  134. start_time = time()
  135. faiss_answer = faiss_query(question, faiss_retriever)
  136. faiss_docs = faiss_retriever.get_relevant_documents(question)
  137. faiss_time = time() - start_time
  138. print(f"FAISS Answer: {faiss_answer}")
  139. print(f"FAISS Time: {faiss_time:.4f} seconds")
  140. start_time = time()
  141. original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
  142. original_time = time() - start_time
  143. print(f"Original Answer: {original_answer}")
  144. print(f"Original Time: {original_time:.4f} seconds")
  145. faiss_datasets = {
  146. "question": [question],
  147. "answer": [faiss_answer],
  148. "contexts": [[doc.page_content for doc in faiss_docs]],
  149. "ground_truth": [ground_truth]
  150. }
  151. faiss_evalsets = Dataset.from_dict(faiss_datasets)
  152. faiss_result = evaluate(
  153. faiss_evalsets,
  154. metrics=[
  155. context_precision,
  156. faithfulness,
  157. answer_relevancy,
  158. context_recall,
  159. ],
  160. )
  161. print("FAISS RAGAS Evaluation:")
  162. print(faiss_result.to_pandas())
  163. original_datasets = {
  164. "question": [question],
  165. "answer": [original_answer],
  166. "contexts": [[doc.page_content for doc in original_docs]],
  167. "ground_truth": [ground_truth]
  168. }
  169. original_evalsets = Dataset.from_dict(original_datasets)
  170. original_result = evaluate(
  171. original_evalsets,
  172. metrics=[
  173. context_precision,
  174. faithfulness,
  175. answer_relevancy,
  176. context_recall,
  177. ],
  178. )
  179. print("Original RAGAS Evaluation:")
  180. print(original_result.to_pandas())
  181. print("\nPerformance comparison complete.")
  182. if __name__ == "__main__":
  183. asyncio.run(run_evaluation())