Browse Source

Integrate Llama 3.1 with RAG using LangChain's Huggingface module

ling 10 tháng trước cách đây
commit
bf1c706ac0
11 tập tin đã thay đổi với 1837 bổ sung0 xóa
  1. 5 0
      .gitignore
  2. 174 0
      Indexing_Split.py
  3. 198 0
      RAG_app.py
  4. 260 0
      RAG_strategy.py
  5. 194 0
      add_vectordb.py
  6. 80 0
      conda_env.txt
  7. 440 0
      faiss_index.py
  8. 119 0
      local_llm.py
  9. 107 0
      pip_env.txt
  10. 73 0
      semantic_search.py
  11. 187 0
      text_to_sql.py

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+__pycache__/
+chroma_db_carbon_questions/
+faiss_index.bin
+faiss_metadata.pkl
+.env

+ 174 - 0
Indexing_Split.py

@@ -0,0 +1,174 @@
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.vectorstores import Chroma
+from langchain_community.document_loaders import TextLoader
+from langchain.text_splitter import CharacterTextSplitter
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_core.documents import Document
+from langchain_community.document_loaders import PyPDFLoader
+from langchain_community.document_loaders import Docx2txtLoader
+from langchain_community.document_loaders import WebBaseLoader
+from PyPDF2 import PdfReader
+from langchain.docstore.document import Document
+from json import loads
+import pandas as pd
+from sqlalchemy import create_engine
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_openai import ChatOpenAI
+from langchain_core.output_parsers import StrOutputParser
+from langchain import hub
+from tqdm import tqdm
+
+# __import__('pysqlite3')
+# import sys
+# sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
+
+from datasets import Dataset 
+# from ragas import evaluate
+# from ragas.metrics import (
+#     answer_relevancy,
+#     faithfulness,
+#     context_recall,
+#     context_precision,
+# )
+import pandas as pd
+import os
+import glob
+
+from dotenv import load_dotenv
+import os
+load_dotenv()
+URI = os.getenv("SUPABASE_URI")
+
+from RAG_strategy import multi_query, naive_rag
+
+def create_retriever(path='Documents', extension="pdf"):
+    txt_files = glob.glob(os.path.join(path, f"*.{extension}"))
+    
+    doc = []
+    for file_path in txt_files:
+        doc.append(file_path)
+    
+    def load_and_split(file_list):
+        chunks = []
+        for file in file_list:
+            if file.endswith(".txt"):
+                loader = TextLoader(file, encoding='utf-8')
+            elif file.endswith(".pdf"):
+                loader = PyPDFLoader(file)
+            elif file.endswith(".docx"):
+                loader = Docx2txtLoader(file)
+            else:
+                raise ValueError(f"Unsupported file extension: {file}")
+            
+
+            docs = loader.load()
+
+            # Split
+            if file.endswith(".docx"):
+                # separators = ["\n\n\u25cb", "\n\n\u25cf"]
+                # text_splitter = RecursiveCharacterTextSplitter(separators=separators, chunk_size=500, chunk_overlap=0)
+                separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
+                text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
+            else:
+                text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
+            
+            splits = text_splitter.split_documents(docs)
+
+            chunks.extend(splits)
+
+        return chunks
+
+    # Index
+    docs = load_and_split(doc)
+    qa_history_doc = gen_doc_from_history()
+    docs.extend(qa_history_doc)
+    # web_doc = web_data(os.path.join(path, 'web_url.csv'))
+    # docs.extend(web_doc)
+
+    # vectorstore
+    # vectorstore = Chroma.from_texts(texts=docs, embedding=OpenAIEmbeddings())
+    vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings())
+    # vectorstore = Chroma.from_documents(documents=docs, embedding=OllamaEmbeddings(model="llama3", num_gpu=1))
+    vectorstore.persist()
+
+    retriever = vectorstore.as_retriever()
+
+    return retriever
+
+def web_data(url_file):
+    df = pd.read_csv(url_file, header = 0)
+    url_list = df['url'].to_list()
+
+    loader = WebBaseLoader(url_list)
+    docs = loader.load()
+
+    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
+                chunk_size=1000, chunk_overlap=0)
+    splits = text_splitter.split_documents(docs)
+    
+    return splits
+
+def gen_doc_from_history():
+    engine = create_engine(URI, echo=True)
+
+    df = pd.read_sql_table("systex_records", engine.connect())  
+    df.fillna('', inplace=True)
+    result = df.to_json(orient='index', force_ascii=False)
+    result = loads(result)
+
+
+    df = pd.DataFrame(result).T
+    qa_history_doc = []
+    for i in range(len(df)):
+        if df.iloc[i]['used_as_document'] is not True: continue
+        Question = df.iloc[i]['Question']
+        Answer = df.iloc[i]['Answer']
+        context = f'Question: {Question}\nAnswer: {Answer}'
+        
+        doc =  Document(page_content=context, metadata={"source": "History"})
+        qa_history_doc.append(doc)
+        # print(doc)
+
+    return qa_history_doc
+
+def gen_doc_from_database():
+    engine = create_engine(URI, echo=True)
+
+    df = pd.read_sql_table("QA_database", engine.connect())  
+    # df.fillna('', inplace=True)
+    result = df[['Question', 'Answer']].to_json(orient='index', force_ascii=False)
+    result = loads(result)
+
+
+    df = pd.DataFrame(result).T
+    qa_doc = []
+    for i in range(len(df)):
+        # if df.iloc[i]['used_as_document'] is not True: continue
+        Question = df.iloc[i]['Question']
+        Answer = df.iloc[i]['Answer']
+        context = f'Question: {Question}\nAnswer: {Answer}'
+        
+        doc = Document(page_content=context, metadata={"source": "History"})
+        qa_doc.append(doc)
+        # print(doc)
+
+    return qa_doc
+
+if __name__ == "__main__":
+
+    retriever = create_retriever(path='./Documents', extension="pdf")
+    question = 'CEV系統可以支援盤查到什麼程度'
+    final_answer, reference_docs = multi_query(question, retriever)
+    print(question, final_answer)
+    question = 'CEV系統依循標準為何'
+    final_answer, reference_docs = multi_query(question, retriever)
+    print(question, final_answer)
+
+
+
+

+ 198 - 0
RAG_app.py

@@ -0,0 +1,198 @@
+import re
+import threading
+from fastapi import FastAPI, Request, HTTPException, Response, status, Body
+# from fastapi.templating import Jinja2Templates
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, JSONResponse
+from fastapi import Depends
+from contextlib import asynccontextmanager
+from pydantic import BaseModel
+from typing import List, Optional
+import uvicorn
+
+import requests
+
+from typing import List, Optional
+import sqlparse
+from sqlalchemy import create_engine
+import pandas as pd
+#from retrying import retry
+import datetime
+import json
+from json import loads
+import pandas as pd
+import time
+from langchain.callbacks import get_openai_callback
+
+from langchain_community.vectorstores import Chroma
+from langchain_openai import OpenAIEmbeddings
+# from RAG_strategy import multi_query, naive_rag
+# from Indexing_Split import create_retriever as split_retriever
+# from Indexing_Split import gen_doc_from_database, gen_doc_from_history
+from semantic_search import semantic_cache
+
+from dotenv import load_dotenv
+import os
+from langchain_community.vectorstores import SupabaseVectorStore
+from langchain_openai import OpenAIEmbeddings
+from supabase.client import Client, create_client
+
+
+from add_vectordb import GetVectorStore
+from faiss_index import create_faiss_retriever, faiss_query
+from local_llm import ollama_, hf
+# from local_llm import ollama_, taide_llm, hf
+# llm = hf()
+
+load_dotenv()
+URI = os.getenv("SUPABASE_URI")
+supabase_url = os.environ.get("SUPABASE_URL")
+supabase_key = os.environ.get("SUPABASE_KEY")
+supabase: Client = create_client(supabase_url, supabase_key)
+
+global_retriever = None
+llm = None
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    global global_retriever
+    global llm
+    global vector_store
+    
+    start = time.time()
+    document_table = "documents"
+
+    embeddings = OpenAIEmbeddings()
+    # vector_store = GetVectorStore(embeddings, supabase, document_table)
+    # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
+    global_retriever = create_faiss_retriever()
+    llm = hf()
+
+    print(time.time() - start)
+    yield
+
+def get_retriever():
+    return global_retriever
+
+def get_llm():
+    return llm
+
+def get_vector_store():
+    return vector_store
+
+app = FastAPI(lifespan=lifespan)
+
+# templates = Jinja2Templates(directory="temp")
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+class ChatHistoryItem(BaseModel):
+    q: str
+    a: str
+
+def replace_unicode_escapes(match):
+    return chr(int(match.group(1), 16))
+
+@app.post("/answer_with_history")
+def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?', chat_history: List[ChatHistoryItem] = Body(...), 
+                       retriever=Depends(get_retriever), llm=Depends(get_llm)):
+    start = time.time()
+    
+    chat_history = [(item.q, item.a) for item in chat_history if item.a != "" and item.a != "string"]
+    print(chat_history)
+
+    # TODO: similarity search
+    
+    with get_openai_callback() as cb:
+
+        # cache_question, cache_answer = semantic_cache(supabase, question)
+        # if cache_answer:
+        #     processing_time = time.time() - start
+        #     save_history(question, cache_answer, cache_question, cb, processing_time)
+
+        #     return {"Answer": cache_answer}
+        
+        # final_answer, reference_docs = multi_query(question, retriever, chat_history)
+        # final_answer, reference_docs = naive_rag(question, retriever, chat_history)
+        final_answer = faiss_query(question, global_retriever, llm)
+        
+        decoded_string  = re.sub(r'\\u([0-9a-fA-F]{4})', replace_unicode_escapes, final_answer)
+        print(decoded_string )
+
+        reference_docs = global_retriever.get_relevant_documents(question)
+    processing_time = time.time() - start
+    print(processing_time)
+    save_history(question, decoded_string , reference_docs, cb, processing_time)
+
+    # print(response)
+    response_content = json.dumps({"Answer": decoded_string }, ensure_ascii=False)
+
+    # Manually create a Response object if using Flask
+    return JSONResponse(content=response_content)
+    # response_content = json.dumps({"Answer": final_answer}, ensure_ascii=False)
+    # print(response_content)
+    # return json.loads(response_content)
+
+    
+
+def save_history(question, answer, reference, cb, processing_time):
+    # reference = [doc.dict() for doc in reference]
+    record = {
+        'Question': [question],
+        'Answer': [answer],
+        'Total_Tokens': [cb.total_tokens],
+        'Total_Cost': [cb.total_cost],
+        'Processing_time': [processing_time],
+        'Contexts': [str(reference)] if isinstance(reference, list) else [reference]
+    }
+    df = pd.DataFrame(record)
+    engine = create_engine(URI)
+    df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
+
+class history_output(BaseModel):
+    Question: str
+    Answer: str
+    Contexts: str
+    Total_Tokens: int
+    Total_Cost: float
+    Processing_time: float
+    Time: datetime.datetime
+    
+@app.get('/history', response_model=List[history_output])
+async def get_history():
+    engine = create_engine(URI, echo=True)
+
+    df = pd.read_sql_table("systex_records", engine.connect())  
+    df.fillna('', inplace=True)
+    result = df.to_json(orient='index', force_ascii=False)
+    result = loads(result)
+    return result.values()
+
+def send_heartbeat(url):
+    while True:
+        try:
+            response = requests.get(url)
+            if response.status_code != 200:
+                print(f"Failed to send heartbeat, status code: {response.status_code}")
+        except requests.RequestException as e:
+            print(f"Error occurred: {e}")
+        # 等待 60 秒
+        time.sleep(600)
+
+def start_heartbeat(url):
+    heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url,))
+    heartbeat_thread.daemon = True
+    heartbeat_thread.start()
+    
+if __name__ == "__main__":
+    
+    # url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
+    # start_heartbeat(url)
+    
+    uvicorn.run("RAG_app:app", host='0.0.0.0', reload=True, port=8080)
+

+ 260 - 0
RAG_strategy.py

@@ -0,0 +1,260 @@
+from langchain.prompts import ChatPromptTemplate
+from langchain.load import dumps, loads
+from langchain_core.output_parsers import StrOutputParser
+from langchain_openai import ChatOpenAI
+from langchain_community.llms import Ollama
+from langchain_community.chat_models import ChatOllama
+from operator import itemgetter
+from langchain_core.runnables import RunnablePassthrough
+from langchain import hub
+from langchain.globals import set_llm_cache
+from langchain import PromptTemplate
+
+
+from langchain_core.runnables import (
+    RunnableBranch,
+    RunnableLambda,
+    RunnableParallel,
+    RunnablePassthrough,
+)
+from typing import Tuple, List, Optional
+from langchain_core.messages import AIMessage, HumanMessage
+
+from typing import List
+from dotenv import load_dotenv
+load_dotenv()
+
+# from local_llm import ollama_, hf
+# llm = hf()
+# llm = taide_llm
+########################################################################################################################
+# from langchain.cache import SQLiteCache
+# set_llm_cache(SQLiteCache(database_path=".langchain.db"))
+########################################################################################################################
+# llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+# llm = ollama_()
+
+def multi_query_chain(llm):
+    # Multi Query: Different Perspectives
+    template = """You are an AI language model assistant. Your task is to generate three 
+    different versions of the given user question to retrieve relevant documents from a vector 
+    database. By generating multiple perspectives on the user question, your goal is to help
+    the user overcome some of the limitations of the distance-based similarity search. 
+    Provide these alternative questions separated by newlines. 
+
+    You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
+    
+    
+    Original question: {question}"""
+    prompt_perspectives = ChatPromptTemplate.from_template(template)
+
+    
+    # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
+    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
+
+    generate_queries = (
+        prompt_perspectives 
+        | llm
+        | StrOutputParser() 
+        | (lambda x: x.split("\n"))
+    )
+
+    return generate_queries
+
+def multi_query(question, retriever, chat_history):
+
+    def get_unique_union(documents: List[list]):
+        """ Unique union of retrieved docs """
+        # Flatten list of lists, and convert each Document to string
+        flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
+        # Get unique documents
+        unique_docs = list(set(flattened_docs))
+        # Return
+        return [loads(doc) for doc in unique_docs]
+    
+
+    _search_query = get_search_query()
+    modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
+    print(modified_question)
+
+    generate_queries = multi_query_chain()
+
+    retrieval_chain = generate_queries | retriever.map() | get_unique_union
+    docs = retrieval_chain.invoke({"question":modified_question})
+
+    answer = multi_query_rag_prompt(retrieval_chain, modified_question)
+
+    return answer, docs
+
+def multi_query_rag_prompt(retrieval_chain, question):
+    # RAG
+    template = """Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    Use three sentences maximum and keep the answer concise.
+    If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    """
+
+    prompt = ChatPromptTemplate.from_template(template)
+
+    # llm = ChatOpenAI(temperature=0)
+    # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
+    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
+
+    final_rag_chain = (
+        {"context": retrieval_chain, 
+        "question": itemgetter("question")} 
+        | prompt
+        | llm
+        | StrOutputParser()
+    )
+
+    # answer = final_rag_chain.invoke({"question":question})
+
+    answer = ""
+    for text in final_rag_chain.stream({"question":question}):
+        print(text, end="", flush=True)
+        answer += text
+
+
+    return answer
+########################################################################################################################
+
+def get_search_query():
+    # Condense a chat history and follow-up question into a standalone question
+    # 
+    # _template = """Given the following conversation and a follow up question, 
+    # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
+    # Generate standalone question in its original language.
+    # Chat History:
+    # {chat_history}
+    # Follow Up Input: {question}
+
+    # Hint:
+    # * Refer to chat history and add the subject to the question
+    # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
+    
+    # Standalone question:"""  # noqa: E501
+    _template = """Rewrite the following query by incorporating relevant context from the conversation history.
+    The rewritten query should:
+    
+    - Preserve the core intent and meaning of the original query
+    - Expand and clarify the query to make it more specific and informative for retrieving relevant context
+    - Avoid introducing new topics or queries that deviate from the original query
+    - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
+    - The rewritten query should be in its original language.
+    
+    Return ONLY the rewritten query text, without any additional formatting or explanations.
+    
+    Conversation History:
+    {chat_history}
+    
+    Original query: [{question}]
+    
+    Rewritten query: 
+    """
+    CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
+
+    def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
+        buffer = []
+        for human, ai in chat_history:
+            buffer.append(HumanMessage(content=human))
+            buffer.append(AIMessage(content=ai))
+        return buffer
+
+    _search_query = RunnableBranch(
+        # If input includes chat_history, we condense it with the follow-up question
+        (
+            RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
+                run_name="HasChatHistoryCheck"
+            ),  # Condense follow-up question and chat into a standalone_question
+            RunnablePassthrough.assign(
+                chat_history=lambda x: _format_chat_history(x["chat_history"])
+            )
+            | CONDENSE_QUESTION_PROMPT
+            | ChatOpenAI(temperature=0)
+            | StrOutputParser(),
+        ),
+        # Else, we have no chat history, so just pass through the question
+        RunnableLambda(lambda x : x["question"]),
+    )
+
+    return _search_query
+########################################################################################################################
+def naive_rag(question, retriever, chat_history):
+    _search_query = get_search_query()
+    modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
+    print(modified_question)
+
+    #### RETRIEVAL and GENERATION ####
+
+    # Prompt
+    prompt = hub.pull("rlm/rag-prompt")
+
+    # LLM
+    # llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
+
+    # Post-processing
+    def format_docs(docs):
+        return "\n\n".join(doc.page_content for doc in docs)
+
+    reference = retriever.get_relevant_documents(modified_question)
+    
+    # Chain
+    rag_chain = (
+        {"context": retriever | format_docs, "question": RunnablePassthrough()}
+        | prompt
+        | llm
+        | StrOutputParser()
+    )
+
+    # Question
+    answer = rag_chain.invoke(modified_question)
+
+    return answer, reference
+################################################################################################
+
+
+if __name__ == "__main__":
+    from faiss_index import create_faiss_retriever, faiss_query 
+    global_retriever = create_faiss_retriever()
+    generate_queries = multi_query_chain()
+    question = "台灣為什麼要制定氣候變遷因應法?"
+    
+    questions = generate_queries.invoke(question)
+    questions = [item for item in questions if item != ""]
+    # print(questions)
+
+    
+    results = list(map(global_retriever.get_relevant_documents, questions))
+    results = [item for sublist in results for item in sublist]
+    print(len(results))
+    print(results)
+
+    
+    # retrieval_chain = generate_queries | global_retriever.map
+    # docs = retrieval_chain.invoke(question)
+    # print(docs)
+    # print(len(docs))
+
+
+
+    # print(len(results))
+    # for doc in results[:10]:
+    #     print(doc)
+    #     print("-----------------------------------------------------------------------")
+
+
+    # results = get_unique_union(results)
+    # print(len(results))
+
+    # retrieval_chain = generate_queries | global_retriever.map | get_unique_union
+    # docs = retrieval_chain.invoke(question)
+    # print(len(docs))
+
+
+

+ 194 - 0
add_vectordb.py

@@ -0,0 +1,194 @@
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.vectorstores import Chroma
+from langchain_community.document_loaders import TextLoader
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import PyPDFLoader
+from langchain_community.document_loaders import Docx2txtLoader
+
+import os
+import glob
+
+from langchain_community.vectorstores import SupabaseVectorStore
+from langchain_openai import OpenAIEmbeddings
+from supabase.client import Client, create_client
+
+
+def get_data_list(data_list=None, path=None, extension=None, update=False):
+    files = data_list or glob.glob(os.path.join(path, f"*.{extension}"))
+    if update:    
+        doc = files.copy()
+    else:
+        existed_data = check_existed_data(supabase)
+        doc = []
+        for file_path in files:
+            filename = os.path.basename(file_path)
+            if filename not in existed_data:
+                doc.append(file_path)
+
+    return doc
+
+
+def read_and_split_files(data_list=None, path=None, extension=None, update=False):
+
+    def load_and_split(file_list):
+        chunks = []
+        for file in file_list:
+            if file.endswith(".txt"):
+                loader = TextLoader(file, encoding='utf-8')
+            elif file.endswith(".pdf"):
+                loader = PyPDFLoader(file)
+            elif file.endswith(".docx"):
+                loader = Docx2txtLoader(file)
+            else:
+                print(f"Unsupported file extension: {file}")
+                continue
+
+            docs = loader.load()
+
+            # Split
+            if file.endswith(".docx"):
+                separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
+                text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
+            else:
+                text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
+            splits = text_splitter.split_documents(docs)
+
+            chunks.extend(splits)
+
+        return chunks
+
+
+    doc = get_data_list(data_list=data_list, path=path, extension=extension, update=update)
+    # Index
+    docs = load_and_split(doc)
+
+    return docs
+
+def create_ids(docs):
+    # Create a dictionary to count occurrences of each page in each document
+    page_counter = {}
+
+    # List to store the resulting IDs
+    document_ids = []
+
+    # Generate IDs
+    for doc in [docs[i].metadata for i in range(len(docs))]:
+        source = doc['source']
+        file_name = os.path.basename(source).split('.')[0]
+
+        if "page" in doc.keys():
+            page = doc['page']
+            key = f"{source}_{page}"
+        else:
+            key = f"{source}"
+
+        if key not in page_counter:
+            page_counter[key] = 1
+        else:
+            page_counter[key] += 1
+        
+        if "page" in doc.keys():
+            doc_id = f"{file_name} | page {page} | chunk {page_counter[key]}"
+        else:
+            doc_id = f"{file_name} | chunk {page_counter[key]}"
+
+        
+        document_ids.append(doc_id)
+
+    return document_ids
+
+def get_document(data_list=None, path=None, extension=None, update=False):
+    docs = read_and_split_files(data_list=data_list, path=path, extension=extension, update=update)
+    document_ids = create_ids(docs)
+
+    for doc in docs:
+        doc.metadata['source'] = os.path.basename(doc.metadata['source'])
+        # print(doc.metadata)
+
+    # document_metadatas = [{'source': doc.metadata['source'], 'page': doc.metadata['page'], 'chunk': int(id.split("chunk ")[-1])} for doc, id in zip(docs, document_ids)]
+    document_metadatas = []
+
+    for doc, id in zip(docs, document_ids):
+        chunk_number = int(id.split("chunk ")[-1])
+        doc.metadata['chunk'] = chunk_number
+        doc.metadata['extension'] = os.path.basename(doc.metadata['source']).split(".")[-1]
+        document_metadatas.append(doc.metadata)
+
+    documents = [docs.metadata['source'].split(".")[0] + docs.page_content for docs in docs]
+
+    return document_ids, documents, document_metadatas
+
+def check_existed_data(supabase):
+    response = supabase.table('documents').select("id, metadata").execute()
+    existed_data = list(set([data['metadata']['source'] for data in response.data]))
+    # existed_data = [(data['id'], data['metadata']['source']) for data in response.data]
+    return existed_data
+
+class GetVectorStore(SupabaseVectorStore):
+    def __init__(self, embeddings, supabase, table_name):
+        super().__init__(embedding=embeddings, client=supabase, table_name=table_name, query_name="match_documents")
+
+    def insert(self, documents, document_metadatas):
+        self.add_texts(
+            texts=documents,
+            metadatas=document_metadatas,
+        )
+
+    def delete(self, file_list):
+        for file_name in file_list:
+            self._client.table(self.table_name).delete().eq('metadata->>source', file_name).execute()
+
+    def update(self, documents, document_metadatas, update_existing_data=False):
+        if not document_metadatas:  # no new data
+            return
+
+        if update_existing_data:
+            file_list = list(set(metadata['source'] for metadata in document_metadatas))
+            self.delete(file_list)
+
+        self.insert(documents, document_metadatas)
+
+if __name__ == "__main__":
+
+    load_dotenv()
+    supabase_url = os.environ.get("SUPABASE_URL")
+    supabase_key = os.environ.get("SUPABASE_KEY")
+    document_table = "documents"
+    supabase: Client = create_client(supabase_url, supabase_key)
+
+    embeddings = OpenAIEmbeddings()
+
+    ###################################################################################
+    # get vector store
+    vector_store = GetVectorStore(embeddings, supabase, document_table)
+
+    ###################################################################################
+    # update data (old + new / all new / all old)
+    path = "/home/mia/systex/Documents"
+    extension = "pdf"
+    # file = None
+
+    # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
+    # file = [os.path.join(path, file) for file in file_list]
+    file_list = glob.glob(os.path.join(path, "*"))
+    print(file_list)
+    
+    update = False
+    document_ids, documents, document_metadatas = get_document(data_list=file_list, path=path, extension=extension, update=update)
+    vector_store.update(documents, document_metadatas, update_existing_data=update)
+
+    ###################################################################################
+    # insert new data (all new)
+    # vector_store.insert(documents, document_metadatas)
+
+    ###################################################################################
+    # delete data
+    # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
+    # vector_store.delete(file_list)
+
+    ###################################################################################
+    # get retriver
+    # retriever = vector_store.as_retriever(search_kwargs={"k": 6})

+ 80 - 0
conda_env.txt

@@ -0,0 +1,80 @@
+# This file may be used to create an environment using:
+# $ conda create --name llama3 --file conda_env.txt
+# platform: linux-64
+_libgcc_mutex=0.1=conda_forge
+_openmp_mutex=4.5=2_gnu
+aiosignal=1.3.1=pyhd8ed1ab_0
+annotated-types=0.7.0=pyhd8ed1ab_0
+anyio=4.2.0=py312h06a4308_0
+async-timeout=4.0.3=pyhd8ed1ab_0
+attrs=24.2.0=pyh71513ae_0
+blas=1.1=openblas
+bottleneck=1.3.7=py312ha883a20_0
+brotli-python=1.0.9=py312h6a678d5_8
+bzip2=1.0.8=h5eee18b_6
+ca-certificates=2024.7.4=hbcca054_0
+certifi=2024.7.4=py312h06a4308_0
+cffi=1.16.0=py312h5eee18b_1
+charset-normalizer=3.3.2=pyhd8ed1ab_0
+click=8.1.7=py312h06a4308_0
+deprecation=2.1.0=pyh9f0ad1d_0
+expat=2.6.2=h6a678d5_0
+fastapi=0.103.0=py312h06a4308_0
+frozenlist=1.4.0=py312h5eee18b_0
+greenlet=3.0.1=py312h6a678d5_0
+h11=0.14.0=py312h06a4308_0
+h2=4.1.0=pyhd8ed1ab_0
+hpack=4.0.0=pyh9f0ad1d_0
+httpcore=1.0.5=pyhd8ed1ab_0
+hyperframe=6.0.1=pyhd8ed1ab_0
+idna=3.7=pyhd8ed1ab_0
+jsonpatch=1.33=pyhd8ed1ab_0
+jsonpointer=2.0=py_0
+ld_impl_linux-64=2.38=h1181459_1
+libffi=3.4.4=h6a678d5_1
+libgcc-ng=14.1.0=h77fa898_0
+libgfortran-ng=14.1.0=h69a702a_0
+libgfortran5=14.1.0=hc5f4f2c_0
+libgomp=14.1.0=h77fa898_0
+libopenblas=0.3.28=pthreads_h94d23a6_0
+libstdcxx-ng=11.2.0=h1234567_1
+libuuid=1.41.5=h5eee18b_0
+lz4-c=1.9.4=h6a678d5_1
+multidict=6.0.4=py312h5eee18b_0
+ncurses=6.4=h6a678d5_0
+numexpr=2.8.7=py312he7dcb8a_0
+numpy=1.26.4=py312h2809609_0
+numpy-base=1.26.4=py312he1a6c75_0
+openssl=3.3.1=h4bc722e_2
+orjson=3.9.15=py312h97a8848_0
+packaging=24.1=pyhd8ed1ab_0
+pandas=2.2.2=py312h526ad5a_0
+pip=24.2=py312h06a4308_0
+pycparser=2.22=pyhd8ed1ab_0
+pysocks=1.7.1=pyha2e5f31_6
+python=3.12.4=h5148396_1
+python-dateutil=2.9.0post0=py312h06a4308_2
+python-tzdata=2023.3=pyhd3eb1b0_0
+pytz=2024.1=py312h06a4308_0
+readline=8.2=h5eee18b_0
+requests=2.32.3=pyhd8ed1ab_0
+setuptools=72.1.0=py312h06a4308_0
+six=1.16.0=pyhd3eb1b0_1
+sniffio=1.3.0=py312h06a4308_0
+sqlite=3.45.3=h5eee18b_0
+starlette=0.27.0=py312h06a4308_0
+tk=8.6.14=h39e8969_0
+typing-extensions=4.12.2=hd8ed1ab_0
+typing_extensions=4.12.2=pyha770c72_0
+tzdata=2024a=h04d1e81_0
+tzlocal=5.2=py312h06a4308_0
+urllib3=2.2.2=pyhd8ed1ab_1
+uvicorn=0.20.0=py312h06a4308_0
+websockets=10.4=py312h5eee18b_1
+wheel=0.43.0=py312h06a4308_0
+xz=5.4.6=h5eee18b_1
+yaml=0.2.5=h7f98852_2
+yarl=1.9.3=py312h5eee18b_0
+zlib=1.2.13=h5eee18b_1
+zstandard=0.22.0=py312h2c38b39_0
+zstd=1.5.5=hc292b87_2

+ 440 - 0
faiss_index.py

@@ -0,0 +1,440 @@
+import faiss
+import numpy as np
+import json
+from time import time
+import asyncio
+from datasets import Dataset
+from typing import List
+from dotenv import load_dotenv
+import os
+import pickle
+from supabase.client import Client, create_client
+from langchain_openai import OpenAIEmbeddings, ChatOpenAI
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+import pandas as pd
+from langchain_core.documents import Document
+from langchain.load import dumps, loads
+
+# Import from the parent directory
+import sys
+
+from RAG_strategy import multi_query_chain
+sys.path.append('..')
+# from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+# llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+from add_vectordb import GetVectorStore
+# from local_llm import ollama_, hf
+# # from local_llm import ollama_, taide_llm, hf
+# llm = hf()
+# llm = taide_llm
+
+
+# Import RAGAS metrics
+from ragas import evaluate
+from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
+
+# Load environment variables
+load_dotenv('../../.env')
+supabase_url = os.getenv("SUPABASE_URL")
+supabase_key = os.getenv("SUPABASE_KEY")
+openai_api_key = os.getenv("OPENAI_API_KEY")
+document_table = "documents"
+
+# Initialize Supabase client
+supabase: Client = create_client(supabase_url, supabase_key)
+
+# Initialize embeddings and chat model
+embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
+
+def download_embeddings():
+    response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
+    embeddings = []
+    ids = []
+    metadatas = []
+    contents = []
+    for item in response.data:
+        embedding = json.loads(item['embedding'])
+        embeddings.append(embedding)
+        ids.append(item['id'])
+        metadatas.append(item['metadata'])
+        contents.append(item['content'])
+    return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
+
+def create_faiss_index(embeddings):
+    dimension = embeddings.shape[1]
+    index = faiss.IndexFlatIP(dimension)  # Use Inner Product for cosine similarity
+    faiss.normalize_L2(embeddings)  # Normalize embeddings for cosine similarity
+    index.add(embeddings)
+    return index
+
+def save_faiss_index(index, file_path):
+    faiss.write_index(index, file_path)
+    print(f"FAISS index saved to {file_path}")
+
+def load_faiss_index(file_path):
+    if os.path.exists(file_path):
+        index = faiss.read_index(file_path)
+        print(f"FAISS index loaded from {file_path}")
+        return index
+    return None
+
+def save_metadata(ids, metadatas, contents, file_path):
+    with open(file_path, 'wb') as f:
+        pickle.dump((ids, metadatas, contents), f)
+    print(f"Metadata saved to {file_path}")
+
+def load_metadata(file_path):
+    if os.path.exists(file_path):
+        with open(file_path, 'rb') as f:
+            ids, metadatas, contents = pickle.load(f)
+        print(f"Metadata loaded from {file_path}")
+        return ids, metadatas, contents
+    return None, None, None
+
+def search_faiss(index, query_vector, k=4):
+    query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
+    faiss.normalize_L2(query_vector)
+    distances, indices = index.search(query_vector, k)
+    return distances[0], indices[0]
+
+class FAISSRetriever:
+    def __init__(self, index, ids, metadatas, contents, embeddings_model):
+        self.index = index
+        self.ids = ids
+        self.metadatas = metadatas
+        self.contents = contents
+        self.embeddings_model = embeddings_model
+
+    def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
+        query_vector = self.embeddings_model.embed_query(query)
+        _, indices = search_faiss(self.index, query_vector, k=k)
+        return [
+            Document(page_content=self.contents[i], metadata=self.metadatas[i])
+            for i in indices
+        ]
+    def map(self, query_list: List[list]) -> List[Document]:
+        def get_unique_union(documents: List[list]):
+            """ Unique union of retrieved docs """
+            # Flatten list of lists, and convert each Document to string
+            flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
+            # Get unique documents
+            unique_docs = list(set(flattened_docs))
+            # Return
+            return [loads(doc) for doc in unique_docs]
+        
+        documents = []
+        for query in query_list:
+            if query != "":
+                docs = self.get_relevant_documents(query)
+
+            documents.extend(docs)
+
+        return get_unique_union(documents)
+
+def load_qa_pairs():
+    # df = pd.read_csv("../QA_database_rows.csv")
+    response = supabase.table('QA_database').select("Question, Answer").execute()
+    df = pd.DataFrame(response.data)
+
+    return df['Question'].tolist(), df['Answer'].tolist()
+
+def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
+    generate_queries = multi_query_chain(llm)
+
+    questions = generate_queries.invoke(question)
+    questions = [item for item in questions if item != ""]
+
+    # docs = list(map(retriever.get_relevant_documents, questions))
+    docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
+    docs = [item for sublist in docs for item in sublist]
+
+    return docs
+
+def faiss_query(question: str, retriever: FAISSRetriever, llm, multi_query: bool = False) -> str:
+    if multi_query:
+        docs = faiss_multiquery(question, retriever, llm)
+        # print(docs)
+    else:
+        docs = retriever.get_relevant_documents(question, k=10)
+        # print(docs)
+
+    context = "\n".join(doc.page_content for doc in docs)
+    
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,
+    請用繁體中文回答問題 \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    Use five sentences maximum and keep the answer concise.
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    勿回答無關資訊
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
+    Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    prompt = ChatPromptTemplate.from_template(
+        system_prompt + "\n\n" +
+        template
+    )
+    
+    # prompt = ChatPromptTemplate.from_template(
+    #     system_prompt + "\n\n" +
+    #     "Answer the following question based on this context:\n\n"
+    #     "{context}\n\n"
+    #     "Question: {question}\n"
+    #     "Answer in the same language as the question. If you don't know the answer, "
+    #     "say 'I'm sorry, I don't have enough information to answer that question.'"
+    # )
+
+    
+    # chain = prompt | taide_llm | StrOutputParser()
+    chain = prompt | llm | StrOutputParser()
+    return chain.invoke({"context": context, "question": question})
+
+
+def create_faiss_retriever():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+
+    index = load_faiss_index(faiss_index_path)
+    ids, metadatas, contents = load_metadata(metadata_path)
+
+    if index is None or ids is None:
+        print("FAISS index or metadata not found. Creating new index...")
+        print("Downloading embeddings from Supabase...")
+        embeddings_array, ids, metadatas, contents = download_embeddings()
+
+        print("Creating FAISS index...")
+        index = create_faiss_index(embeddings_array)
+
+        save_faiss_index(index, faiss_index_path)
+        save_metadata(ids, metadatas, contents, metadata_path)
+    else:
+        print("Using existing FAISS index and metadata.")
+
+    print("Creating FAISS retriever...")
+    faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
+
+    return faiss_retriever
+
+
+async def run_evaluation():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+
+    index = load_faiss_index(faiss_index_path)
+    ids, metadatas, contents = load_metadata(metadata_path)
+
+    if index is None or ids is None:
+        print("FAISS index or metadata not found. Creating new index...")
+        print("Downloading embeddings from Supabase...")
+        embeddings_array, ids, metadatas, contents = download_embeddings()
+
+        print("Creating FAISS index...")
+        index = create_faiss_index(embeddings_array)
+
+        save_faiss_index(index, faiss_index_path)
+        save_metadata(ids, metadatas, contents, metadata_path)
+    else:
+        print("Using existing FAISS index and metadata.")
+
+    print("Creating FAISS retriever...")
+    faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
+
+    print("Creating original vector store...")
+    original_vector_store = GetVectorStore(embeddings, supabase, document_table)
+    original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
+
+    questions, ground_truths = load_qa_pairs()
+
+    for question, ground_truth in zip(questions, ground_truths):
+        print(f"\nQuestion: {question}")
+
+        start_time = time()
+        faiss_answer = faiss_query(question, faiss_retriever)
+        faiss_docs = faiss_retriever.get_relevant_documents(question)
+        faiss_time = time() - start_time
+        print(f"FAISS Answer: {faiss_answer}")
+        print(f"FAISS Time: {faiss_time:.4f} seconds")
+
+        start_time = time()
+        original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
+        original_time = time() - start_time
+        print(f"Original Answer: {original_answer}")
+        print(f"Original Time: {original_time:.4f} seconds")
+
+        # faiss_datasets = {
+        #     "question": [question],
+        #     "answer": [faiss_answer],
+        #     "contexts": [[doc.page_content for doc in faiss_docs]],
+        #     "ground_truth": [ground_truth]
+        # }
+        # faiss_evalsets = Dataset.from_dict(faiss_datasets)
+
+        # faiss_result = evaluate(
+        #     faiss_evalsets,
+        #     metrics=[
+        #         context_precision,
+        #         faithfulness,
+        #         answer_relevancy,
+        #         context_recall,
+        #     ],
+        # )
+
+        # print("FAISS RAGAS Evaluation:")
+        # print(faiss_result.to_pandas())
+
+        # original_datasets = {
+        #     "question": [question],
+        #     "answer": [original_answer],
+        #     "contexts": [[doc.page_content for doc in original_docs]],
+        #     "ground_truth": [ground_truth]
+        # }
+        # original_evalsets = Dataset.from_dict(original_datasets)
+
+        # original_result = evaluate(
+        #     original_evalsets,
+        #     metrics=[
+        #         context_precision,
+        #         faithfulness,
+        #         answer_relevancy,
+        #         context_recall,
+        #     ],
+        # )
+
+        # print("Original RAGAS Evaluation:")
+        # print(original_result.to_pandas())
+
+    print("\nPerformance comparison complete.")
+
+
+async def ask_question():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+
+    index = load_faiss_index(faiss_index_path)
+    ids, metadatas, contents = load_metadata(metadata_path)
+
+    if index is None or ids is None:
+        print("FAISS index or metadata not found. Creating new index...")
+        print("Downloading embeddings from Supabase...")
+        embeddings_array, ids, metadatas, contents = download_embeddings()
+
+        print("Creating FAISS index...")
+        index = create_faiss_index(embeddings_array)
+
+        save_faiss_index(index, faiss_index_path)
+        save_metadata(ids, metadatas, contents, metadata_path)
+    else:
+        print("Using existing FAISS index and metadata.")
+
+    print("Creating FAISS retriever...")
+    faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
+
+    # print("Creating original vector store...")
+    # original_vector_store = GetVectorStore(embeddings, supabase, document_table)
+    # original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
+
+    # questions, ground_truths = load_qa_pairs()
+
+    # for question, ground_truth in zip(questions, ground_truths):
+    question = ""
+    while question != "exit":
+        question = input("Question: ")
+        print(f"\nQuestion: {question}")
+
+        start_time = time()
+        faiss_answer = faiss_query(question, faiss_retriever)
+        faiss_docs = faiss_retriever.get_relevant_documents(question)
+        faiss_time = time() - start_time
+        print(f"FAISS Answer: {faiss_answer}")
+        print(f"FAISS Time: {faiss_time:.4f} seconds")
+
+        # start_time = time()
+        # original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
+        # original_time = time() - start_time
+        # print(f"Original Answer: {original_answer}")
+        # print(f"Original Time: {original_time:.4f} seconds")
+
+if __name__ == "__main__":
+
+    global_retriever = create_faiss_retriever()
+
+    questions, ground_truths = load_qa_pairs()
+    results = []
+
+    for question, ground_truth in zip(questions, ground_truths):
+        # For multi_query=True
+        start = time()
+        final_answer_multi = faiss_query(question, global_retriever, multi_query=True)
+        processing_time_multi = time() - start
+        # print(final_answer_multi)
+        # print(processing_time_multi)
+
+        # For multi_query=False
+        start = time()
+        final_answer_single = faiss_query(question, global_retriever, multi_query=False)
+        processing_time_single = time() - start
+        # print(final_answer_single)
+        # print(processing_time_single)
+
+        # Store results in a dictionary
+        result = {
+            "question": question,
+            "ground_truth": ground_truth,
+            "final_answer_multi_query": final_answer_multi,
+            "processing_time_multi_query": processing_time_multi,
+            "final_answer_single_query": final_answer_single,
+            "processing_time_single_query": processing_time_single
+        }
+        print(result)
+        
+        results.append(result)
+
+        with open('qa_results.json', 'a', encoding='utf8') as outfile:
+            json.dump(result, outfile, indent=4, ensure_ascii=False)
+            outfile.write("\n")  # Ensure each result is on a new line
+        
+
+    # Save results to a JSON file
+    with open('qa_results+all.json', 'w', encoding='utf8') as outfile:
+        json.dump(results, outfile, indent=4, ensure_ascii=False)
+
+    print('All questions done!')
+    # question = ""
+    # while question != "exit":
+    #     # question = "國家溫室氣體長期減量目標" 
+    #     question = input("Question: ")
+    #     if question.strip().lower == "exit": break
+
+    #     start = time()
+    #     final_answer = faiss_query(question, global_retriever, multi_query=True)
+    #     print(final_answer)
+    #     processing_time = time() - start
+    #     print(processing_time)
+
+        
+    #     start = time() 
+    #     final_answer = faiss_query(question, global_retriever, multi_query=False)
+    #     print(final_answer)
+    #     processing_time = time() - start
+    #     print(processing_time)
+    # print("Chatbot closed!")
+
+    # asyncio.run(ask_question())

+ 119 - 0
local_llm.py

@@ -0,0 +1,119 @@
+from langchain_community.chat_models import ChatOllama
+from langchain_openai import ChatOpenAI
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
+import torch
+from langchain_huggingface import HuggingFacePipeline
+
+
+from typing import Any, List, Optional, Dict
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
+from langchain_core.outputs import ChatResult, ChatGeneration
+from pydantic import Field
+import subprocess
+
+import time
+
+from dotenv import load_dotenv
+load_dotenv()
+
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+
+def hf():
+    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+    llm = HuggingFacePipeline.from_model_id(
+        model_id=model_id,
+        task="text-generation",
+        model_kwargs={"torch_dtype": torch.bfloat16},
+        pipeline_kwargs={"return_full_text": False,
+            "max_new_tokens": 512,
+            "repetition_penalty":1.03},
+        device=0, device_map='cuda')
+    # print(llm.pipeline)
+    llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+
+    return llm
+
+
+def ollama_():
+    # model = "cwchang/llama3-taide-lx-8b-chat-alpha1"
+    model = "llama3.1:latest"
+    # model = "llama3.1:70b"
+    # model = "893379029/piccolo-large-zh-v2"
+    sys = "你是一個來自台灣的 AI 助理,,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。請用 5 句話以內回答問題。"
+    # llm = ChatOllama(model=model, num_gpu=2, num_thread=32, temperature=0, system=sys, keep_alive="10m", verbose=True)
+    llm = ChatOllama(model=model, num_gpu=2, temperature=0, system=sys, keep_alive="10m")
+
+    return llm
+
+
+def openai_(): # not lacal
+    llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
+
+    return llm
+
+class OllamaChatModel(BaseChatModel):
+    model_name: str = Field(default="taide-local-llama3")
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        formatted_messages = []
+        for msg in messages:
+            if isinstance(msg, HumanMessage):
+                formatted_messages.append({"role": "user", "content": msg.content})
+            elif isinstance(msg, AIMessage):
+                formatted_messages.append({"role": "assistant", "content": msg.content})
+            elif isinstance(msg, SystemMessage):
+                 formatted_messages.append({"role": "system", "content": msg.content})
+
+        # prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" # TAIDE llama2
+        prompt = f"<|begin_of_text|><|start_header_id|>{system_prompt}<|end_header_id|>" # TAIDE llama3
+        for msg in formatted_messages:
+            if msg['role'] == 'user':
+                # prompt += f"{msg['content']} [/INST]" # TAIDE llama2
+                prompt += f"<|eot_id|><|start_header_id|>{msg['content']}<|end_header_id|>" # TAIDE llama3
+            elif msg['role'] == "assistant":
+                # prompt += f"{msg['content']} </s><s>[INST]" # TAIDE llama2
+                prompt += f"<|eot_id|><|start_header_id|>{msg['content']}<|end_header_id|>" # TAIDE llama3
+
+        command = ["docker", "exec", "-it", "ollama", "ollama", "run", self.model_name, prompt]
+        result = subprocess.run(command, capture_output=True, text=True)
+
+        if result.returncode != 0:
+            raise Exception(f"Ollama command failed: {result.stderr}")
+        
+        content = result.stdout.strip()
+
+        message = AIMessage(content=content)
+        generation = ChatGeneration(message=message)
+        return ChatResult(generations=[generation])
+    
+    @property
+    def _llm_type(self) -> str:
+        return "ollama-chat-model"
+    
+# taide_llm = OllamaChatModel(model_name="taide-local-llama2")    
+
+if __name__ == "__main__":
+    question = ""
+    while question.lower() != "exit": 
+        question = input("Question: ")
+        # 溫室氣體是什麼?
+
+
+        for function in [ollama_, huggingface_, huggingface2_, openai_]:
+            start = time.time()
+            llm = function()
+            answer = llm.invoke(question)
+            print(answer)
+
+            processing_time = time.time() - start
+            print(processing_time)

+ 107 - 0
pip_env.txt

@@ -0,0 +1,107 @@
+accelerate==0.33.0
+aenum==3.1.15
+aiohappyeyeballs==2.3.7
+aiohttp==3.10.3
+asgiref==3.8.1
+backoff==2.2.1
+bcrypt==4.2.0
+blobfile==2.1.1
+build==1.2.1
+cachetools==5.5.0
+chroma-hnswlib==0.7.6
+chromadb==0.5.5
+coloredlogs==15.0.1
+dataclasses-json==0.6.7
+deprecated==1.2.14
+distro==1.9.0
+fairscale==0.4.13
+filelock==3.15.4
+fire==0.6.0
+flatbuffers==24.3.25
+fsspec==2024.6.1
+future==1.0.0
+google-auth==2.34.0
+googleapis-common-protos==1.63.2
+grpcio==1.65.5
+httptools==0.6.1
+httpx==0.27.0
+huggingface-hub==0.24.5
+humanfriendly==10.0
+importlib-metadata==8.0.0
+importlib-resources==6.4.3
+jinja2==3.1.4
+jiter==0.5.0
+kubernetes==30.1.0
+langchain-community==0.2.12
+langchain-openai==0.1.22
+line-bot-sdk==3.11.1
+lxml==4.9.4
+markdown-it-py==3.0.0
+markupsafe==2.1.5
+marshmallow==3.21.3
+mdurl==0.1.2
+mmh3==4.1.0
+monotonic==1.6
+mpmath==1.3.0
+mypy-extensions==1.0.0
+networkx==3.3
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.6.20
+nvidia-nvtx-cu12==12.1.105
+oauthlib==3.2.2
+onnxruntime==1.19.0
+openai==1.41.0
+opentelemetry-api==1.26.0
+opentelemetry-exporter-otlp-proto-common==1.26.0
+opentelemetry-exporter-otlp-proto-grpc==1.26.0
+opentelemetry-instrumentation==0.47b0
+opentelemetry-instrumentation-asgi==0.47b0
+opentelemetry-instrumentation-fastapi==0.47b0
+opentelemetry-proto==1.26.0
+opentelemetry-sdk==1.26.0
+opentelemetry-semantic-conventions==0.47b0
+opentelemetry-util-http==0.47b0
+overrides==7.7.0
+posthog==3.5.0
+protobuf==4.25.4
+psutil==6.0.0
+pyasn1==0.6.0
+pyasn1-modules==0.4.0
+pycryptodomex==3.20.0
+pydantic==2.8.2
+pydantic-core==2.20.1
+pygments==2.18.0
+pypika==0.48.9
+pyproject-hooks==1.1.0
+python-dotenv==1.0.1
+pyyaml==6.0.2
+regex==2024.7.24
+requests-oauthlib==2.0.0
+rich==13.7.1
+rsa==4.9
+safetensors==0.4.4
+shellingham==1.5.4
+sympy==1.13.2
+termcolor==2.4.0
+tiktoken==0.7.0
+tokenizers==0.19.1
+torch==2.4.0
+tqdm==4.66.5
+transformers==4.44.0
+triton==3.0.0
+typer==0.12.4
+typing-inspect==0.9.0
+uvloop==0.20.0
+watchfiles==0.23.0
+websocket-client==1.8.0
+wrapt==1.16.0
+zipp==3.20.0

+ 73 - 0
semantic_search.py

@@ -0,0 +1,73 @@
+### Python = 3.9
+import os
+from dotenv import load_dotenv
+load_dotenv('.env')
+
+import openai 
+openai_api_key = os.getenv("OPENAI_API_KEY")
+openai.api_key = openai_api_key
+
+from langchain_openai import OpenAIEmbeddings
+embeddings_model = OpenAIEmbeddings()
+
+from langchain_community.document_loaders.csv_loader import CSVLoader
+from langchain_community.vectorstores import Chroma
+
+import pandas as pd
+import re
+
+from langchain_community.embeddings.openai import OpenAIEmbeddings
+from langchain_community.vectorstores import SupabaseVectorStore
+from supabase.client import create_client
+
+
+def create_qa_vectordb(supabase, vectordb_directory="./chroma_db_carbon_questions"):
+
+    if os.path.isdir(vectordb_directory):
+        vectorstore = Chroma(persist_directory=vectordb_directory, embedding_function=embeddings_model)
+        vectorstore.delete_collection()
+
+    response = supabase.table("QA_database").select("Question, Answer").execute()
+    questions = [row["Question"] for row in response.data]
+
+    ######### generate embedding ###########
+    # embedding = embeddings_model.embed_documents(questions)
+
+    ########## Write embedding to the supabase table  #######
+    # for id, new_embedding in zip(ids, embedding):
+    #     supabase.table("video_cache_rows_duplicate").insert({"embedding": embedding.tolist()}).eq("id", id).execute()
+
+    ########### Vector Store #############
+    # Put pre-compute embeddings to vector store. ## save to disk
+    vectorstore = Chroma.from_texts(
+        texts=questions,
+        embedding=embeddings_model,
+        persist_directory=vectordb_directory
+        )
+    
+    return vectorstore
+
+
+# vectorstore = Chroma(persist_directory="./chroma_db_carbon_questions", embedding_function=embeddings_model)
+def semantic_cache(supabase, q, SIMILARITY_THRESHOLD=0.83, k=1, vectordb_directory="./chroma_db_carbon_questions"):
+
+    if os.path.isdir(vectordb_directory):
+        vectorstore = Chroma(persist_directory=vectordb_directory, embedding_function=embeddings_model)
+    else:
+        print("create new vector db ...")
+        vectorstore = create_qa_vectordb(supabase, vectordb_directory)
+
+    docs_and_scores = vectorstore.similarity_search_with_relevance_scores(q, k=1)
+    doc, score = docs_and_scores[0]
+
+    
+    if score >= SIMILARITY_THRESHOLD:
+        cache_question = doc.page_content
+        response = supabase.table("QA_database").select("Question, Answer").eq("Question", cache_question).execute()
+        # qa_df = pd.DataFrame(response.data)
+        # print(response.data[0])
+        answer = response.data[0]["Answer"]
+        # video_cache = response.data[0]["video_cache"]
+        return cache_question, answer
+    else:
+        return None, None

+ 187 - 0
text_to_sql.py

@@ -0,0 +1,187 @@
+import re
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_community.utilities import SQLDatabase
+import os
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# print(db.dialect)
+# print(db.get_usable_table_names())
+# db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
+
+context = db.get_context()
+# print(list(context))
+# print(context["table_info"])
+
+from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
+from langchain.chains import create_sql_query_chain
+from langchain_community.llms import Ollama
+
+from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
+from operator import itemgetter
+
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import PromptTemplate
+from langchain_core.runnables import RunnablePassthrough
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
+import torch
+from langchain_huggingface import HuggingFacePipeline
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+# model_id = "defog/llama-3-sqlcoder-8b"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+# sql_llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False},
+#     device=0, device_map='cuda')
+
+
+model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+llm = HuggingFacePipeline.from_model_id(
+    model_id=model_id,
+    task="text-generation",
+    model_kwargs={"torch_dtype": torch.bfloat16},
+    pipeline_kwargs={"return_full_text": False,
+        "max_new_tokens": 512},
+    device=0, device_map='cuda')
+print(llm.pipeline)
+llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+# model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
+
+# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1, 
+#                 model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
+#, device="auto", load_in_4bit=True
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+def get_examples():
+    examples = [
+        {
+            "input": "去年的固定燃燒總排放量是多少?",
+            "query": 'SELECT SUM("高雄總部及運通廠" + "台北辦事處" + "昆山廣興廠" + "北海建準廠" + "北海立準廠" + "菲律賓建準廠" + "Inc" + "SAS" + "India") AS "固定燃燒總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "排放源" = \'固定燃燒\'',
+        },
+        {
+            "input": "建準廣興廠去年的類別1總排放量是多少?",
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'',
+        },
+        {
+            "input": "建準廣興廠去年的直接排放總排放量是多少?",
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%直接排放%\'',
+        },
+
+    ]
+
+    return examples
+
+def write_query_chain(db):
+
+    template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
+
+    Generate a SQL query to answer this question: `{input}`
+
+    You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
+    then look at the results of the query and return the answer to the input question.\n\
+    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. 
+    You can order the results to return the most informative data in the database.\n\
+    Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
+    Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
+    
+    ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
+    ***Pay attention to only return PostgreSQL query.\n\
+
+    DDL statements:
+    {table_info}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+    The following SQL query best answers the question `{input}`:
+    ```sql
+    """
+    # prompt_template = PromptTemplate.from_template(template)
+
+    example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
+    prompt = FewShotPromptTemplate(
+        examples=get_examples(),
+        example_prompt=example_prompt,
+        prefix=template,
+        suffix="User input: {input}\nSQL query: ",
+        input_variables=["input", "top_k", "table_info"],
+    )
+
+    # llm = Ollama(model = "mannix/defog-llama3-sqlcoder-8b", num_gpu=1)
+    # llm = HuggingFacePipeline(pipeline=pipe)
+    
+    
+    write_query = create_sql_query_chain(llm, db, prompt)
+
+
+    return write_query
+
+def sql_to_nl_chain():
+    # llm = Ollama(model = "llama3.1", num_gpu=1)
+    # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
+    # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+    answer_prompt = PromptTemplate.from_template(
+        """
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|>
+        Given the following user question, corresponding SQL query, and SQL result, answer the user question.
+        給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
+
+        For example
+        Question: 建準廣興廠去年的類別1總排放量是多少?
+        SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
+        SQL Result: [(1102.3712,)]
+        Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+
+        Question: {question}
+        SQL Query: {query}
+        SQL Result: {result}
+        Answer: """
+        )
+
+    chain = answer_prompt | llm | StrOutputParser()
+
+    return chain
+
+def run(db, question, selected_table):
+
+    write_query = write_query_chain(db)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"]})
+    
+    query = re.split('SQL query: ', query)[-1]
+    print(query)
+
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain()
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return query, result, answer
+
+
+if __name__ == "__main__":
+    import time
+    
+    start = time.time()
+    
+    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)']
+    question = "去年的固定燃燒總排放量是多少?"
+    query, result, answer = run(db, question, selected_table)
+    print("query: ", query)
+    print("result: ", result)
+    print("answer: ", answer)
+    
+    print(time.time()-start)
+