123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440 |
- import faiss
- import numpy as np
- import json
- from time import time
- import asyncio
- from datasets import Dataset
- from typing import List
- from dotenv import load_dotenv
- import os
- import pickle
- from supabase.client import Client, create_client
- from langchain_openai import OpenAIEmbeddings, ChatOpenAI
- from langchain.prompts import ChatPromptTemplate
- from langchain_core.output_parsers import StrOutputParser
- import pandas as pd
- from langchain_core.documents import Document
- from langchain.load import dumps, loads
- # Import from the parent directory
- import sys
- from RAG_strategy import multi_query_chain
- sys.path.append('..')
- # from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
- system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
- # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
- from langchain.prompts import ChatPromptTemplate
- from langchain_core.output_parsers import StrOutputParser
- from file_loader.add_vectordb import GetVectorStore
- # from local_llm import ollama_, hf
- # # from local_llm import ollama_, taide_llm, hf
- # llm = hf()
- # llm = taide_llm
- # Import RAGAS metrics
- from ragas import evaluate
- from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
- # Load environment variables
- load_dotenv('../../.env')
- supabase_url = os.getenv("SUPABASE_URL")
- supabase_key = os.getenv("SUPABASE_KEY")
- openai_api_key = os.getenv("OPENAI_API_KEY")
- document_table = "documents"
- # Initialize Supabase client
- supabase: Client = create_client(supabase_url, supabase_key)
- # Initialize embeddings and chat model
- embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
- def download_embeddings():
- response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
- embeddings = []
- ids = []
- metadatas = []
- contents = []
- for item in response.data:
- embedding = json.loads(item['embedding'])
- embeddings.append(embedding)
- ids.append(item['id'])
- metadatas.append(item['metadata'])
- contents.append(item['content'])
- return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
- def create_faiss_index(embeddings):
- dimension = embeddings.shape[1]
- index = faiss.IndexFlatIP(dimension) # Use Inner Product for cosine similarity
- faiss.normalize_L2(embeddings) # Normalize embeddings for cosine similarity
- index.add(embeddings)
- return index
- def save_faiss_index(index, file_path):
- faiss.write_index(index, file_path)
- print(f"FAISS index saved to {file_path}")
- def load_faiss_index(file_path):
- if os.path.exists(file_path):
- index = faiss.read_index(file_path)
- print(f"FAISS index loaded from {file_path}")
- return index
- return None
- def save_metadata(ids, metadatas, contents, file_path):
- with open(file_path, 'wb') as f:
- pickle.dump((ids, metadatas, contents), f)
- print(f"Metadata saved to {file_path}")
- def load_metadata(file_path):
- if os.path.exists(file_path):
- with open(file_path, 'rb') as f:
- ids, metadatas, contents = pickle.load(f)
- print(f"Metadata loaded from {file_path}")
- return ids, metadatas, contents
- return None, None, None
- def search_faiss(index, query_vector, k=4):
- query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
- faiss.normalize_L2(query_vector)
- distances, indices = index.search(query_vector, k)
- return distances[0], indices[0]
- class FAISSRetriever:
- def __init__(self, index, ids, metadatas, contents, embeddings_model):
- self.index = index
- self.ids = ids
- self.metadatas = metadatas
- self.contents = contents
- self.embeddings_model = embeddings_model
- def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
- query_vector = self.embeddings_model.embed_query(query)
- _, indices = search_faiss(self.index, query_vector, k=k)
- return [
- Document(page_content=self.contents[i], metadata=self.metadatas[i])
- for i in indices
- ]
- def map(self, query_list: List[list]) -> List[Document]:
- def get_unique_union(documents: List[list]):
- """ Unique union of retrieved docs """
- # Flatten list of lists, and convert each Document to string
- flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
- # Get unique documents
- unique_docs = list(set(flattened_docs))
- # Return
- return [loads(doc) for doc in unique_docs]
-
- documents = []
- for query in query_list:
- if query != "":
- docs = self.get_relevant_documents(query)
- documents.extend(docs)
- return get_unique_union(documents)
- def load_qa_pairs():
- # df = pd.read_csv("../QA_database_rows.csv")
- response = supabase.table('QA_database').select("Question, Answer").execute()
- df = pd.DataFrame(response.data)
- return df['Question'].tolist(), df['Answer'].tolist()
- def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
- generate_queries = multi_query_chain(llm)
- questions = generate_queries.invoke(question)
- questions = [item for item in questions if item != ""]
- # docs = list(map(retriever.get_relevant_documents, questions))
- docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
- docs = [item for sublist in docs for item in sublist]
- return docs
- def faiss_query(question: str, retriever: FAISSRetriever, llm, multi_query: bool = False) -> str:
- if multi_query:
- docs = faiss_multiquery(question, retriever, llm)
- # print(docs)
- else:
- docs = retriever.get_relevant_documents(question, k=10)
- # print(docs)
- context = "\n".join(doc.page_content for doc in docs)
-
- template = """
- <|begin_of_text|>
-
- <|start_header_id|>system<|end_header_id|>
- 你是一個來自台灣的ESG的AI助理,
- 請用繁體中文回答問題 \n
- You should not mention anything about "根據提供的文件內容" or other similar terms.
- Use five sentences maximum and keep the answer concise.
- 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- 勿回答無關資訊
- <|eot_id|>
-
- <|start_header_id|>user<|end_header_id|>
- Answer the following question based on this context:
- {context}
- Question: {question}
- <|eot_id|>
-
- <|start_header_id|>assistant<|end_header_id|>
- """
- prompt = ChatPromptTemplate.from_template(
- system_prompt + "\n\n" +
- template
- )
-
- # prompt = ChatPromptTemplate.from_template(
- # system_prompt + "\n\n" +
- # "Answer the following question based on this context:\n\n"
- # "{context}\n\n"
- # "Question: {question}\n"
- # "Answer in the same language as the question. If you don't know the answer, "
- # "say 'I'm sorry, I don't have enough information to answer that question.'"
- # )
-
- # chain = prompt | taide_llm | StrOutputParser()
- chain = prompt | llm | StrOutputParser()
- return chain.invoke({"context": context, "question": question})
- def create_faiss_retriever():
- faiss_index_path = "faiss_index.bin"
- metadata_path = "faiss_metadata.pkl"
- index = load_faiss_index(faiss_index_path)
- ids, metadatas, contents = load_metadata(metadata_path)
- if index is None or ids is None:
- print("FAISS index or metadata not found. Creating new index...")
- print("Downloading embeddings from Supabase...")
- embeddings_array, ids, metadatas, contents = download_embeddings()
- print("Creating FAISS index...")
- index = create_faiss_index(embeddings_array)
- save_faiss_index(index, faiss_index_path)
- save_metadata(ids, metadatas, contents, metadata_path)
- else:
- print("Using existing FAISS index and metadata.")
- print("Creating FAISS retriever...")
- faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
- return faiss_retriever
- async def run_evaluation():
- faiss_index_path = "faiss_index.bin"
- metadata_path = "faiss_metadata.pkl"
- index = load_faiss_index(faiss_index_path)
- ids, metadatas, contents = load_metadata(metadata_path)
- if index is None or ids is None:
- print("FAISS index or metadata not found. Creating new index...")
- print("Downloading embeddings from Supabase...")
- embeddings_array, ids, metadatas, contents = download_embeddings()
- print("Creating FAISS index...")
- index = create_faiss_index(embeddings_array)
- save_faiss_index(index, faiss_index_path)
- save_metadata(ids, metadatas, contents, metadata_path)
- else:
- print("Using existing FAISS index and metadata.")
- print("Creating FAISS retriever...")
- faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
- print("Creating original vector store...")
- original_vector_store = GetVectorStore(embeddings, supabase, document_table)
- original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
- questions, ground_truths = load_qa_pairs()
- for question, ground_truth in zip(questions, ground_truths):
- print(f"\nQuestion: {question}")
- start_time = time()
- faiss_answer = faiss_query(question, faiss_retriever)
- faiss_docs = faiss_retriever.get_relevant_documents(question)
- faiss_time = time() - start_time
- print(f"FAISS Answer: {faiss_answer}")
- print(f"FAISS Time: {faiss_time:.4f} seconds")
- start_time = time()
- original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
- original_time = time() - start_time
- print(f"Original Answer: {original_answer}")
- print(f"Original Time: {original_time:.4f} seconds")
- # faiss_datasets = {
- # "question": [question],
- # "answer": [faiss_answer],
- # "contexts": [[doc.page_content for doc in faiss_docs]],
- # "ground_truth": [ground_truth]
- # }
- # faiss_evalsets = Dataset.from_dict(faiss_datasets)
- # faiss_result = evaluate(
- # faiss_evalsets,
- # metrics=[
- # context_precision,
- # faithfulness,
- # answer_relevancy,
- # context_recall,
- # ],
- # )
- # print("FAISS RAGAS Evaluation:")
- # print(faiss_result.to_pandas())
- # original_datasets = {
- # "question": [question],
- # "answer": [original_answer],
- # "contexts": [[doc.page_content for doc in original_docs]],
- # "ground_truth": [ground_truth]
- # }
- # original_evalsets = Dataset.from_dict(original_datasets)
- # original_result = evaluate(
- # original_evalsets,
- # metrics=[
- # context_precision,
- # faithfulness,
- # answer_relevancy,
- # context_recall,
- # ],
- # )
- # print("Original RAGAS Evaluation:")
- # print(original_result.to_pandas())
- print("\nPerformance comparison complete.")
- async def ask_question():
- faiss_index_path = "faiss_index.bin"
- metadata_path = "faiss_metadata.pkl"
- index = load_faiss_index(faiss_index_path)
- ids, metadatas, contents = load_metadata(metadata_path)
- if index is None or ids is None:
- print("FAISS index or metadata not found. Creating new index...")
- print("Downloading embeddings from Supabase...")
- embeddings_array, ids, metadatas, contents = download_embeddings()
- print("Creating FAISS index...")
- index = create_faiss_index(embeddings_array)
- save_faiss_index(index, faiss_index_path)
- save_metadata(ids, metadatas, contents, metadata_path)
- else:
- print("Using existing FAISS index and metadata.")
- print("Creating FAISS retriever...")
- faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
- # print("Creating original vector store...")
- # original_vector_store = GetVectorStore(embeddings, supabase, document_table)
- # original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
- # questions, ground_truths = load_qa_pairs()
- # for question, ground_truth in zip(questions, ground_truths):
- question = ""
- while question != "exit":
- question = input("Question: ")
- print(f"\nQuestion: {question}")
- start_time = time()
- faiss_answer = faiss_query(question, faiss_retriever)
- faiss_docs = faiss_retriever.get_relevant_documents(question)
- faiss_time = time() - start_time
- print(f"FAISS Answer: {faiss_answer}")
- print(f"FAISS Time: {faiss_time:.4f} seconds")
- # start_time = time()
- # original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
- # original_time = time() - start_time
- # print(f"Original Answer: {original_answer}")
- # print(f"Original Time: {original_time:.4f} seconds")
- if __name__ == "__main__":
- global_retriever = create_faiss_retriever()
- questions, ground_truths = load_qa_pairs()
- results = []
- for question, ground_truth in zip(questions, ground_truths):
- # For multi_query=True
- start = time()
- final_answer_multi = faiss_query(question, global_retriever, multi_query=True)
- processing_time_multi = time() - start
- # print(final_answer_multi)
- # print(processing_time_multi)
- # For multi_query=False
- start = time()
- final_answer_single = faiss_query(question, global_retriever, multi_query=False)
- processing_time_single = time() - start
- # print(final_answer_single)
- # print(processing_time_single)
- # Store results in a dictionary
- result = {
- "question": question,
- "ground_truth": ground_truth,
- "final_answer_multi_query": final_answer_multi,
- "processing_time_multi_query": processing_time_multi,
- "final_answer_single_query": final_answer_single,
- "processing_time_single_query": processing_time_single
- }
- print(result)
-
- results.append(result)
- with open('qa_results.json', 'a', encoding='utf8') as outfile:
- json.dump(result, outfile, indent=4, ensure_ascii=False)
- outfile.write("\n") # Ensure each result is on a new line
-
- # Save results to a JSON file
- with open('qa_results_all.json', 'w', encoding='utf8') as outfile:
- json.dump(results, outfile, indent=4, ensure_ascii=False)
- print('All questions done!')
- # question = ""
- # while question != "exit":
- # # question = "國家溫室氣體長期減量目標"
- # question = input("Question: ")
- # if question.strip().lower == "exit": break
- # start = time()
- # final_answer = faiss_query(question, global_retriever, multi_query=True)
- # print(final_answer)
- # processing_time = time() - start
- # print(processing_time)
-
- # start = time()
- # final_answer = faiss_query(question, global_retriever, multi_query=False)
- # print(final_answer)
- # processing_time = time() - start
- # print(processing_time)
- # print("Chatbot closed!")
- # asyncio.run(ask_question())
|