faiss_index.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import faiss
  2. import numpy as np
  3. import json
  4. from time import time
  5. import asyncio
  6. from datasets import Dataset
  7. from typing import List
  8. from dotenv import load_dotenv
  9. import os
  10. import pickle
  11. from supabase.client import Client, create_client
  12. from langchain_openai import OpenAIEmbeddings, ChatOpenAI
  13. from langchain.prompts import ChatPromptTemplate
  14. from langchain_core.output_parsers import StrOutputParser
  15. import pandas as pd
  16. from langchain_core.documents import Document
  17. from langchain.load import dumps, loads
  18. # Import from the parent directory
  19. import sys
  20. from RAG_strategy import multi_query_chain
  21. sys.path.append('..')
  22. # from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
  23. system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  24. # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  25. from langchain.prompts import ChatPromptTemplate
  26. from langchain_core.output_parsers import StrOutputParser
  27. from file_loader.add_vectordb import GetVectorStore
  28. # from local_llm import ollama_, hf
  29. # # from local_llm import ollama_, taide_llm, hf
  30. # llm = hf()
  31. # llm = taide_llm
  32. # Import RAGAS metrics
  33. from ragas import evaluate
  34. from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
  35. # Load environment variables
  36. load_dotenv('../../.env')
  37. supabase_url = os.getenv("SUPABASE_URL")
  38. supabase_key = os.getenv("SUPABASE_KEY")
  39. openai_api_key = os.getenv("OPENAI_API_KEY")
  40. document_table = "documents"
  41. # Initialize Supabase client
  42. supabase: Client = create_client(supabase_url, supabase_key)
  43. # Initialize embeddings and chat model
  44. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  45. def download_embeddings():
  46. response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
  47. embeddings = []
  48. ids = []
  49. metadatas = []
  50. contents = []
  51. for item in response.data:
  52. embedding = json.loads(item['embedding'])
  53. embeddings.append(embedding)
  54. ids.append(item['id'])
  55. metadatas.append(item['metadata'])
  56. contents.append(item['content'])
  57. return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
  58. def create_faiss_index(embeddings):
  59. dimension = embeddings.shape[1]
  60. index = faiss.IndexFlatIP(dimension) # Use Inner Product for cosine similarity
  61. faiss.normalize_L2(embeddings) # Normalize embeddings for cosine similarity
  62. index.add(embeddings)
  63. return index
  64. def save_faiss_index(index, file_path):
  65. faiss.write_index(index, file_path)
  66. print(f"FAISS index saved to {file_path}")
  67. def load_faiss_index(file_path):
  68. if os.path.exists(file_path):
  69. index = faiss.read_index(file_path)
  70. print(f"FAISS index loaded from {file_path}")
  71. return index
  72. return None
  73. def save_metadata(ids, metadatas, contents, file_path):
  74. with open(file_path, 'wb') as f:
  75. pickle.dump((ids, metadatas, contents), f)
  76. print(f"Metadata saved to {file_path}")
  77. def load_metadata(file_path):
  78. if os.path.exists(file_path):
  79. with open(file_path, 'rb') as f:
  80. ids, metadatas, contents = pickle.load(f)
  81. print(f"Metadata loaded from {file_path}")
  82. return ids, metadatas, contents
  83. return None, None, None
  84. def search_faiss(index, query_vector, k=4):
  85. query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
  86. faiss.normalize_L2(query_vector)
  87. distances, indices = index.search(query_vector, k)
  88. return distances[0], indices[0]
  89. class FAISSRetriever:
  90. def __init__(self, index, ids, metadatas, contents, embeddings_model):
  91. self.index = index
  92. self.ids = ids
  93. self.metadatas = metadatas
  94. self.contents = contents
  95. self.embeddings_model = embeddings_model
  96. def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
  97. query_vector = self.embeddings_model.embed_query(query)
  98. _, indices = search_faiss(self.index, query_vector, k=k)
  99. return [
  100. Document(page_content=self.contents[i], metadata=self.metadatas[i])
  101. for i in indices
  102. ]
  103. def map(self, query_list: List[list]) -> List[Document]:
  104. def get_unique_union(documents: List[list]):
  105. """ Unique union of retrieved docs """
  106. # Flatten list of lists, and convert each Document to string
  107. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  108. # Get unique documents
  109. unique_docs = list(set(flattened_docs))
  110. # Return
  111. return [loads(doc) for doc in unique_docs]
  112. documents = []
  113. for query in query_list:
  114. if query != "":
  115. docs = self.get_relevant_documents(query)
  116. documents.extend(docs)
  117. return get_unique_union(documents)
  118. def load_qa_pairs():
  119. # df = pd.read_csv("../QA_database_rows.csv")
  120. response = supabase.table('QA_database').select("Question, Answer").execute()
  121. df = pd.DataFrame(response.data)
  122. return df['Question'].tolist(), df['Answer'].tolist()
  123. def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
  124. generate_queries = multi_query_chain(llm)
  125. questions = generate_queries.invoke(question)
  126. questions = [item for item in questions if item != ""]
  127. # docs = list(map(retriever.get_relevant_documents, questions))
  128. docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
  129. docs = [item for sublist in docs for item in sublist]
  130. return docs
  131. def faiss_query(question: str, retriever: FAISSRetriever, llm, multi_query: bool = False) -> str:
  132. if multi_query:
  133. docs = faiss_multiquery(question, retriever, llm)
  134. # print(docs)
  135. else:
  136. docs = retriever.get_relevant_documents(question, k=10)
  137. # print(docs)
  138. context = "\n".join(doc.page_content for doc in docs)
  139. template = """
  140. <|begin_of_text|>
  141. <|start_header_id|>system<|end_header_id|>
  142. 你是一個來自台灣的ESG的AI助理,
  143. 請用繁體中文回答問題 \n
  144. You should not mention anything about "根據提供的文件內容" or other similar terms.
  145. Use five sentences maximum and keep the answer concise.
  146. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  147. 勿回答無關資訊
  148. <|eot_id|>
  149. <|start_header_id|>user<|end_header_id|>
  150. Answer the following question based on this context:
  151. {context}
  152. Question: {question}
  153. <|eot_id|>
  154. <|start_header_id|>assistant<|end_header_id|>
  155. """
  156. prompt = ChatPromptTemplate.from_template(
  157. system_prompt + "\n\n" +
  158. template
  159. )
  160. # prompt = ChatPromptTemplate.from_template(
  161. # system_prompt + "\n\n" +
  162. # "Answer the following question based on this context:\n\n"
  163. # "{context}\n\n"
  164. # "Question: {question}\n"
  165. # "Answer in the same language as the question. If you don't know the answer, "
  166. # "say 'I'm sorry, I don't have enough information to answer that question.'"
  167. # )
  168. # chain = prompt | taide_llm | StrOutputParser()
  169. chain = prompt | llm | StrOutputParser()
  170. return chain.invoke({"context": context, "question": question})
  171. def create_faiss_retriever():
  172. faiss_index_path = "faiss_index.bin"
  173. metadata_path = "faiss_metadata.pkl"
  174. index = load_faiss_index(faiss_index_path)
  175. ids, metadatas, contents = load_metadata(metadata_path)
  176. if index is None or ids is None:
  177. print("FAISS index or metadata not found. Creating new index...")
  178. print("Downloading embeddings from Supabase...")
  179. embeddings_array, ids, metadatas, contents = download_embeddings()
  180. print("Creating FAISS index...")
  181. index = create_faiss_index(embeddings_array)
  182. save_faiss_index(index, faiss_index_path)
  183. save_metadata(ids, metadatas, contents, metadata_path)
  184. else:
  185. print("Using existing FAISS index and metadata.")
  186. print("Creating FAISS retriever...")
  187. faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
  188. return faiss_retriever
  189. async def run_evaluation():
  190. faiss_index_path = "faiss_index.bin"
  191. metadata_path = "faiss_metadata.pkl"
  192. index = load_faiss_index(faiss_index_path)
  193. ids, metadatas, contents = load_metadata(metadata_path)
  194. if index is None or ids is None:
  195. print("FAISS index or metadata not found. Creating new index...")
  196. print("Downloading embeddings from Supabase...")
  197. embeddings_array, ids, metadatas, contents = download_embeddings()
  198. print("Creating FAISS index...")
  199. index = create_faiss_index(embeddings_array)
  200. save_faiss_index(index, faiss_index_path)
  201. save_metadata(ids, metadatas, contents, metadata_path)
  202. else:
  203. print("Using existing FAISS index and metadata.")
  204. print("Creating FAISS retriever...")
  205. faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
  206. print("Creating original vector store...")
  207. original_vector_store = GetVectorStore(embeddings, supabase, document_table)
  208. original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
  209. questions, ground_truths = load_qa_pairs()
  210. for question, ground_truth in zip(questions, ground_truths):
  211. print(f"\nQuestion: {question}")
  212. start_time = time()
  213. faiss_answer = faiss_query(question, faiss_retriever)
  214. faiss_docs = faiss_retriever.get_relevant_documents(question)
  215. faiss_time = time() - start_time
  216. print(f"FAISS Answer: {faiss_answer}")
  217. print(f"FAISS Time: {faiss_time:.4f} seconds")
  218. start_time = time()
  219. original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
  220. original_time = time() - start_time
  221. print(f"Original Answer: {original_answer}")
  222. print(f"Original Time: {original_time:.4f} seconds")
  223. # faiss_datasets = {
  224. # "question": [question],
  225. # "answer": [faiss_answer],
  226. # "contexts": [[doc.page_content for doc in faiss_docs]],
  227. # "ground_truth": [ground_truth]
  228. # }
  229. # faiss_evalsets = Dataset.from_dict(faiss_datasets)
  230. # faiss_result = evaluate(
  231. # faiss_evalsets,
  232. # metrics=[
  233. # context_precision,
  234. # faithfulness,
  235. # answer_relevancy,
  236. # context_recall,
  237. # ],
  238. # )
  239. # print("FAISS RAGAS Evaluation:")
  240. # print(faiss_result.to_pandas())
  241. # original_datasets = {
  242. # "question": [question],
  243. # "answer": [original_answer],
  244. # "contexts": [[doc.page_content for doc in original_docs]],
  245. # "ground_truth": [ground_truth]
  246. # }
  247. # original_evalsets = Dataset.from_dict(original_datasets)
  248. # original_result = evaluate(
  249. # original_evalsets,
  250. # metrics=[
  251. # context_precision,
  252. # faithfulness,
  253. # answer_relevancy,
  254. # context_recall,
  255. # ],
  256. # )
  257. # print("Original RAGAS Evaluation:")
  258. # print(original_result.to_pandas())
  259. print("\nPerformance comparison complete.")
  260. async def ask_question():
  261. faiss_index_path = "faiss_index.bin"
  262. metadata_path = "faiss_metadata.pkl"
  263. index = load_faiss_index(faiss_index_path)
  264. ids, metadatas, contents = load_metadata(metadata_path)
  265. if index is None or ids is None:
  266. print("FAISS index or metadata not found. Creating new index...")
  267. print("Downloading embeddings from Supabase...")
  268. embeddings_array, ids, metadatas, contents = download_embeddings()
  269. print("Creating FAISS index...")
  270. index = create_faiss_index(embeddings_array)
  271. save_faiss_index(index, faiss_index_path)
  272. save_metadata(ids, metadatas, contents, metadata_path)
  273. else:
  274. print("Using existing FAISS index and metadata.")
  275. print("Creating FAISS retriever...")
  276. faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
  277. # print("Creating original vector store...")
  278. # original_vector_store = GetVectorStore(embeddings, supabase, document_table)
  279. # original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
  280. # questions, ground_truths = load_qa_pairs()
  281. # for question, ground_truth in zip(questions, ground_truths):
  282. question = ""
  283. while question != "exit":
  284. question = input("Question: ")
  285. print(f"\nQuestion: {question}")
  286. start_time = time()
  287. faiss_answer = faiss_query(question, faiss_retriever)
  288. faiss_docs = faiss_retriever.get_relevant_documents(question)
  289. faiss_time = time() - start_time
  290. print(f"FAISS Answer: {faiss_answer}")
  291. print(f"FAISS Time: {faiss_time:.4f} seconds")
  292. # start_time = time()
  293. # original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
  294. # original_time = time() - start_time
  295. # print(f"Original Answer: {original_answer}")
  296. # print(f"Original Time: {original_time:.4f} seconds")
  297. if __name__ == "__main__":
  298. global_retriever = create_faiss_retriever()
  299. questions, ground_truths = load_qa_pairs()
  300. results = []
  301. for question, ground_truth in zip(questions, ground_truths):
  302. # For multi_query=True
  303. start = time()
  304. final_answer_multi = faiss_query(question, global_retriever, multi_query=True)
  305. processing_time_multi = time() - start
  306. # print(final_answer_multi)
  307. # print(processing_time_multi)
  308. # For multi_query=False
  309. start = time()
  310. final_answer_single = faiss_query(question, global_retriever, multi_query=False)
  311. processing_time_single = time() - start
  312. # print(final_answer_single)
  313. # print(processing_time_single)
  314. # Store results in a dictionary
  315. result = {
  316. "question": question,
  317. "ground_truth": ground_truth,
  318. "final_answer_multi_query": final_answer_multi,
  319. "processing_time_multi_query": processing_time_multi,
  320. "final_answer_single_query": final_answer_single,
  321. "processing_time_single_query": processing_time_single
  322. }
  323. print(result)
  324. results.append(result)
  325. with open('qa_results.json', 'a', encoding='utf8') as outfile:
  326. json.dump(result, outfile, indent=4, ensure_ascii=False)
  327. outfile.write("\n") # Ensure each result is on a new line
  328. # Save results to a JSON file
  329. with open('qa_results_all.json', 'w', encoding='utf8') as outfile:
  330. json.dump(results, outfile, indent=4, ensure_ascii=False)
  331. print('All questions done!')
  332. # question = ""
  333. # while question != "exit":
  334. # # question = "國家溫室氣體長期減量目標"
  335. # question = input("Question: ")
  336. # if question.strip().lower == "exit": break
  337. # start = time()
  338. # final_answer = faiss_query(question, global_retriever, multi_query=True)
  339. # print(final_answer)
  340. # processing_time = time() - start
  341. # print(processing_time)
  342. # start = time()
  343. # final_answer = faiss_query(question, global_retriever, multi_query=False)
  344. # print(final_answer)
  345. # processing_time = time() - start
  346. # print(processing_time)
  347. # print("Chatbot closed!")
  348. # asyncio.run(ask_question())