faiss_index.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import faiss
  2. import numpy as np
  3. import json
  4. from time import time
  5. import asyncio
  6. from datasets import Dataset
  7. from typing import List
  8. from dotenv import load_dotenv
  9. import os
  10. import pickle
  11. from supabase.client import Client, create_client
  12. from langchain_openai import OpenAIEmbeddings, ChatOpenAI
  13. from langchain.prompts import ChatPromptTemplate
  14. from langchain_core.output_parsers import StrOutputParser
  15. import pandas as pd
  16. from langchain_core.documents import Document
  17. from langchain.load import dumps, loads
  18. from langchain_community.chat_models import ChatOllama
  19. from langchain.callbacks import get_openai_callback
  20. # Import from the parent directory
  21. import sys
  22. from RAG_strategy import multi_query_chain
  23. sys.path.append('..')
  24. # from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
  25. system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  26. # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  27. from langchain.prompts import ChatPromptTemplate
  28. from langchain_core.output_parsers import StrOutputParser
  29. from file_loader.add_vectordb import GetVectorStore
  30. # from local_llm import ollama_, hf
  31. # # from local_llm import ollama_, taide_llm, hf
  32. # llm = hf()
  33. # llm = taide_llm
  34. # Import RAGAS metrics
  35. from ragas import evaluate
  36. from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
  37. # Load environment variables
  38. load_dotenv('../../.env')
  39. supabase_url = os.getenv("SUPABASE_URL")
  40. supabase_key = os.getenv("SUPABASE_KEY")
  41. openai_api_key = os.getenv("OPENAI_API_KEY")
  42. document_table = "documents2"
  43. # Initialize Supabase client
  44. supabase: Client = create_client(supabase_url, supabase_key)
  45. # Initialize embeddings and chat model
  46. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  47. def download_embeddings():
  48. response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
  49. # response = supabase.table(document_table).select("id, embedding, metadata, content").eq('metadata ->> source', 'supplement.docx').execute()
  50. embeddings = []
  51. ids = []
  52. metadatas = []
  53. contents = []
  54. for item in response.data:
  55. embedding = json.loads(item['embedding'])
  56. embeddings.append(embedding)
  57. ids.append(item['id'])
  58. metadatas.append(item['metadata'])
  59. contents.append(item['content'])
  60. return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
  61. def create_faiss_index(embeddings):
  62. dimension = embeddings.shape[1]
  63. index = faiss.IndexFlatIP(dimension) # Use Inner Product for cosine similarity
  64. faiss.normalize_L2(embeddings) # Normalize embeddings for cosine similarity
  65. index.add(embeddings)
  66. return index
  67. def save_faiss_index(index, file_path):
  68. faiss.write_index(index, file_path)
  69. print(f"FAISS index saved to {file_path}")
  70. def load_faiss_index(file_path):
  71. if os.path.exists(file_path):
  72. index = faiss.read_index(file_path)
  73. print(f"FAISS index loaded from {file_path}")
  74. return index
  75. return None
  76. def save_metadata(ids, metadatas, contents, file_path):
  77. with open(file_path, 'wb') as f:
  78. pickle.dump((ids, metadatas, contents), f)
  79. print(f"Metadata saved to {file_path}")
  80. def load_metadata(file_path):
  81. if os.path.exists(file_path):
  82. with open(file_path, 'rb') as f:
  83. ids, metadatas, contents = pickle.load(f)
  84. print(f"Metadata loaded from {file_path}")
  85. return ids, metadatas, contents
  86. return None, None, None
  87. def search_faiss(index, query_vector, k=4):
  88. query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
  89. faiss.normalize_L2(query_vector)
  90. distances, indices = index.search(query_vector, k)
  91. return distances[0], indices[0]
  92. class FAISSRetriever:
  93. def __init__(self, index, ids, metadatas, contents, embeddings_model):
  94. self.index = index
  95. self.ids = ids
  96. self.metadatas = metadatas
  97. self.contents = contents
  98. self.embeddings_model = embeddings_model
  99. def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
  100. query_vector = self.embeddings_model.embed_query(query)
  101. _, indices = search_faiss(self.index, query_vector, k=k)
  102. return [
  103. Document(page_content=self.contents[i], metadata=self.metadatas[i])
  104. for i in indices
  105. ]
  106. def map(self, query_list: List[list]) -> List[Document]:
  107. def get_unique_union(documents: List[list]):
  108. """ Unique union of retrieved docs """
  109. # Flatten list of lists, and convert each Document to string
  110. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  111. # Get unique documents
  112. unique_docs = list(set(flattened_docs))
  113. # Return
  114. return [loads(doc) for doc in unique_docs]
  115. documents = []
  116. for query in query_list:
  117. if query != "":
  118. docs = self.get_relevant_documents(query)
  119. documents.extend(docs)
  120. return get_unique_union(documents)
  121. def load_qa_pairs():
  122. # df = pd.read_csv("../QA_database_rows.csv")
  123. response = supabase.table('QA_database').select("Question, Answer").execute()
  124. df = pd.DataFrame(response.data)
  125. return df['Question'].tolist(), df['Answer'].tolist()
  126. def faiss_multiquery(question: str, retriever: FAISSRetriever, llm, k: int = 4):
  127. generate_queries = multi_query_chain(llm)
  128. questions = generate_queries.invoke(question)
  129. questions = [item for item in questions if item != ""]
  130. questions.append(question)
  131. for q in questions:
  132. print(q)
  133. # docs = list(map(retriever.get_relevant_documents, questions))
  134. docs = list(map(lambda query: retriever.get_relevant_documents(query, k=k), questions))
  135. docs = [item for sublist in docs for item in sublist]
  136. return docs
  137. def faiss_query(retriever, question: str, llm, k: int = 4, multi_query: bool = False) -> str:
  138. if multi_query:
  139. docs = faiss_multiquery(question, retriever, llm, k)
  140. # print(docs)
  141. else:
  142. docs = retriever.get_relevant_documents(question, k)
  143. # print(docs)
  144. context = docs
  145. system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  146. template = """
  147. <|begin_of_text|>
  148. <|start_header_id|>system<|end_header_id|>
  149. 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
  150. You should not mention anything about "根據提供的文件內容" or other similar terms.
  151. 請盡可能的詳細回答問題。
  152. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  153. 勿回答無關資訊或任何與某特定公司相關的問題
  154. <|eot_id|>
  155. <|start_header_id|>user<|end_header_id|>
  156. Answer the following question based on this context:
  157. {context}
  158. Question: {question}
  159. 用繁體中文回答問題,請用一段話詳細的回答。勿回答無關資訊或任何與某特定公司相關的問題。
  160. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  161. <|eot_id|>
  162. <|start_header_id|>assistant<|end_header_id|>
  163. """
  164. prompt = ChatPromptTemplate.from_template(
  165. system_prompt + "\n\n" +
  166. template
  167. )
  168. rag_chain = prompt | llm | StrOutputParser()
  169. return context, rag_chain.invoke({"context": context, "question": question})
  170. def create_faiss_retriever():
  171. faiss_index_path = "faiss_index.bin"
  172. metadata_path = "faiss_metadata.pkl"
  173. index = load_faiss_index(faiss_index_path)
  174. ids, metadatas, contents = load_metadata(metadata_path)
  175. if index is None or ids is None:
  176. print("FAISS index or metadata not found. Creating new index...")
  177. print("Downloading embeddings from Supabase...")
  178. embeddings_array, ids, metadatas, contents = download_embeddings()
  179. print("Creating FAISS index...")
  180. index = create_faiss_index(embeddings_array)
  181. save_faiss_index(index, faiss_index_path)
  182. save_metadata(ids, metadatas, contents, metadata_path)
  183. else:
  184. print("Using existing FAISS index and metadata.")
  185. print("Creating FAISS retriever...")
  186. faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
  187. return faiss_retriever
  188. async def run_evaluation():
  189. local_llm = "llama3-groq-tool-use:latest"
  190. llama3 = ChatOllama(model=local_llm, temperature=0)
  191. openai = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  192. retriever = create_faiss_retriever()
  193. questions, ground_truths = load_qa_pairs()
  194. for question, ground_truth in zip(questions[:5], ground_truths[:5]):
  195. print(f"\nQuestion: {question}")
  196. start_time = time()
  197. llama3_docs, llama3_answer = faiss_query(retriever, question, llama3, multi_query=True)
  198. llama3_time = time() - start_time
  199. print(f"llama3 Answer: {llama3_answer}")
  200. print(f"llama3 Time: {llama3_time:.4f} seconds")
  201. llama3_datasets = {
  202. "question": [question],
  203. "answer": [llama3_answer],
  204. "contexts": [[doc.page_content for doc in llama3_docs]],
  205. "ground_truth": [ground_truth]
  206. }
  207. llama3_evalsets = Dataset.from_dict(llama3_datasets)
  208. llama3_result = evaluate(
  209. llama3_evalsets,
  210. metrics=[
  211. context_precision,
  212. faithfulness,
  213. answer_relevancy,
  214. context_recall,
  215. ],
  216. )
  217. print("llama3 RAGAS Evaluation:")
  218. llama3_result['time'] = llama3_time
  219. df = llama3_result.to_pandas()
  220. print(df)
  221. df.to_csv("llama.csv", mode='a')
  222. #############################################################
  223. start_time = time()
  224. openai_docs, openai_answer = faiss_query(retriever, question, openai, multi_query=True)
  225. openai_time = time() - start_time
  226. print(f"openai Answer: {openai_answer}")
  227. print(f"openai Time: {openai_time:.4f} seconds")
  228. openai_datasets = {
  229. "question": [question],
  230. "answer": [openai_answer],
  231. "contexts": [[doc.page_content for doc in openai_docs]],
  232. "ground_truth": [ground_truth]
  233. }
  234. openai_evalsets = Dataset.from_dict(openai_datasets)
  235. openai_result = evaluate(
  236. openai_evalsets,
  237. metrics=[
  238. context_precision,
  239. faithfulness,
  240. answer_relevancy,
  241. context_recall,
  242. ],
  243. )
  244. print("openai RAGAS Evaluation:")
  245. openai_result['time'] = llama3_time
  246. df = openai_result.to_pandas()
  247. print(df)
  248. df.to_csv("openai.csv", mode='a')
  249. print("\nPerformance comparison complete.")
  250. def load_kg_q():
  251. import pandas as pd
  252. sheet_id ="1eVmO8fIjjtEQBlmqtipi2d0H-8VgJkI-icqhSTfdhHQ"
  253. gid = "0"
  254. df = pd.read_csv(f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv&gid={gid}")
  255. return df['Question'].tolist()
  256. async def run_q_batch():
  257. # local_llm = "llama3-groq-tool-use:latest"
  258. # llama3 = ChatOllama(model=local_llm, temperature=0)
  259. openai = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  260. retriever = create_faiss_retriever()
  261. questions = load_kg_q()
  262. result_list = []
  263. for question in questions:
  264. print(f"\nQuestion: {question}")
  265. with get_openai_callback() as cb:
  266. start_time = time()
  267. # openai_docs, openai_answer = faiss_query(retriever, question, openai, multi_query=True)
  268. documents, answer = faiss_query(retriever, question, openai, multi_query=True)
  269. openai_time = time() - start_time
  270. print(f"Answer: {answer}")
  271. print(f"Time: {openai_time:.4f} seconds")
  272. save_history(question, answer, cb, openai_time)
  273. result = {'Question': question, 'Answer': answer}
  274. result_list.append(result)
  275. df = pd.DataFrame.from_records(result_list)
  276. print(df)
  277. df.to_csv("kg_qa.csv", mode='w')
  278. print("\nQA complete.")
  279. def save_history(question, answer, cb, processing_time):
  280. # reference = [doc.dict() for doc in reference]
  281. record = {
  282. 'Question': question,
  283. 'Answer': answer,
  284. 'Total_Tokens': cb.total_tokens,
  285. 'Total_Cost': cb.total_cost,
  286. 'Processing_time': processing_time,
  287. }
  288. response = (
  289. supabase.table("agent_records")
  290. .insert(record)
  291. .execute()
  292. )
  293. if __name__ == "__main__":
  294. # asyncio.run(run_evaluation())
  295. asyncio.run(run_q_batch())
  296. # print(load_kg_q())