faiss_index.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import faiss
  2. import numpy as np
  3. import json
  4. from time import time
  5. import asyncio
  6. from datasets import Dataset
  7. from typing import List
  8. from dotenv import load_dotenv
  9. import os
  10. import pickle
  11. from supabase.client import Client, create_client
  12. from langchain_openai import OpenAIEmbeddings, ChatOpenAI
  13. from langchain.prompts import ChatPromptTemplate
  14. from langchain_core.output_parsers import StrOutputParser
  15. import pandas as pd
  16. from langchain_core.documents import Document
  17. from langchain.load import dumps, loads
  18. from langchain_community.chat_models import ChatOllama
  19. # Import from the parent directory
  20. import sys
  21. from RAG_strategy import multi_query_chain
  22. sys.path.append('..')
  23. # from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
  24. system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  25. # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  26. from langchain.prompts import ChatPromptTemplate
  27. from langchain_core.output_parsers import StrOutputParser
  28. from file_loader.add_vectordb import GetVectorStore
  29. # from local_llm import ollama_, hf
  30. # # from local_llm import ollama_, taide_llm, hf
  31. # llm = hf()
  32. # llm = taide_llm
  33. # Import RAGAS metrics
  34. from ragas import evaluate
  35. from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
  36. # Load environment variables
  37. load_dotenv('../../.env')
  38. supabase_url = os.getenv("SUPABASE_URL")
  39. supabase_key = os.getenv("SUPABASE_KEY")
  40. openai_api_key = os.getenv("OPENAI_API_KEY")
  41. document_table = "documents2"
  42. # Initialize Supabase client
  43. supabase: Client = create_client(supabase_url, supabase_key)
  44. # Initialize embeddings and chat model
  45. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  46. def download_embeddings():
  47. response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
  48. # response = supabase.table(document_table).select("id, embedding, metadata, content").eq('metadata ->> source', 'supplement.docx').execute()
  49. embeddings = []
  50. ids = []
  51. metadatas = []
  52. contents = []
  53. for item in response.data:
  54. embedding = json.loads(item['embedding'])
  55. embeddings.append(embedding)
  56. ids.append(item['id'])
  57. metadatas.append(item['metadata'])
  58. contents.append(item['content'])
  59. return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
  60. def create_faiss_index(embeddings):
  61. dimension = embeddings.shape[1]
  62. index = faiss.IndexFlatIP(dimension) # Use Inner Product for cosine similarity
  63. faiss.normalize_L2(embeddings) # Normalize embeddings for cosine similarity
  64. index.add(embeddings)
  65. return index
  66. def save_faiss_index(index, file_path):
  67. faiss.write_index(index, file_path)
  68. print(f"FAISS index saved to {file_path}")
  69. def load_faiss_index(file_path):
  70. if os.path.exists(file_path):
  71. index = faiss.read_index(file_path)
  72. print(f"FAISS index loaded from {file_path}")
  73. return index
  74. return None
  75. def save_metadata(ids, metadatas, contents, file_path):
  76. with open(file_path, 'wb') as f:
  77. pickle.dump((ids, metadatas, contents), f)
  78. print(f"Metadata saved to {file_path}")
  79. def load_metadata(file_path):
  80. if os.path.exists(file_path):
  81. with open(file_path, 'rb') as f:
  82. ids, metadatas, contents = pickle.load(f)
  83. print(f"Metadata loaded from {file_path}")
  84. return ids, metadatas, contents
  85. return None, None, None
  86. def search_faiss(index, query_vector, k=4):
  87. query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
  88. faiss.normalize_L2(query_vector)
  89. distances, indices = index.search(query_vector, k)
  90. return distances[0], indices[0]
  91. class FAISSRetriever:
  92. def __init__(self, index, ids, metadatas, contents, embeddings_model):
  93. self.index = index
  94. self.ids = ids
  95. self.metadatas = metadatas
  96. self.contents = contents
  97. self.embeddings_model = embeddings_model
  98. def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
  99. query_vector = self.embeddings_model.embed_query(query)
  100. _, indices = search_faiss(self.index, query_vector, k=k)
  101. return [
  102. Document(page_content=self.contents[i], metadata=self.metadatas[i])
  103. for i in indices
  104. ]
  105. def map(self, query_list: List[list]) -> List[Document]:
  106. def get_unique_union(documents: List[list]):
  107. """ Unique union of retrieved docs """
  108. # Flatten list of lists, and convert each Document to string
  109. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  110. # Get unique documents
  111. unique_docs = list(set(flattened_docs))
  112. # Return
  113. return [loads(doc) for doc in unique_docs]
  114. documents = []
  115. for query in query_list:
  116. if query != "":
  117. docs = self.get_relevant_documents(query)
  118. documents.extend(docs)
  119. return get_unique_union(documents)
  120. def load_qa_pairs():
  121. # df = pd.read_csv("../QA_database_rows.csv")
  122. response = supabase.table('QA_database').select("Question, Answer").execute()
  123. df = pd.DataFrame(response.data)
  124. return df['Question'].tolist(), df['Answer'].tolist()
  125. def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
  126. generate_queries = multi_query_chain(llm)
  127. questions = generate_queries.invoke(question)
  128. questions = [item for item in questions if item != ""]
  129. questions.append(question)
  130. for q in questions:
  131. print(q)
  132. # docs = list(map(retriever.get_relevant_documents, questions))
  133. docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
  134. docs = [item for sublist in docs for item in sublist]
  135. return docs
  136. def faiss_query(retriever, question: str, llm, multi_query: bool = False) -> str:
  137. if multi_query:
  138. docs = faiss_multiquery(question, retriever, llm)
  139. # print(docs)
  140. else:
  141. docs = retriever.get_relevant_documents(question, k=10)
  142. # print(docs)
  143. context = docs
  144. system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  145. template = """
  146. <|begin_of_text|>
  147. <|start_header_id|>system<|end_header_id|>
  148. 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
  149. You should not mention anything about "根據提供的文件內容" or other similar terms.
  150. 請盡可能的詳細回答問題。
  151. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  152. 勿回答無關資訊
  153. <|eot_id|>
  154. <|start_header_id|>user<|end_header_id|>
  155. Answer the following question based on this context:
  156. {context}
  157. Question: {question}
  158. 用繁體中文回答問題,請用一段話詳細的回答。
  159. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  160. <|eot_id|>
  161. <|start_header_id|>assistant<|end_header_id|>
  162. """
  163. prompt = ChatPromptTemplate.from_template(
  164. system_prompt + "\n\n" +
  165. template
  166. )
  167. rag_chain = prompt | llm | StrOutputParser()
  168. return context, rag_chain.invoke({"context": context, "question": question})
  169. def create_faiss_retriever():
  170. faiss_index_path = "faiss_index.bin"
  171. metadata_path = "faiss_metadata.pkl"
  172. index = load_faiss_index(faiss_index_path)
  173. ids, metadatas, contents = load_metadata(metadata_path)
  174. if index is None or ids is None:
  175. print("FAISS index or metadata not found. Creating new index...")
  176. print("Downloading embeddings from Supabase...")
  177. embeddings_array, ids, metadatas, contents = download_embeddings()
  178. print("Creating FAISS index...")
  179. index = create_faiss_index(embeddings_array)
  180. save_faiss_index(index, faiss_index_path)
  181. save_metadata(ids, metadatas, contents, metadata_path)
  182. else:
  183. print("Using existing FAISS index and metadata.")
  184. print("Creating FAISS retriever...")
  185. faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
  186. return faiss_retriever
  187. async def run_evaluation():
  188. local_llm = "llama3-groq-tool-use:latest"
  189. llama3 = ChatOllama(model=local_llm, temperature=0)
  190. openai = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  191. retriever = create_faiss_retriever()
  192. questions, ground_truths = load_qa_pairs()
  193. for question, ground_truth in zip(questions[:5], ground_truths[:5]):
  194. print(f"\nQuestion: {question}")
  195. start_time = time()
  196. llama3_docs, llama3_answer = faiss_query(retriever, question, llama3, multi_query=True)
  197. llama3_time = time() - start_time
  198. print(f"llama3 Answer: {llama3_answer}")
  199. print(f"llama3 Time: {llama3_time:.4f} seconds")
  200. llama3_datasets = {
  201. "question": [question],
  202. "answer": [llama3_answer],
  203. "contexts": [[doc.page_content for doc in llama3_docs]],
  204. "ground_truth": [ground_truth]
  205. }
  206. llama3_evalsets = Dataset.from_dict(llama3_datasets)
  207. llama3_result = evaluate(
  208. llama3_evalsets,
  209. metrics=[
  210. context_precision,
  211. faithfulness,
  212. answer_relevancy,
  213. context_recall,
  214. ],
  215. )
  216. print("llama3 RAGAS Evaluation:")
  217. llama3_result['time'] = llama3_time
  218. df = llama3_result.to_pandas()
  219. print(df)
  220. df.to_csv("llama.csv", mode='a')
  221. #############################################################
  222. start_time = time()
  223. openai_docs, openai_answer = faiss_query(retriever, question, openai, multi_query=True)
  224. openai_time = time() - start_time
  225. print(f"openai Answer: {openai_answer}")
  226. print(f"openai Time: {openai_time:.4f} seconds")
  227. openai_datasets = {
  228. "question": [question],
  229. "answer": [openai_answer],
  230. "contexts": [[doc.page_content for doc in openai_docs]],
  231. "ground_truth": [ground_truth]
  232. }
  233. openai_evalsets = Dataset.from_dict(openai_datasets)
  234. openai_result = evaluate(
  235. openai_evalsets,
  236. metrics=[
  237. context_precision,
  238. faithfulness,
  239. answer_relevancy,
  240. context_recall,
  241. ],
  242. )
  243. print("openai RAGAS Evaluation:")
  244. openai_result['time'] = llama3_time
  245. df = openai_result.to_pandas()
  246. print(df)
  247. df.to_csv("openai.csv", mode='a')
  248. print("\nPerformance comparison complete.")
  249. if __name__ == "__main__":
  250. asyncio.run(run_evaluation())