|
@@ -0,0 +1,440 @@
|
|
|
+import faiss
|
|
|
+import numpy as np
|
|
|
+import json
|
|
|
+from time import time
|
|
|
+import asyncio
|
|
|
+from datasets import Dataset
|
|
|
+from typing import List
|
|
|
+from dotenv import load_dotenv
|
|
|
+import os
|
|
|
+import pickle
|
|
|
+from supabase.client import Client, create_client
|
|
|
+from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
|
+from langchain.prompts import ChatPromptTemplate
|
|
|
+from langchain_core.output_parsers import StrOutputParser
|
|
|
+import pandas as pd
|
|
|
+from langchain_core.documents import Document
|
|
|
+from langchain.load import dumps, loads
|
|
|
+
|
|
|
+# Import from the parent directory
|
|
|
+import sys
|
|
|
+
|
|
|
+from RAG_strategy import multi_query_chain
|
|
|
+sys.path.append('..')
|
|
|
+# from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
|
|
|
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
+# llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
|
|
+
|
|
|
+
|
|
|
+from langchain.prompts import ChatPromptTemplate
|
|
|
+from langchain_core.output_parsers import StrOutputParser
|
|
|
+from add_vectordb import GetVectorStore
|
|
|
+# from local_llm import ollama_, hf
|
|
|
+# # from local_llm import ollama_, taide_llm, hf
|
|
|
+# llm = hf()
|
|
|
+# llm = taide_llm
|
|
|
+
|
|
|
+
|
|
|
+# Import RAGAS metrics
|
|
|
+from ragas import evaluate
|
|
|
+from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
|
|
|
+
|
|
|
+# Load environment variables
|
|
|
+load_dotenv('../../.env')
|
|
|
+supabase_url = os.getenv("SUPABASE_URL")
|
|
|
+supabase_key = os.getenv("SUPABASE_KEY")
|
|
|
+openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
+document_table = "documents"
|
|
|
+
|
|
|
+# Initialize Supabase client
|
|
|
+supabase: Client = create_client(supabase_url, supabase_key)
|
|
|
+
|
|
|
+# Initialize embeddings and chat model
|
|
|
+embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
|
|
+
|
|
|
+def download_embeddings():
|
|
|
+ response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
|
|
|
+ embeddings = []
|
|
|
+ ids = []
|
|
|
+ metadatas = []
|
|
|
+ contents = []
|
|
|
+ for item in response.data:
|
|
|
+ embedding = json.loads(item['embedding'])
|
|
|
+ embeddings.append(embedding)
|
|
|
+ ids.append(item['id'])
|
|
|
+ metadatas.append(item['metadata'])
|
|
|
+ contents.append(item['content'])
|
|
|
+ return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
|
|
|
+
|
|
|
+def create_faiss_index(embeddings):
|
|
|
+ dimension = embeddings.shape[1]
|
|
|
+ index = faiss.IndexFlatIP(dimension) # Use Inner Product for cosine similarity
|
|
|
+ faiss.normalize_L2(embeddings) # Normalize embeddings for cosine similarity
|
|
|
+ index.add(embeddings)
|
|
|
+ return index
|
|
|
+
|
|
|
+def save_faiss_index(index, file_path):
|
|
|
+ faiss.write_index(index, file_path)
|
|
|
+ print(f"FAISS index saved to {file_path}")
|
|
|
+
|
|
|
+def load_faiss_index(file_path):
|
|
|
+ if os.path.exists(file_path):
|
|
|
+ index = faiss.read_index(file_path)
|
|
|
+ print(f"FAISS index loaded from {file_path}")
|
|
|
+ return index
|
|
|
+ return None
|
|
|
+
|
|
|
+def save_metadata(ids, metadatas, contents, file_path):
|
|
|
+ with open(file_path, 'wb') as f:
|
|
|
+ pickle.dump((ids, metadatas, contents), f)
|
|
|
+ print(f"Metadata saved to {file_path}")
|
|
|
+
|
|
|
+def load_metadata(file_path):
|
|
|
+ if os.path.exists(file_path):
|
|
|
+ with open(file_path, 'rb') as f:
|
|
|
+ ids, metadatas, contents = pickle.load(f)
|
|
|
+ print(f"Metadata loaded from {file_path}")
|
|
|
+ return ids, metadatas, contents
|
|
|
+ return None, None, None
|
|
|
+
|
|
|
+def search_faiss(index, query_vector, k=4):
|
|
|
+ query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
|
|
|
+ faiss.normalize_L2(query_vector)
|
|
|
+ distances, indices = index.search(query_vector, k)
|
|
|
+ return distances[0], indices[0]
|
|
|
+
|
|
|
+class FAISSRetriever:
|
|
|
+ def __init__(self, index, ids, metadatas, contents, embeddings_model):
|
|
|
+ self.index = index
|
|
|
+ self.ids = ids
|
|
|
+ self.metadatas = metadatas
|
|
|
+ self.contents = contents
|
|
|
+ self.embeddings_model = embeddings_model
|
|
|
+
|
|
|
+ def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
|
|
|
+ query_vector = self.embeddings_model.embed_query(query)
|
|
|
+ _, indices = search_faiss(self.index, query_vector, k=k)
|
|
|
+ return [
|
|
|
+ Document(page_content=self.contents[i], metadata=self.metadatas[i])
|
|
|
+ for i in indices
|
|
|
+ ]
|
|
|
+ def map(self, query_list: List[list]) -> List[Document]:
|
|
|
+ def get_unique_union(documents: List[list]):
|
|
|
+ """ Unique union of retrieved docs """
|
|
|
+ # Flatten list of lists, and convert each Document to string
|
|
|
+ flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
|
|
|
+ # Get unique documents
|
|
|
+ unique_docs = list(set(flattened_docs))
|
|
|
+ # Return
|
|
|
+ return [loads(doc) for doc in unique_docs]
|
|
|
+
|
|
|
+ documents = []
|
|
|
+ for query in query_list:
|
|
|
+ if query != "":
|
|
|
+ docs = self.get_relevant_documents(query)
|
|
|
+
|
|
|
+ documents.extend(docs)
|
|
|
+
|
|
|
+ return get_unique_union(documents)
|
|
|
+
|
|
|
+def load_qa_pairs():
|
|
|
+ # df = pd.read_csv("../QA_database_rows.csv")
|
|
|
+ response = supabase.table('QA_database').select("Question, Answer").execute()
|
|
|
+ df = pd.DataFrame(response.data)
|
|
|
+
|
|
|
+ return df['Question'].tolist(), df['Answer'].tolist()
|
|
|
+
|
|
|
+def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
|
|
|
+ generate_queries = multi_query_chain(llm)
|
|
|
+
|
|
|
+ questions = generate_queries.invoke(question)
|
|
|
+ questions = [item for item in questions if item != ""]
|
|
|
+
|
|
|
+ # docs = list(map(retriever.get_relevant_documents, questions))
|
|
|
+ docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
|
|
|
+ docs = [item for sublist in docs for item in sublist]
|
|
|
+
|
|
|
+ return docs
|
|
|
+
|
|
|
+def faiss_query(question: str, retriever: FAISSRetriever, llm, multi_query: bool = False) -> str:
|
|
|
+ if multi_query:
|
|
|
+ docs = faiss_multiquery(question, retriever, llm)
|
|
|
+ # print(docs)
|
|
|
+ else:
|
|
|
+ docs = retriever.get_relevant_documents(question, k=10)
|
|
|
+ # print(docs)
|
|
|
+
|
|
|
+ context = "\n".join(doc.page_content for doc in docs)
|
|
|
+
|
|
|
+ template = """
|
|
|
+ <|begin_of_text|>
|
|
|
+
|
|
|
+ <|start_header_id|>system<|end_header_id|>
|
|
|
+ 你是一個來自台灣的ESG的AI助理,
|
|
|
+ 請用繁體中文回答問題 \n
|
|
|
+ You should not mention anything about "根據提供的文件內容" or other similar terms.
|
|
|
+ Use five sentences maximum and keep the answer concise.
|
|
|
+ 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
+ 勿回答無關資訊
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>user<|end_header_id|>
|
|
|
+ Answer the following question based on this context:
|
|
|
+
|
|
|
+ {context}
|
|
|
+
|
|
|
+ Question: {question}
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>assistant<|end_header_id|>
|
|
|
+ """
|
|
|
+ prompt = ChatPromptTemplate.from_template(
|
|
|
+ system_prompt + "\n\n" +
|
|
|
+ template
|
|
|
+ )
|
|
|
+
|
|
|
+ # prompt = ChatPromptTemplate.from_template(
|
|
|
+ # system_prompt + "\n\n" +
|
|
|
+ # "Answer the following question based on this context:\n\n"
|
|
|
+ # "{context}\n\n"
|
|
|
+ # "Question: {question}\n"
|
|
|
+ # "Answer in the same language as the question. If you don't know the answer, "
|
|
|
+ # "say 'I'm sorry, I don't have enough information to answer that question.'"
|
|
|
+ # )
|
|
|
+
|
|
|
+
|
|
|
+ # chain = prompt | taide_llm | StrOutputParser()
|
|
|
+ chain = prompt | llm | StrOutputParser()
|
|
|
+ return chain.invoke({"context": context, "question": question})
|
|
|
+
|
|
|
+
|
|
|
+def create_faiss_retriever():
|
|
|
+ faiss_index_path = "faiss_index.bin"
|
|
|
+ metadata_path = "faiss_metadata.pkl"
|
|
|
+
|
|
|
+ index = load_faiss_index(faiss_index_path)
|
|
|
+ ids, metadatas, contents = load_metadata(metadata_path)
|
|
|
+
|
|
|
+ if index is None or ids is None:
|
|
|
+ print("FAISS index or metadata not found. Creating new index...")
|
|
|
+ print("Downloading embeddings from Supabase...")
|
|
|
+ embeddings_array, ids, metadatas, contents = download_embeddings()
|
|
|
+
|
|
|
+ print("Creating FAISS index...")
|
|
|
+ index = create_faiss_index(embeddings_array)
|
|
|
+
|
|
|
+ save_faiss_index(index, faiss_index_path)
|
|
|
+ save_metadata(ids, metadatas, contents, metadata_path)
|
|
|
+ else:
|
|
|
+ print("Using existing FAISS index and metadata.")
|
|
|
+
|
|
|
+ print("Creating FAISS retriever...")
|
|
|
+ faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
|
|
|
+
|
|
|
+ return faiss_retriever
|
|
|
+
|
|
|
+
|
|
|
+async def run_evaluation():
|
|
|
+ faiss_index_path = "faiss_index.bin"
|
|
|
+ metadata_path = "faiss_metadata.pkl"
|
|
|
+
|
|
|
+ index = load_faiss_index(faiss_index_path)
|
|
|
+ ids, metadatas, contents = load_metadata(metadata_path)
|
|
|
+
|
|
|
+ if index is None or ids is None:
|
|
|
+ print("FAISS index or metadata not found. Creating new index...")
|
|
|
+ print("Downloading embeddings from Supabase...")
|
|
|
+ embeddings_array, ids, metadatas, contents = download_embeddings()
|
|
|
+
|
|
|
+ print("Creating FAISS index...")
|
|
|
+ index = create_faiss_index(embeddings_array)
|
|
|
+
|
|
|
+ save_faiss_index(index, faiss_index_path)
|
|
|
+ save_metadata(ids, metadatas, contents, metadata_path)
|
|
|
+ else:
|
|
|
+ print("Using existing FAISS index and metadata.")
|
|
|
+
|
|
|
+ print("Creating FAISS retriever...")
|
|
|
+ faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
|
|
|
+
|
|
|
+ print("Creating original vector store...")
|
|
|
+ original_vector_store = GetVectorStore(embeddings, supabase, document_table)
|
|
|
+ original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
|
|
|
+
|
|
|
+ questions, ground_truths = load_qa_pairs()
|
|
|
+
|
|
|
+ for question, ground_truth in zip(questions, ground_truths):
|
|
|
+ print(f"\nQuestion: {question}")
|
|
|
+
|
|
|
+ start_time = time()
|
|
|
+ faiss_answer = faiss_query(question, faiss_retriever)
|
|
|
+ faiss_docs = faiss_retriever.get_relevant_documents(question)
|
|
|
+ faiss_time = time() - start_time
|
|
|
+ print(f"FAISS Answer: {faiss_answer}")
|
|
|
+ print(f"FAISS Time: {faiss_time:.4f} seconds")
|
|
|
+
|
|
|
+ start_time = time()
|
|
|
+ original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
|
|
|
+ original_time = time() - start_time
|
|
|
+ print(f"Original Answer: {original_answer}")
|
|
|
+ print(f"Original Time: {original_time:.4f} seconds")
|
|
|
+
|
|
|
+ # faiss_datasets = {
|
|
|
+ # "question": [question],
|
|
|
+ # "answer": [faiss_answer],
|
|
|
+ # "contexts": [[doc.page_content for doc in faiss_docs]],
|
|
|
+ # "ground_truth": [ground_truth]
|
|
|
+ # }
|
|
|
+ # faiss_evalsets = Dataset.from_dict(faiss_datasets)
|
|
|
+
|
|
|
+ # faiss_result = evaluate(
|
|
|
+ # faiss_evalsets,
|
|
|
+ # metrics=[
|
|
|
+ # context_precision,
|
|
|
+ # faithfulness,
|
|
|
+ # answer_relevancy,
|
|
|
+ # context_recall,
|
|
|
+ # ],
|
|
|
+ # )
|
|
|
+
|
|
|
+ # print("FAISS RAGAS Evaluation:")
|
|
|
+ # print(faiss_result.to_pandas())
|
|
|
+
|
|
|
+ # original_datasets = {
|
|
|
+ # "question": [question],
|
|
|
+ # "answer": [original_answer],
|
|
|
+ # "contexts": [[doc.page_content for doc in original_docs]],
|
|
|
+ # "ground_truth": [ground_truth]
|
|
|
+ # }
|
|
|
+ # original_evalsets = Dataset.from_dict(original_datasets)
|
|
|
+
|
|
|
+ # original_result = evaluate(
|
|
|
+ # original_evalsets,
|
|
|
+ # metrics=[
|
|
|
+ # context_precision,
|
|
|
+ # faithfulness,
|
|
|
+ # answer_relevancy,
|
|
|
+ # context_recall,
|
|
|
+ # ],
|
|
|
+ # )
|
|
|
+
|
|
|
+ # print("Original RAGAS Evaluation:")
|
|
|
+ # print(original_result.to_pandas())
|
|
|
+
|
|
|
+ print("\nPerformance comparison complete.")
|
|
|
+
|
|
|
+
|
|
|
+async def ask_question():
|
|
|
+ faiss_index_path = "faiss_index.bin"
|
|
|
+ metadata_path = "faiss_metadata.pkl"
|
|
|
+
|
|
|
+ index = load_faiss_index(faiss_index_path)
|
|
|
+ ids, metadatas, contents = load_metadata(metadata_path)
|
|
|
+
|
|
|
+ if index is None or ids is None:
|
|
|
+ print("FAISS index or metadata not found. Creating new index...")
|
|
|
+ print("Downloading embeddings from Supabase...")
|
|
|
+ embeddings_array, ids, metadatas, contents = download_embeddings()
|
|
|
+
|
|
|
+ print("Creating FAISS index...")
|
|
|
+ index = create_faiss_index(embeddings_array)
|
|
|
+
|
|
|
+ save_faiss_index(index, faiss_index_path)
|
|
|
+ save_metadata(ids, metadatas, contents, metadata_path)
|
|
|
+ else:
|
|
|
+ print("Using existing FAISS index and metadata.")
|
|
|
+
|
|
|
+ print("Creating FAISS retriever...")
|
|
|
+ faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
|
|
|
+
|
|
|
+ # print("Creating original vector store...")
|
|
|
+ # original_vector_store = GetVectorStore(embeddings, supabase, document_table)
|
|
|
+ # original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
|
|
|
+
|
|
|
+ # questions, ground_truths = load_qa_pairs()
|
|
|
+
|
|
|
+ # for question, ground_truth in zip(questions, ground_truths):
|
|
|
+ question = ""
|
|
|
+ while question != "exit":
|
|
|
+ question = input("Question: ")
|
|
|
+ print(f"\nQuestion: {question}")
|
|
|
+
|
|
|
+ start_time = time()
|
|
|
+ faiss_answer = faiss_query(question, faiss_retriever)
|
|
|
+ faiss_docs = faiss_retriever.get_relevant_documents(question)
|
|
|
+ faiss_time = time() - start_time
|
|
|
+ print(f"FAISS Answer: {faiss_answer}")
|
|
|
+ print(f"FAISS Time: {faiss_time:.4f} seconds")
|
|
|
+
|
|
|
+ # start_time = time()
|
|
|
+ # original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
|
|
|
+ # original_time = time() - start_time
|
|
|
+ # print(f"Original Answer: {original_answer}")
|
|
|
+ # print(f"Original Time: {original_time:.4f} seconds")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+
|
|
|
+ global_retriever = create_faiss_retriever()
|
|
|
+
|
|
|
+ questions, ground_truths = load_qa_pairs()
|
|
|
+ results = []
|
|
|
+
|
|
|
+ for question, ground_truth in zip(questions, ground_truths):
|
|
|
+ # For multi_query=True
|
|
|
+ start = time()
|
|
|
+ final_answer_multi = faiss_query(question, global_retriever, multi_query=True)
|
|
|
+ processing_time_multi = time() - start
|
|
|
+ # print(final_answer_multi)
|
|
|
+ # print(processing_time_multi)
|
|
|
+
|
|
|
+ # For multi_query=False
|
|
|
+ start = time()
|
|
|
+ final_answer_single = faiss_query(question, global_retriever, multi_query=False)
|
|
|
+ processing_time_single = time() - start
|
|
|
+ # print(final_answer_single)
|
|
|
+ # print(processing_time_single)
|
|
|
+
|
|
|
+ # Store results in a dictionary
|
|
|
+ result = {
|
|
|
+ "question": question,
|
|
|
+ "ground_truth": ground_truth,
|
|
|
+ "final_answer_multi_query": final_answer_multi,
|
|
|
+ "processing_time_multi_query": processing_time_multi,
|
|
|
+ "final_answer_single_query": final_answer_single,
|
|
|
+ "processing_time_single_query": processing_time_single
|
|
|
+ }
|
|
|
+ print(result)
|
|
|
+
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ with open('qa_results.json', 'a', encoding='utf8') as outfile:
|
|
|
+ json.dump(result, outfile, indent=4, ensure_ascii=False)
|
|
|
+ outfile.write("\n") # Ensure each result is on a new line
|
|
|
+
|
|
|
+
|
|
|
+ # Save results to a JSON file
|
|
|
+ with open('qa_results+all.json', 'w', encoding='utf8') as outfile:
|
|
|
+ json.dump(results, outfile, indent=4, ensure_ascii=False)
|
|
|
+
|
|
|
+ print('All questions done!')
|
|
|
+ # question = ""
|
|
|
+ # while question != "exit":
|
|
|
+ # # question = "國家溫室氣體長期減量目標"
|
|
|
+ # question = input("Question: ")
|
|
|
+ # if question.strip().lower == "exit": break
|
|
|
+
|
|
|
+ # start = time()
|
|
|
+ # final_answer = faiss_query(question, global_retriever, multi_query=True)
|
|
|
+ # print(final_answer)
|
|
|
+ # processing_time = time() - start
|
|
|
+ # print(processing_time)
|
|
|
+
|
|
|
+
|
|
|
+ # start = time()
|
|
|
+ # final_answer = faiss_query(question, global_retriever, multi_query=False)
|
|
|
+ # print(final_answer)
|
|
|
+ # processing_time = time() - start
|
|
|
+ # print(processing_time)
|
|
|
+ # print("Chatbot closed!")
|
|
|
+
|
|
|
+ # asyncio.run(ask_question())
|