Pārlūkot izejas kodu

Integrate Llama 3.1 with RAG using LangChain's Huggingface module

ling 7 mēneši atpakaļ
revīzija
bf1c706ac0
11 mainītis faili ar 1837 papildinājumiem un 0 dzēšanām
  1. 5 0
      .gitignore
  2. 174 0
      Indexing_Split.py
  3. 198 0
      RAG_app.py
  4. 260 0
      RAG_strategy.py
  5. 194 0
      add_vectordb.py
  6. 80 0
      conda_env.txt
  7. 440 0
      faiss_index.py
  8. 119 0
      local_llm.py
  9. 107 0
      pip_env.txt
  10. 73 0
      semantic_search.py
  11. 187 0
      text_to_sql.py

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+__pycache__/
+chroma_db_carbon_questions/
+faiss_index.bin
+faiss_metadata.pkl
+.env

+ 174 - 0
Indexing_Split.py

@@ -0,0 +1,174 @@
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.embeddings import OllamaEmbeddings
+from langchain_community.vectorstores import Chroma
+from langchain_community.document_loaders import TextLoader
+from langchain.text_splitter import CharacterTextSplitter
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_core.documents import Document
+from langchain_community.document_loaders import PyPDFLoader
+from langchain_community.document_loaders import Docx2txtLoader
+from langchain_community.document_loaders import WebBaseLoader
+from PyPDF2 import PdfReader
+from langchain.docstore.document import Document
+from json import loads
+import pandas as pd
+from sqlalchemy import create_engine
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_openai import ChatOpenAI
+from langchain_core.output_parsers import StrOutputParser
+from langchain import hub
+from tqdm import tqdm
+
+# __import__('pysqlite3')
+# import sys
+# sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
+
+from datasets import Dataset 
+# from ragas import evaluate
+# from ragas.metrics import (
+#     answer_relevancy,
+#     faithfulness,
+#     context_recall,
+#     context_precision,
+# )
+import pandas as pd
+import os
+import glob
+
+from dotenv import load_dotenv
+import os
+load_dotenv()
+URI = os.getenv("SUPABASE_URI")
+
+from RAG_strategy import multi_query, naive_rag
+
+def create_retriever(path='Documents', extension="pdf"):
+    txt_files = glob.glob(os.path.join(path, f"*.{extension}"))
+    
+    doc = []
+    for file_path in txt_files:
+        doc.append(file_path)
+    
+    def load_and_split(file_list):
+        chunks = []
+        for file in file_list:
+            if file.endswith(".txt"):
+                loader = TextLoader(file, encoding='utf-8')
+            elif file.endswith(".pdf"):
+                loader = PyPDFLoader(file)
+            elif file.endswith(".docx"):
+                loader = Docx2txtLoader(file)
+            else:
+                raise ValueError(f"Unsupported file extension: {file}")
+            
+
+            docs = loader.load()
+
+            # Split
+            if file.endswith(".docx"):
+                # separators = ["\n\n\u25cb", "\n\n\u25cf"]
+                # text_splitter = RecursiveCharacterTextSplitter(separators=separators, chunk_size=500, chunk_overlap=0)
+                separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
+                text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
+            else:
+                text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
+            
+            splits = text_splitter.split_documents(docs)
+
+            chunks.extend(splits)
+
+        return chunks
+
+    # Index
+    docs = load_and_split(doc)
+    qa_history_doc = gen_doc_from_history()
+    docs.extend(qa_history_doc)
+    # web_doc = web_data(os.path.join(path, 'web_url.csv'))
+    # docs.extend(web_doc)
+
+    # vectorstore
+    # vectorstore = Chroma.from_texts(texts=docs, embedding=OpenAIEmbeddings())
+    vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings())
+    # vectorstore = Chroma.from_documents(documents=docs, embedding=OllamaEmbeddings(model="llama3", num_gpu=1))
+    vectorstore.persist()
+
+    retriever = vectorstore.as_retriever()
+
+    return retriever
+
+def web_data(url_file):
+    df = pd.read_csv(url_file, header = 0)
+    url_list = df['url'].to_list()
+
+    loader = WebBaseLoader(url_list)
+    docs = loader.load()
+
+    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
+                chunk_size=1000, chunk_overlap=0)
+    splits = text_splitter.split_documents(docs)
+    
+    return splits
+
+def gen_doc_from_history():
+    engine = create_engine(URI, echo=True)
+
+    df = pd.read_sql_table("systex_records", engine.connect())  
+    df.fillna('', inplace=True)
+    result = df.to_json(orient='index', force_ascii=False)
+    result = loads(result)
+
+
+    df = pd.DataFrame(result).T
+    qa_history_doc = []
+    for i in range(len(df)):
+        if df.iloc[i]['used_as_document'] is not True: continue
+        Question = df.iloc[i]['Question']
+        Answer = df.iloc[i]['Answer']
+        context = f'Question: {Question}\nAnswer: {Answer}'
+        
+        doc =  Document(page_content=context, metadata={"source": "History"})
+        qa_history_doc.append(doc)
+        # print(doc)
+
+    return qa_history_doc
+
+def gen_doc_from_database():
+    engine = create_engine(URI, echo=True)
+
+    df = pd.read_sql_table("QA_database", engine.connect())  
+    # df.fillna('', inplace=True)
+    result = df[['Question', 'Answer']].to_json(orient='index', force_ascii=False)
+    result = loads(result)
+
+
+    df = pd.DataFrame(result).T
+    qa_doc = []
+    for i in range(len(df)):
+        # if df.iloc[i]['used_as_document'] is not True: continue
+        Question = df.iloc[i]['Question']
+        Answer = df.iloc[i]['Answer']
+        context = f'Question: {Question}\nAnswer: {Answer}'
+        
+        doc = Document(page_content=context, metadata={"source": "History"})
+        qa_doc.append(doc)
+        # print(doc)
+
+    return qa_doc
+
+if __name__ == "__main__":
+
+    retriever = create_retriever(path='./Documents', extension="pdf")
+    question = 'CEV系統可以支援盤查到什麼程度'
+    final_answer, reference_docs = multi_query(question, retriever)
+    print(question, final_answer)
+    question = 'CEV系統依循標準為何'
+    final_answer, reference_docs = multi_query(question, retriever)
+    print(question, final_answer)
+
+
+
+

+ 198 - 0
RAG_app.py

@@ -0,0 +1,198 @@
+import re
+import threading
+from fastapi import FastAPI, Request, HTTPException, Response, status, Body
+# from fastapi.templating import Jinja2Templates
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, JSONResponse
+from fastapi import Depends
+from contextlib import asynccontextmanager
+from pydantic import BaseModel
+from typing import List, Optional
+import uvicorn
+
+import requests
+
+from typing import List, Optional
+import sqlparse
+from sqlalchemy import create_engine
+import pandas as pd
+#from retrying import retry
+import datetime
+import json
+from json import loads
+import pandas as pd
+import time
+from langchain.callbacks import get_openai_callback
+
+from langchain_community.vectorstores import Chroma
+from langchain_openai import OpenAIEmbeddings
+# from RAG_strategy import multi_query, naive_rag
+# from Indexing_Split import create_retriever as split_retriever
+# from Indexing_Split import gen_doc_from_database, gen_doc_from_history
+from semantic_search import semantic_cache
+
+from dotenv import load_dotenv
+import os
+from langchain_community.vectorstores import SupabaseVectorStore
+from langchain_openai import OpenAIEmbeddings
+from supabase.client import Client, create_client
+
+
+from add_vectordb import GetVectorStore
+from faiss_index import create_faiss_retriever, faiss_query
+from local_llm import ollama_, hf
+# from local_llm import ollama_, taide_llm, hf
+# llm = hf()
+
+load_dotenv()
+URI = os.getenv("SUPABASE_URI")
+supabase_url = os.environ.get("SUPABASE_URL")
+supabase_key = os.environ.get("SUPABASE_KEY")
+supabase: Client = create_client(supabase_url, supabase_key)
+
+global_retriever = None
+llm = None
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    global global_retriever
+    global llm
+    global vector_store
+    
+    start = time.time()
+    document_table = "documents"
+
+    embeddings = OpenAIEmbeddings()
+    # vector_store = GetVectorStore(embeddings, supabase, document_table)
+    # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
+    global_retriever = create_faiss_retriever()
+    llm = hf()
+
+    print(time.time() - start)
+    yield
+
+def get_retriever():
+    return global_retriever
+
+def get_llm():
+    return llm
+
+def get_vector_store():
+    return vector_store
+
+app = FastAPI(lifespan=lifespan)
+
+# templates = Jinja2Templates(directory="temp")
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+class ChatHistoryItem(BaseModel):
+    q: str
+    a: str
+
+def replace_unicode_escapes(match):
+    return chr(int(match.group(1), 16))
+
+@app.post("/answer_with_history")
+def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?', chat_history: List[ChatHistoryItem] = Body(...), 
+                       retriever=Depends(get_retriever), llm=Depends(get_llm)):
+    start = time.time()
+    
+    chat_history = [(item.q, item.a) for item in chat_history if item.a != "" and item.a != "string"]
+    print(chat_history)
+
+    # TODO: similarity search
+    
+    with get_openai_callback() as cb:
+
+        # cache_question, cache_answer = semantic_cache(supabase, question)
+        # if cache_answer:
+        #     processing_time = time.time() - start
+        #     save_history(question, cache_answer, cache_question, cb, processing_time)
+
+        #     return {"Answer": cache_answer}
+        
+        # final_answer, reference_docs = multi_query(question, retriever, chat_history)
+        # final_answer, reference_docs = naive_rag(question, retriever, chat_history)
+        final_answer = faiss_query(question, global_retriever, llm)
+        
+        decoded_string  = re.sub(r'\\u([0-9a-fA-F]{4})', replace_unicode_escapes, final_answer)
+        print(decoded_string )
+
+        reference_docs = global_retriever.get_relevant_documents(question)
+    processing_time = time.time() - start
+    print(processing_time)
+    save_history(question, decoded_string , reference_docs, cb, processing_time)
+
+    # print(response)
+    response_content = json.dumps({"Answer": decoded_string }, ensure_ascii=False)
+
+    # Manually create a Response object if using Flask
+    return JSONResponse(content=response_content)
+    # response_content = json.dumps({"Answer": final_answer}, ensure_ascii=False)
+    # print(response_content)
+    # return json.loads(response_content)
+
+    
+
+def save_history(question, answer, reference, cb, processing_time):
+    # reference = [doc.dict() for doc in reference]
+    record = {
+        'Question': [question],
+        'Answer': [answer],
+        'Total_Tokens': [cb.total_tokens],
+        'Total_Cost': [cb.total_cost],
+        'Processing_time': [processing_time],
+        'Contexts': [str(reference)] if isinstance(reference, list) else [reference]
+    }
+    df = pd.DataFrame(record)
+    engine = create_engine(URI)
+    df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
+
+class history_output(BaseModel):
+    Question: str
+    Answer: str
+    Contexts: str
+    Total_Tokens: int
+    Total_Cost: float
+    Processing_time: float
+    Time: datetime.datetime
+    
+@app.get('/history', response_model=List[history_output])
+async def get_history():
+    engine = create_engine(URI, echo=True)
+
+    df = pd.read_sql_table("systex_records", engine.connect())  
+    df.fillna('', inplace=True)
+    result = df.to_json(orient='index', force_ascii=False)
+    result = loads(result)
+    return result.values()
+
+def send_heartbeat(url):
+    while True:
+        try:
+            response = requests.get(url)
+            if response.status_code != 200:
+                print(f"Failed to send heartbeat, status code: {response.status_code}")
+        except requests.RequestException as e:
+            print(f"Error occurred: {e}")
+        # 等待 60 秒
+        time.sleep(600)
+
+def start_heartbeat(url):
+    heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url,))
+    heartbeat_thread.daemon = True
+    heartbeat_thread.start()
+    
+if __name__ == "__main__":
+    
+    # url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
+    # start_heartbeat(url)
+    
+    uvicorn.run("RAG_app:app", host='0.0.0.0', reload=True, port=8080)
+

+ 260 - 0
RAG_strategy.py

@@ -0,0 +1,260 @@
+from langchain.prompts import ChatPromptTemplate
+from langchain.load import dumps, loads
+from langchain_core.output_parsers import StrOutputParser
+from langchain_openai import ChatOpenAI
+from langchain_community.llms import Ollama
+from langchain_community.chat_models import ChatOllama
+from operator import itemgetter
+from langchain_core.runnables import RunnablePassthrough
+from langchain import hub
+from langchain.globals import set_llm_cache
+from langchain import PromptTemplate
+
+
+from langchain_core.runnables import (
+    RunnableBranch,
+    RunnableLambda,
+    RunnableParallel,
+    RunnablePassthrough,
+)
+from typing import Tuple, List, Optional
+from langchain_core.messages import AIMessage, HumanMessage
+
+from typing import List
+from dotenv import load_dotenv
+load_dotenv()
+
+# from local_llm import ollama_, hf
+# llm = hf()
+# llm = taide_llm
+########################################################################################################################
+# from langchain.cache import SQLiteCache
+# set_llm_cache(SQLiteCache(database_path=".langchain.db"))
+########################################################################################################################
+# llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+# llm = ollama_()
+
+def multi_query_chain(llm):
+    # Multi Query: Different Perspectives
+    template = """You are an AI language model assistant. Your task is to generate three 
+    different versions of the given user question to retrieve relevant documents from a vector 
+    database. By generating multiple perspectives on the user question, your goal is to help
+    the user overcome some of the limitations of the distance-based similarity search. 
+    Provide these alternative questions separated by newlines. 
+
+    You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
+    
+    
+    Original question: {question}"""
+    prompt_perspectives = ChatPromptTemplate.from_template(template)
+
+    
+    # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
+    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
+
+    generate_queries = (
+        prompt_perspectives 
+        | llm
+        | StrOutputParser() 
+        | (lambda x: x.split("\n"))
+    )
+
+    return generate_queries
+
+def multi_query(question, retriever, chat_history):
+
+    def get_unique_union(documents: List[list]):
+        """ Unique union of retrieved docs """
+        # Flatten list of lists, and convert each Document to string
+        flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
+        # Get unique documents
+        unique_docs = list(set(flattened_docs))
+        # Return
+        return [loads(doc) for doc in unique_docs]
+    
+
+    _search_query = get_search_query()
+    modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
+    print(modified_question)
+
+    generate_queries = multi_query_chain()
+
+    retrieval_chain = generate_queries | retriever.map() | get_unique_union
+    docs = retrieval_chain.invoke({"question":modified_question})
+
+    answer = multi_query_rag_prompt(retrieval_chain, modified_question)
+
+    return answer, docs
+
+def multi_query_rag_prompt(retrieval_chain, question):
+    # RAG
+    template = """Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    Use three sentences maximum and keep the answer concise.
+    If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    """
+
+    prompt = ChatPromptTemplate.from_template(template)
+
+    # llm = ChatOpenAI(temperature=0)
+    # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
+    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
+
+    final_rag_chain = (
+        {"context": retrieval_chain, 
+        "question": itemgetter("question")} 
+        | prompt
+        | llm
+        | StrOutputParser()
+    )
+
+    # answer = final_rag_chain.invoke({"question":question})
+
+    answer = ""
+    for text in final_rag_chain.stream({"question":question}):
+        print(text, end="", flush=True)
+        answer += text
+
+
+    return answer
+########################################################################################################################
+
+def get_search_query():
+    # Condense a chat history and follow-up question into a standalone question
+    # 
+    # _template = """Given the following conversation and a follow up question, 
+    # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
+    # Generate standalone question in its original language.
+    # Chat History:
+    # {chat_history}
+    # Follow Up Input: {question}
+
+    # Hint:
+    # * Refer to chat history and add the subject to the question
+    # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
+    
+    # Standalone question:"""  # noqa: E501
+    _template = """Rewrite the following query by incorporating relevant context from the conversation history.
+    The rewritten query should:
+    
+    - Preserve the core intent and meaning of the original query
+    - Expand and clarify the query to make it more specific and informative for retrieving relevant context
+    - Avoid introducing new topics or queries that deviate from the original query
+    - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
+    - The rewritten query should be in its original language.
+    
+    Return ONLY the rewritten query text, without any additional formatting or explanations.
+    
+    Conversation History:
+    {chat_history}
+    
+    Original query: [{question}]
+    
+    Rewritten query: 
+    """
+    CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
+
+    def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
+        buffer = []
+        for human, ai in chat_history:
+            buffer.append(HumanMessage(content=human))
+            buffer.append(AIMessage(content=ai))
+        return buffer
+
+    _search_query = RunnableBranch(
+        # If input includes chat_history, we condense it with the follow-up question
+        (
+            RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
+                run_name="HasChatHistoryCheck"
+            ),  # Condense follow-up question and chat into a standalone_question
+            RunnablePassthrough.assign(
+                chat_history=lambda x: _format_chat_history(x["chat_history"])
+            )
+            | CONDENSE_QUESTION_PROMPT
+            | ChatOpenAI(temperature=0)
+            | StrOutputParser(),
+        ),
+        # Else, we have no chat history, so just pass through the question
+        RunnableLambda(lambda x : x["question"]),
+    )
+
+    return _search_query
+########################################################################################################################
+def naive_rag(question, retriever, chat_history):
+    _search_query = get_search_query()
+    modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
+    print(modified_question)
+
+    #### RETRIEVAL and GENERATION ####
+
+    # Prompt
+    prompt = hub.pull("rlm/rag-prompt")
+
+    # LLM
+    # llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
+
+    # Post-processing
+    def format_docs(docs):
+        return "\n\n".join(doc.page_content for doc in docs)
+
+    reference = retriever.get_relevant_documents(modified_question)
+    
+    # Chain
+    rag_chain = (
+        {"context": retriever | format_docs, "question": RunnablePassthrough()}
+        | prompt
+        | llm
+        | StrOutputParser()
+    )
+
+    # Question
+    answer = rag_chain.invoke(modified_question)
+
+    return answer, reference
+################################################################################################
+
+
+if __name__ == "__main__":
+    from faiss_index import create_faiss_retriever, faiss_query 
+    global_retriever = create_faiss_retriever()
+    generate_queries = multi_query_chain()
+    question = "台灣為什麼要制定氣候變遷因應法?"
+    
+    questions = generate_queries.invoke(question)
+    questions = [item for item in questions if item != ""]
+    # print(questions)
+
+    
+    results = list(map(global_retriever.get_relevant_documents, questions))
+    results = [item for sublist in results for item in sublist]
+    print(len(results))
+    print(results)
+
+    
+    # retrieval_chain = generate_queries | global_retriever.map
+    # docs = retrieval_chain.invoke(question)
+    # print(docs)
+    # print(len(docs))
+
+
+
+    # print(len(results))
+    # for doc in results[:10]:
+    #     print(doc)
+    #     print("-----------------------------------------------------------------------")
+
+
+    # results = get_unique_union(results)
+    # print(len(results))
+
+    # retrieval_chain = generate_queries | global_retriever.map | get_unique_union
+    # docs = retrieval_chain.invoke(question)
+    # print(len(docs))
+
+
+

+ 194 - 0
add_vectordb.py

@@ -0,0 +1,194 @@
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.vectorstores import Chroma
+from langchain_community.document_loaders import TextLoader
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import PyPDFLoader
+from langchain_community.document_loaders import Docx2txtLoader
+
+import os
+import glob
+
+from langchain_community.vectorstores import SupabaseVectorStore
+from langchain_openai import OpenAIEmbeddings
+from supabase.client import Client, create_client
+
+
+def get_data_list(data_list=None, path=None, extension=None, update=False):
+    files = data_list or glob.glob(os.path.join(path, f"*.{extension}"))
+    if update:    
+        doc = files.copy()
+    else:
+        existed_data = check_existed_data(supabase)
+        doc = []
+        for file_path in files:
+            filename = os.path.basename(file_path)
+            if filename not in existed_data:
+                doc.append(file_path)
+
+    return doc
+
+
+def read_and_split_files(data_list=None, path=None, extension=None, update=False):
+
+    def load_and_split(file_list):
+        chunks = []
+        for file in file_list:
+            if file.endswith(".txt"):
+                loader = TextLoader(file, encoding='utf-8')
+            elif file.endswith(".pdf"):
+                loader = PyPDFLoader(file)
+            elif file.endswith(".docx"):
+                loader = Docx2txtLoader(file)
+            else:
+                print(f"Unsupported file extension: {file}")
+                continue
+
+            docs = loader.load()
+
+            # Split
+            if file.endswith(".docx"):
+                separators = ['\u25cb\s*第.*?條', '\u25cf\s*第.*?條']
+                text_splitter = RecursiveCharacterTextSplitter(is_separator_regex=True, separators=separators, chunk_size=300, chunk_overlap=0)
+            else:
+                text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=500, chunk_overlap=0)
+            splits = text_splitter.split_documents(docs)
+
+            chunks.extend(splits)
+
+        return chunks
+
+
+    doc = get_data_list(data_list=data_list, path=path, extension=extension, update=update)
+    # Index
+    docs = load_and_split(doc)
+
+    return docs
+
+def create_ids(docs):
+    # Create a dictionary to count occurrences of each page in each document
+    page_counter = {}
+
+    # List to store the resulting IDs
+    document_ids = []
+
+    # Generate IDs
+    for doc in [docs[i].metadata for i in range(len(docs))]:
+        source = doc['source']
+        file_name = os.path.basename(source).split('.')[0]
+
+        if "page" in doc.keys():
+            page = doc['page']
+            key = f"{source}_{page}"
+        else:
+            key = f"{source}"
+
+        if key not in page_counter:
+            page_counter[key] = 1
+        else:
+            page_counter[key] += 1
+        
+        if "page" in doc.keys():
+            doc_id = f"{file_name} | page {page} | chunk {page_counter[key]}"
+        else:
+            doc_id = f"{file_name} | chunk {page_counter[key]}"
+
+        
+        document_ids.append(doc_id)
+
+    return document_ids
+
+def get_document(data_list=None, path=None, extension=None, update=False):
+    docs = read_and_split_files(data_list=data_list, path=path, extension=extension, update=update)
+    document_ids = create_ids(docs)
+
+    for doc in docs:
+        doc.metadata['source'] = os.path.basename(doc.metadata['source'])
+        # print(doc.metadata)
+
+    # document_metadatas = [{'source': doc.metadata['source'], 'page': doc.metadata['page'], 'chunk': int(id.split("chunk ")[-1])} for doc, id in zip(docs, document_ids)]
+    document_metadatas = []
+
+    for doc, id in zip(docs, document_ids):
+        chunk_number = int(id.split("chunk ")[-1])
+        doc.metadata['chunk'] = chunk_number
+        doc.metadata['extension'] = os.path.basename(doc.metadata['source']).split(".")[-1]
+        document_metadatas.append(doc.metadata)
+
+    documents = [docs.metadata['source'].split(".")[0] + docs.page_content for docs in docs]
+
+    return document_ids, documents, document_metadatas
+
+def check_existed_data(supabase):
+    response = supabase.table('documents').select("id, metadata").execute()
+    existed_data = list(set([data['metadata']['source'] for data in response.data]))
+    # existed_data = [(data['id'], data['metadata']['source']) for data in response.data]
+    return existed_data
+
+class GetVectorStore(SupabaseVectorStore):
+    def __init__(self, embeddings, supabase, table_name):
+        super().__init__(embedding=embeddings, client=supabase, table_name=table_name, query_name="match_documents")
+
+    def insert(self, documents, document_metadatas):
+        self.add_texts(
+            texts=documents,
+            metadatas=document_metadatas,
+        )
+
+    def delete(self, file_list):
+        for file_name in file_list:
+            self._client.table(self.table_name).delete().eq('metadata->>source', file_name).execute()
+
+    def update(self, documents, document_metadatas, update_existing_data=False):
+        if not document_metadatas:  # no new data
+            return
+
+        if update_existing_data:
+            file_list = list(set(metadata['source'] for metadata in document_metadatas))
+            self.delete(file_list)
+
+        self.insert(documents, document_metadatas)
+
+if __name__ == "__main__":
+
+    load_dotenv()
+    supabase_url = os.environ.get("SUPABASE_URL")
+    supabase_key = os.environ.get("SUPABASE_KEY")
+    document_table = "documents"
+    supabase: Client = create_client(supabase_url, supabase_key)
+
+    embeddings = OpenAIEmbeddings()
+
+    ###################################################################################
+    # get vector store
+    vector_store = GetVectorStore(embeddings, supabase, document_table)
+
+    ###################################################################################
+    # update data (old + new / all new / all old)
+    path = "/home/mia/systex/Documents"
+    extension = "pdf"
+    # file = None
+
+    # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
+    # file = [os.path.join(path, file) for file in file_list]
+    file_list = glob.glob(os.path.join(path, "*"))
+    print(file_list)
+    
+    update = False
+    document_ids, documents, document_metadatas = get_document(data_list=file_list, path=path, extension=extension, update=update)
+    vector_store.update(documents, document_metadatas, update_existing_data=update)
+
+    ###################################################################################
+    # insert new data (all new)
+    # vector_store.insert(documents, document_metadatas)
+
+    ###################################################################################
+    # delete data
+    # file_list = ["溫室氣體排放量盤查作業指引113年版.pdf"]
+    # vector_store.delete(file_list)
+
+    ###################################################################################
+    # get retriver
+    # retriever = vector_store.as_retriever(search_kwargs={"k": 6})

+ 80 - 0
conda_env.txt

@@ -0,0 +1,80 @@
+# This file may be used to create an environment using:
+# $ conda create --name llama3 --file conda_env.txt
+# platform: linux-64
+_libgcc_mutex=0.1=conda_forge
+_openmp_mutex=4.5=2_gnu
+aiosignal=1.3.1=pyhd8ed1ab_0
+annotated-types=0.7.0=pyhd8ed1ab_0
+anyio=4.2.0=py312h06a4308_0
+async-timeout=4.0.3=pyhd8ed1ab_0
+attrs=24.2.0=pyh71513ae_0
+blas=1.1=openblas
+bottleneck=1.3.7=py312ha883a20_0
+brotli-python=1.0.9=py312h6a678d5_8
+bzip2=1.0.8=h5eee18b_6
+ca-certificates=2024.7.4=hbcca054_0
+certifi=2024.7.4=py312h06a4308_0
+cffi=1.16.0=py312h5eee18b_1
+charset-normalizer=3.3.2=pyhd8ed1ab_0
+click=8.1.7=py312h06a4308_0
+deprecation=2.1.0=pyh9f0ad1d_0
+expat=2.6.2=h6a678d5_0
+fastapi=0.103.0=py312h06a4308_0
+frozenlist=1.4.0=py312h5eee18b_0
+greenlet=3.0.1=py312h6a678d5_0
+h11=0.14.0=py312h06a4308_0
+h2=4.1.0=pyhd8ed1ab_0
+hpack=4.0.0=pyh9f0ad1d_0
+httpcore=1.0.5=pyhd8ed1ab_0
+hyperframe=6.0.1=pyhd8ed1ab_0
+idna=3.7=pyhd8ed1ab_0
+jsonpatch=1.33=pyhd8ed1ab_0
+jsonpointer=2.0=py_0
+ld_impl_linux-64=2.38=h1181459_1
+libffi=3.4.4=h6a678d5_1
+libgcc-ng=14.1.0=h77fa898_0
+libgfortran-ng=14.1.0=h69a702a_0
+libgfortran5=14.1.0=hc5f4f2c_0
+libgomp=14.1.0=h77fa898_0
+libopenblas=0.3.28=pthreads_h94d23a6_0
+libstdcxx-ng=11.2.0=h1234567_1
+libuuid=1.41.5=h5eee18b_0
+lz4-c=1.9.4=h6a678d5_1
+multidict=6.0.4=py312h5eee18b_0
+ncurses=6.4=h6a678d5_0
+numexpr=2.8.7=py312he7dcb8a_0
+numpy=1.26.4=py312h2809609_0
+numpy-base=1.26.4=py312he1a6c75_0
+openssl=3.3.1=h4bc722e_2
+orjson=3.9.15=py312h97a8848_0
+packaging=24.1=pyhd8ed1ab_0
+pandas=2.2.2=py312h526ad5a_0
+pip=24.2=py312h06a4308_0
+pycparser=2.22=pyhd8ed1ab_0
+pysocks=1.7.1=pyha2e5f31_6
+python=3.12.4=h5148396_1
+python-dateutil=2.9.0post0=py312h06a4308_2
+python-tzdata=2023.3=pyhd3eb1b0_0
+pytz=2024.1=py312h06a4308_0
+readline=8.2=h5eee18b_0
+requests=2.32.3=pyhd8ed1ab_0
+setuptools=72.1.0=py312h06a4308_0
+six=1.16.0=pyhd3eb1b0_1
+sniffio=1.3.0=py312h06a4308_0
+sqlite=3.45.3=h5eee18b_0
+starlette=0.27.0=py312h06a4308_0
+tk=8.6.14=h39e8969_0
+typing-extensions=4.12.2=hd8ed1ab_0
+typing_extensions=4.12.2=pyha770c72_0
+tzdata=2024a=h04d1e81_0
+tzlocal=5.2=py312h06a4308_0
+urllib3=2.2.2=pyhd8ed1ab_1
+uvicorn=0.20.0=py312h06a4308_0
+websockets=10.4=py312h5eee18b_1
+wheel=0.43.0=py312h06a4308_0
+xz=5.4.6=h5eee18b_1
+yaml=0.2.5=h7f98852_2
+yarl=1.9.3=py312h5eee18b_0
+zlib=1.2.13=h5eee18b_1
+zstandard=0.22.0=py312h2c38b39_0
+zstd=1.5.5=hc292b87_2

+ 440 - 0
faiss_index.py

@@ -0,0 +1,440 @@
+import faiss
+import numpy as np
+import json
+from time import time
+import asyncio
+from datasets import Dataset
+from typing import List
+from dotenv import load_dotenv
+import os
+import pickle
+from supabase.client import Client, create_client
+from langchain_openai import OpenAIEmbeddings, ChatOpenAI
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+import pandas as pd
+from langchain_core.documents import Document
+from langchain.load import dumps, loads
+
+# Import from the parent directory
+import sys
+
+from RAG_strategy import multi_query_chain
+sys.path.append('..')
+# from RAG_strategy_Taide import taide_llm, system_prompt, multi_query
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+# llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+from add_vectordb import GetVectorStore
+# from local_llm import ollama_, hf
+# # from local_llm import ollama_, taide_llm, hf
+# llm = hf()
+# llm = taide_llm
+
+
+# Import RAGAS metrics
+from ragas import evaluate
+from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision
+
+# Load environment variables
+load_dotenv('../../.env')
+supabase_url = os.getenv("SUPABASE_URL")
+supabase_key = os.getenv("SUPABASE_KEY")
+openai_api_key = os.getenv("OPENAI_API_KEY")
+document_table = "documents"
+
+# Initialize Supabase client
+supabase: Client = create_client(supabase_url, supabase_key)
+
+# Initialize embeddings and chat model
+embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
+
+def download_embeddings():
+    response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
+    embeddings = []
+    ids = []
+    metadatas = []
+    contents = []
+    for item in response.data:
+        embedding = json.loads(item['embedding'])
+        embeddings.append(embedding)
+        ids.append(item['id'])
+        metadatas.append(item['metadata'])
+        contents.append(item['content'])
+    return np.array(embeddings, dtype=np.float32), ids, metadatas, contents
+
+def create_faiss_index(embeddings):
+    dimension = embeddings.shape[1]
+    index = faiss.IndexFlatIP(dimension)  # Use Inner Product for cosine similarity
+    faiss.normalize_L2(embeddings)  # Normalize embeddings for cosine similarity
+    index.add(embeddings)
+    return index
+
+def save_faiss_index(index, file_path):
+    faiss.write_index(index, file_path)
+    print(f"FAISS index saved to {file_path}")
+
+def load_faiss_index(file_path):
+    if os.path.exists(file_path):
+        index = faiss.read_index(file_path)
+        print(f"FAISS index loaded from {file_path}")
+        return index
+    return None
+
+def save_metadata(ids, metadatas, contents, file_path):
+    with open(file_path, 'wb') as f:
+        pickle.dump((ids, metadatas, contents), f)
+    print(f"Metadata saved to {file_path}")
+
+def load_metadata(file_path):
+    if os.path.exists(file_path):
+        with open(file_path, 'rb') as f:
+            ids, metadatas, contents = pickle.load(f)
+        print(f"Metadata loaded from {file_path}")
+        return ids, metadatas, contents
+    return None, None, None
+
+def search_faiss(index, query_vector, k=4):
+    query_vector = np.array(query_vector, dtype=np.float32).reshape(1, -1)
+    faiss.normalize_L2(query_vector)
+    distances, indices = index.search(query_vector, k)
+    return distances[0], indices[0]
+
+class FAISSRetriever:
+    def __init__(self, index, ids, metadatas, contents, embeddings_model):
+        self.index = index
+        self.ids = ids
+        self.metadatas = metadatas
+        self.contents = contents
+        self.embeddings_model = embeddings_model
+
+    def get_relevant_documents(self, query: str, k: int = 4) -> List[Document]:
+        query_vector = self.embeddings_model.embed_query(query)
+        _, indices = search_faiss(self.index, query_vector, k=k)
+        return [
+            Document(page_content=self.contents[i], metadata=self.metadatas[i])
+            for i in indices
+        ]
+    def map(self, query_list: List[list]) -> List[Document]:
+        def get_unique_union(documents: List[list]):
+            """ Unique union of retrieved docs """
+            # Flatten list of lists, and convert each Document to string
+            flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
+            # Get unique documents
+            unique_docs = list(set(flattened_docs))
+            # Return
+            return [loads(doc) for doc in unique_docs]
+        
+        documents = []
+        for query in query_list:
+            if query != "":
+                docs = self.get_relevant_documents(query)
+
+            documents.extend(docs)
+
+        return get_unique_union(documents)
+
+def load_qa_pairs():
+    # df = pd.read_csv("../QA_database_rows.csv")
+    response = supabase.table('QA_database').select("Question, Answer").execute()
+    df = pd.DataFrame(response.data)
+
+    return df['Question'].tolist(), df['Answer'].tolist()
+
+def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
+    generate_queries = multi_query_chain(llm)
+
+    questions = generate_queries.invoke(question)
+    questions = [item for item in questions if item != ""]
+
+    # docs = list(map(retriever.get_relevant_documents, questions))
+    docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
+    docs = [item for sublist in docs for item in sublist]
+
+    return docs
+
+def faiss_query(question: str, retriever: FAISSRetriever, llm, multi_query: bool = False) -> str:
+    if multi_query:
+        docs = faiss_multiquery(question, retriever, llm)
+        # print(docs)
+    else:
+        docs = retriever.get_relevant_documents(question, k=10)
+        # print(docs)
+
+    context = "\n".join(doc.page_content for doc in docs)
+    
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,
+    請用繁體中文回答問題 \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    Use five sentences maximum and keep the answer concise.
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    勿回答無關資訊
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
+    Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    prompt = ChatPromptTemplate.from_template(
+        system_prompt + "\n\n" +
+        template
+    )
+    
+    # prompt = ChatPromptTemplate.from_template(
+    #     system_prompt + "\n\n" +
+    #     "Answer the following question based on this context:\n\n"
+    #     "{context}\n\n"
+    #     "Question: {question}\n"
+    #     "Answer in the same language as the question. If you don't know the answer, "
+    #     "say 'I'm sorry, I don't have enough information to answer that question.'"
+    # )
+
+    
+    # chain = prompt | taide_llm | StrOutputParser()
+    chain = prompt | llm | StrOutputParser()
+    return chain.invoke({"context": context, "question": question})
+
+
+def create_faiss_retriever():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+
+    index = load_faiss_index(faiss_index_path)
+    ids, metadatas, contents = load_metadata(metadata_path)
+
+    if index is None or ids is None:
+        print("FAISS index or metadata not found. Creating new index...")
+        print("Downloading embeddings from Supabase...")
+        embeddings_array, ids, metadatas, contents = download_embeddings()
+
+        print("Creating FAISS index...")
+        index = create_faiss_index(embeddings_array)
+
+        save_faiss_index(index, faiss_index_path)
+        save_metadata(ids, metadatas, contents, metadata_path)
+    else:
+        print("Using existing FAISS index and metadata.")
+
+    print("Creating FAISS retriever...")
+    faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
+
+    return faiss_retriever
+
+
+async def run_evaluation():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+
+    index = load_faiss_index(faiss_index_path)
+    ids, metadatas, contents = load_metadata(metadata_path)
+
+    if index is None or ids is None:
+        print("FAISS index or metadata not found. Creating new index...")
+        print("Downloading embeddings from Supabase...")
+        embeddings_array, ids, metadatas, contents = download_embeddings()
+
+        print("Creating FAISS index...")
+        index = create_faiss_index(embeddings_array)
+
+        save_faiss_index(index, faiss_index_path)
+        save_metadata(ids, metadatas, contents, metadata_path)
+    else:
+        print("Using existing FAISS index and metadata.")
+
+    print("Creating FAISS retriever...")
+    faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
+
+    print("Creating original vector store...")
+    original_vector_store = GetVectorStore(embeddings, supabase, document_table)
+    original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
+
+    questions, ground_truths = load_qa_pairs()
+
+    for question, ground_truth in zip(questions, ground_truths):
+        print(f"\nQuestion: {question}")
+
+        start_time = time()
+        faiss_answer = faiss_query(question, faiss_retriever)
+        faiss_docs = faiss_retriever.get_relevant_documents(question)
+        faiss_time = time() - start_time
+        print(f"FAISS Answer: {faiss_answer}")
+        print(f"FAISS Time: {faiss_time:.4f} seconds")
+
+        start_time = time()
+        original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
+        original_time = time() - start_time
+        print(f"Original Answer: {original_answer}")
+        print(f"Original Time: {original_time:.4f} seconds")
+
+        # faiss_datasets = {
+        #     "question": [question],
+        #     "answer": [faiss_answer],
+        #     "contexts": [[doc.page_content for doc in faiss_docs]],
+        #     "ground_truth": [ground_truth]
+        # }
+        # faiss_evalsets = Dataset.from_dict(faiss_datasets)
+
+        # faiss_result = evaluate(
+        #     faiss_evalsets,
+        #     metrics=[
+        #         context_precision,
+        #         faithfulness,
+        #         answer_relevancy,
+        #         context_recall,
+        #     ],
+        # )
+
+        # print("FAISS RAGAS Evaluation:")
+        # print(faiss_result.to_pandas())
+
+        # original_datasets = {
+        #     "question": [question],
+        #     "answer": [original_answer],
+        #     "contexts": [[doc.page_content for doc in original_docs]],
+        #     "ground_truth": [ground_truth]
+        # }
+        # original_evalsets = Dataset.from_dict(original_datasets)
+
+        # original_result = evaluate(
+        #     original_evalsets,
+        #     metrics=[
+        #         context_precision,
+        #         faithfulness,
+        #         answer_relevancy,
+        #         context_recall,
+        #     ],
+        # )
+
+        # print("Original RAGAS Evaluation:")
+        # print(original_result.to_pandas())
+
+    print("\nPerformance comparison complete.")
+
+
+async def ask_question():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+
+    index = load_faiss_index(faiss_index_path)
+    ids, metadatas, contents = load_metadata(metadata_path)
+
+    if index is None or ids is None:
+        print("FAISS index or metadata not found. Creating new index...")
+        print("Downloading embeddings from Supabase...")
+        embeddings_array, ids, metadatas, contents = download_embeddings()
+
+        print("Creating FAISS index...")
+        index = create_faiss_index(embeddings_array)
+
+        save_faiss_index(index, faiss_index_path)
+        save_metadata(ids, metadatas, contents, metadata_path)
+    else:
+        print("Using existing FAISS index and metadata.")
+
+    print("Creating FAISS retriever...")
+    faiss_retriever = FAISSRetriever(index, ids, metadatas, contents, embeddings)
+
+    # print("Creating original vector store...")
+    # original_vector_store = GetVectorStore(embeddings, supabase, document_table)
+    # original_retriever = original_vector_store.as_retriever(search_kwargs={"k": 4})
+
+    # questions, ground_truths = load_qa_pairs()
+
+    # for question, ground_truth in zip(questions, ground_truths):
+    question = ""
+    while question != "exit":
+        question = input("Question: ")
+        print(f"\nQuestion: {question}")
+
+        start_time = time()
+        faiss_answer = faiss_query(question, faiss_retriever)
+        faiss_docs = faiss_retriever.get_relevant_documents(question)
+        faiss_time = time() - start_time
+        print(f"FAISS Answer: {faiss_answer}")
+        print(f"FAISS Time: {faiss_time:.4f} seconds")
+
+        # start_time = time()
+        # original_answer, original_docs = multi_query(question, original_retriever, chat_history=[])
+        # original_time = time() - start_time
+        # print(f"Original Answer: {original_answer}")
+        # print(f"Original Time: {original_time:.4f} seconds")
+
+if __name__ == "__main__":
+
+    global_retriever = create_faiss_retriever()
+
+    questions, ground_truths = load_qa_pairs()
+    results = []
+
+    for question, ground_truth in zip(questions, ground_truths):
+        # For multi_query=True
+        start = time()
+        final_answer_multi = faiss_query(question, global_retriever, multi_query=True)
+        processing_time_multi = time() - start
+        # print(final_answer_multi)
+        # print(processing_time_multi)
+
+        # For multi_query=False
+        start = time()
+        final_answer_single = faiss_query(question, global_retriever, multi_query=False)
+        processing_time_single = time() - start
+        # print(final_answer_single)
+        # print(processing_time_single)
+
+        # Store results in a dictionary
+        result = {
+            "question": question,
+            "ground_truth": ground_truth,
+            "final_answer_multi_query": final_answer_multi,
+            "processing_time_multi_query": processing_time_multi,
+            "final_answer_single_query": final_answer_single,
+            "processing_time_single_query": processing_time_single
+        }
+        print(result)
+        
+        results.append(result)
+
+        with open('qa_results.json', 'a', encoding='utf8') as outfile:
+            json.dump(result, outfile, indent=4, ensure_ascii=False)
+            outfile.write("\n")  # Ensure each result is on a new line
+        
+
+    # Save results to a JSON file
+    with open('qa_results+all.json', 'w', encoding='utf8') as outfile:
+        json.dump(results, outfile, indent=4, ensure_ascii=False)
+
+    print('All questions done!')
+    # question = ""
+    # while question != "exit":
+    #     # question = "國家溫室氣體長期減量目標" 
+    #     question = input("Question: ")
+    #     if question.strip().lower == "exit": break
+
+    #     start = time()
+    #     final_answer = faiss_query(question, global_retriever, multi_query=True)
+    #     print(final_answer)
+    #     processing_time = time() - start
+    #     print(processing_time)
+
+        
+    #     start = time() 
+    #     final_answer = faiss_query(question, global_retriever, multi_query=False)
+    #     print(final_answer)
+    #     processing_time = time() - start
+    #     print(processing_time)
+    # print("Chatbot closed!")
+
+    # asyncio.run(ask_question())

+ 119 - 0
local_llm.py

@@ -0,0 +1,119 @@
+from langchain_community.chat_models import ChatOllama
+from langchain_openai import ChatOpenAI
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
+import torch
+from langchain_huggingface import HuggingFacePipeline
+
+
+from typing import Any, List, Optional, Dict
+from langchain_core.callbacks import CallbackManagerForLLMRun
+from langchain_core.language_models import BaseChatModel
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
+from langchain_core.outputs import ChatResult, ChatGeneration
+from pydantic import Field
+import subprocess
+
+import time
+
+from dotenv import load_dotenv
+load_dotenv()
+
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+
+def hf():
+    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+    llm = HuggingFacePipeline.from_model_id(
+        model_id=model_id,
+        task="text-generation",
+        model_kwargs={"torch_dtype": torch.bfloat16},
+        pipeline_kwargs={"return_full_text": False,
+            "max_new_tokens": 512,
+            "repetition_penalty":1.03},
+        device=0, device_map='cuda')
+    # print(llm.pipeline)
+    llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+
+    return llm
+
+
+def ollama_():
+    # model = "cwchang/llama3-taide-lx-8b-chat-alpha1"
+    model = "llama3.1:latest"
+    # model = "llama3.1:70b"
+    # model = "893379029/piccolo-large-zh-v2"
+    sys = "你是一個來自台灣的 AI 助理,,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。請用 5 句話以內回答問題。"
+    # llm = ChatOllama(model=model, num_gpu=2, num_thread=32, temperature=0, system=sys, keep_alive="10m", verbose=True)
+    llm = ChatOllama(model=model, num_gpu=2, temperature=0, system=sys, keep_alive="10m")
+
+    return llm
+
+
+def openai_(): # not lacal
+    llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
+
+    return llm
+
+class OllamaChatModel(BaseChatModel):
+    model_name: str = Field(default="taide-local-llama3")
+
+    def _generate(
+            self,
+            messages: List[BaseMessage],
+            stop: Optional[List[str]] = None,
+            run_manager: Optional[CallbackManagerForLLMRun] = None,
+            **kwargs: Any,
+    ) -> ChatResult:
+        formatted_messages = []
+        for msg in messages:
+            if isinstance(msg, HumanMessage):
+                formatted_messages.append({"role": "user", "content": msg.content})
+            elif isinstance(msg, AIMessage):
+                formatted_messages.append({"role": "assistant", "content": msg.content})
+            elif isinstance(msg, SystemMessage):
+                 formatted_messages.append({"role": "system", "content": msg.content})
+
+        # prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" # TAIDE llama2
+        prompt = f"<|begin_of_text|><|start_header_id|>{system_prompt}<|end_header_id|>" # TAIDE llama3
+        for msg in formatted_messages:
+            if msg['role'] == 'user':
+                # prompt += f"{msg['content']} [/INST]" # TAIDE llama2
+                prompt += f"<|eot_id|><|start_header_id|>{msg['content']}<|end_header_id|>" # TAIDE llama3
+            elif msg['role'] == "assistant":
+                # prompt += f"{msg['content']} </s><s>[INST]" # TAIDE llama2
+                prompt += f"<|eot_id|><|start_header_id|>{msg['content']}<|end_header_id|>" # TAIDE llama3
+
+        command = ["docker", "exec", "-it", "ollama", "ollama", "run", self.model_name, prompt]
+        result = subprocess.run(command, capture_output=True, text=True)
+
+        if result.returncode != 0:
+            raise Exception(f"Ollama command failed: {result.stderr}")
+        
+        content = result.stdout.strip()
+
+        message = AIMessage(content=content)
+        generation = ChatGeneration(message=message)
+        return ChatResult(generations=[generation])
+    
+    @property
+    def _llm_type(self) -> str:
+        return "ollama-chat-model"
+    
+# taide_llm = OllamaChatModel(model_name="taide-local-llama2")    
+
+if __name__ == "__main__":
+    question = ""
+    while question.lower() != "exit": 
+        question = input("Question: ")
+        # 溫室氣體是什麼?
+
+
+        for function in [ollama_, huggingface_, huggingface2_, openai_]:
+            start = time.time()
+            llm = function()
+            answer = llm.invoke(question)
+            print(answer)
+
+            processing_time = time.time() - start
+            print(processing_time)

+ 107 - 0
pip_env.txt

@@ -0,0 +1,107 @@
+accelerate==0.33.0
+aenum==3.1.15
+aiohappyeyeballs==2.3.7
+aiohttp==3.10.3
+asgiref==3.8.1
+backoff==2.2.1
+bcrypt==4.2.0
+blobfile==2.1.1
+build==1.2.1
+cachetools==5.5.0
+chroma-hnswlib==0.7.6
+chromadb==0.5.5
+coloredlogs==15.0.1
+dataclasses-json==0.6.7
+deprecated==1.2.14
+distro==1.9.0
+fairscale==0.4.13
+filelock==3.15.4
+fire==0.6.0
+flatbuffers==24.3.25
+fsspec==2024.6.1
+future==1.0.0
+google-auth==2.34.0
+googleapis-common-protos==1.63.2
+grpcio==1.65.5
+httptools==0.6.1
+httpx==0.27.0
+huggingface-hub==0.24.5
+humanfriendly==10.0
+importlib-metadata==8.0.0
+importlib-resources==6.4.3
+jinja2==3.1.4
+jiter==0.5.0
+kubernetes==30.1.0
+langchain-community==0.2.12
+langchain-openai==0.1.22
+line-bot-sdk==3.11.1
+lxml==4.9.4
+markdown-it-py==3.0.0
+markupsafe==2.1.5
+marshmallow==3.21.3
+mdurl==0.1.2
+mmh3==4.1.0
+monotonic==1.6
+mpmath==1.3.0
+mypy-extensions==1.0.0
+networkx==3.3
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.6.20
+nvidia-nvtx-cu12==12.1.105
+oauthlib==3.2.2
+onnxruntime==1.19.0
+openai==1.41.0
+opentelemetry-api==1.26.0
+opentelemetry-exporter-otlp-proto-common==1.26.0
+opentelemetry-exporter-otlp-proto-grpc==1.26.0
+opentelemetry-instrumentation==0.47b0
+opentelemetry-instrumentation-asgi==0.47b0
+opentelemetry-instrumentation-fastapi==0.47b0
+opentelemetry-proto==1.26.0
+opentelemetry-sdk==1.26.0
+opentelemetry-semantic-conventions==0.47b0
+opentelemetry-util-http==0.47b0
+overrides==7.7.0
+posthog==3.5.0
+protobuf==4.25.4
+psutil==6.0.0
+pyasn1==0.6.0
+pyasn1-modules==0.4.0
+pycryptodomex==3.20.0
+pydantic==2.8.2
+pydantic-core==2.20.1
+pygments==2.18.0
+pypika==0.48.9
+pyproject-hooks==1.1.0
+python-dotenv==1.0.1
+pyyaml==6.0.2
+regex==2024.7.24
+requests-oauthlib==2.0.0
+rich==13.7.1
+rsa==4.9
+safetensors==0.4.4
+shellingham==1.5.4
+sympy==1.13.2
+termcolor==2.4.0
+tiktoken==0.7.0
+tokenizers==0.19.1
+torch==2.4.0
+tqdm==4.66.5
+transformers==4.44.0
+triton==3.0.0
+typer==0.12.4
+typing-inspect==0.9.0
+uvloop==0.20.0
+watchfiles==0.23.0
+websocket-client==1.8.0
+wrapt==1.16.0
+zipp==3.20.0

+ 73 - 0
semantic_search.py

@@ -0,0 +1,73 @@
+### Python = 3.9
+import os
+from dotenv import load_dotenv
+load_dotenv('.env')
+
+import openai 
+openai_api_key = os.getenv("OPENAI_API_KEY")
+openai.api_key = openai_api_key
+
+from langchain_openai import OpenAIEmbeddings
+embeddings_model = OpenAIEmbeddings()
+
+from langchain_community.document_loaders.csv_loader import CSVLoader
+from langchain_community.vectorstores import Chroma
+
+import pandas as pd
+import re
+
+from langchain_community.embeddings.openai import OpenAIEmbeddings
+from langchain_community.vectorstores import SupabaseVectorStore
+from supabase.client import create_client
+
+
+def create_qa_vectordb(supabase, vectordb_directory="./chroma_db_carbon_questions"):
+
+    if os.path.isdir(vectordb_directory):
+        vectorstore = Chroma(persist_directory=vectordb_directory, embedding_function=embeddings_model)
+        vectorstore.delete_collection()
+
+    response = supabase.table("QA_database").select("Question, Answer").execute()
+    questions = [row["Question"] for row in response.data]
+
+    ######### generate embedding ###########
+    # embedding = embeddings_model.embed_documents(questions)
+
+    ########## Write embedding to the supabase table  #######
+    # for id, new_embedding in zip(ids, embedding):
+    #     supabase.table("video_cache_rows_duplicate").insert({"embedding": embedding.tolist()}).eq("id", id).execute()
+
+    ########### Vector Store #############
+    # Put pre-compute embeddings to vector store. ## save to disk
+    vectorstore = Chroma.from_texts(
+        texts=questions,
+        embedding=embeddings_model,
+        persist_directory=vectordb_directory
+        )
+    
+    return vectorstore
+
+
+# vectorstore = Chroma(persist_directory="./chroma_db_carbon_questions", embedding_function=embeddings_model)
+def semantic_cache(supabase, q, SIMILARITY_THRESHOLD=0.83, k=1, vectordb_directory="./chroma_db_carbon_questions"):
+
+    if os.path.isdir(vectordb_directory):
+        vectorstore = Chroma(persist_directory=vectordb_directory, embedding_function=embeddings_model)
+    else:
+        print("create new vector db ...")
+        vectorstore = create_qa_vectordb(supabase, vectordb_directory)
+
+    docs_and_scores = vectorstore.similarity_search_with_relevance_scores(q, k=1)
+    doc, score = docs_and_scores[0]
+
+    
+    if score >= SIMILARITY_THRESHOLD:
+        cache_question = doc.page_content
+        response = supabase.table("QA_database").select("Question, Answer").eq("Question", cache_question).execute()
+        # qa_df = pd.DataFrame(response.data)
+        # print(response.data[0])
+        answer = response.data[0]["Answer"]
+        # video_cache = response.data[0]["video_cache"]
+        return cache_question, answer
+    else:
+        return None, None

+ 187 - 0
text_to_sql.py

@@ -0,0 +1,187 @@
+import re
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_community.utilities import SQLDatabase
+import os
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# print(db.dialect)
+# print(db.get_usable_table_names())
+# db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
+
+context = db.get_context()
+# print(list(context))
+# print(context["table_info"])
+
+from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
+from langchain.chains import create_sql_query_chain
+from langchain_community.llms import Ollama
+
+from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
+from operator import itemgetter
+
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import PromptTemplate
+from langchain_core.runnables import RunnablePassthrough
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
+import torch
+from langchain_huggingface import HuggingFacePipeline
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+# model_id = "defog/llama-3-sqlcoder-8b"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+# sql_llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False},
+#     device=0, device_map='cuda')
+
+
+model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+llm = HuggingFacePipeline.from_model_id(
+    model_id=model_id,
+    task="text-generation",
+    model_kwargs={"torch_dtype": torch.bfloat16},
+    pipeline_kwargs={"return_full_text": False,
+        "max_new_tokens": 512},
+    device=0, device_map='cuda')
+print(llm.pipeline)
+llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+# model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
+
+# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1, 
+#                 model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
+#, device="auto", load_in_4bit=True
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+def get_examples():
+    examples = [
+        {
+            "input": "去年的固定燃燒總排放量是多少?",
+            "query": 'SELECT SUM("高雄總部及運通廠" + "台北辦事處" + "昆山廣興廠" + "北海建準廠" + "北海立準廠" + "菲律賓建準廠" + "Inc" + "SAS" + "India") AS "固定燃燒總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "排放源" = \'固定燃燒\'',
+        },
+        {
+            "input": "建準廣興廠去年的類別1總排放量是多少?",
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'',
+        },
+        {
+            "input": "建準廣興廠去年的直接排放總排放量是多少?",
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%直接排放%\'',
+        },
+
+    ]
+
+    return examples
+
+def write_query_chain(db):
+
+    template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
+
+    Generate a SQL query to answer this question: `{input}`
+
+    You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
+    then look at the results of the query and return the answer to the input question.\n\
+    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. 
+    You can order the results to return the most informative data in the database.\n\
+    Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
+    Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
+    
+    ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
+    ***Pay attention to only return PostgreSQL query.\n\
+
+    DDL statements:
+    {table_info}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+    The following SQL query best answers the question `{input}`:
+    ```sql
+    """
+    # prompt_template = PromptTemplate.from_template(template)
+
+    example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
+    prompt = FewShotPromptTemplate(
+        examples=get_examples(),
+        example_prompt=example_prompt,
+        prefix=template,
+        suffix="User input: {input}\nSQL query: ",
+        input_variables=["input", "top_k", "table_info"],
+    )
+
+    # llm = Ollama(model = "mannix/defog-llama3-sqlcoder-8b", num_gpu=1)
+    # llm = HuggingFacePipeline(pipeline=pipe)
+    
+    
+    write_query = create_sql_query_chain(llm, db, prompt)
+
+
+    return write_query
+
+def sql_to_nl_chain():
+    # llm = Ollama(model = "llama3.1", num_gpu=1)
+    # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
+    # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+    answer_prompt = PromptTemplate.from_template(
+        """
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|>
+        Given the following user question, corresponding SQL query, and SQL result, answer the user question.
+        給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
+
+        For example
+        Question: 建準廣興廠去年的類別1總排放量是多少?
+        SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
+        SQL Result: [(1102.3712,)]
+        Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+
+        Question: {question}
+        SQL Query: {query}
+        SQL Result: {result}
+        Answer: """
+        )
+
+    chain = answer_prompt | llm | StrOutputParser()
+
+    return chain
+
+def run(db, question, selected_table):
+
+    write_query = write_query_chain(db)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"]})
+    
+    query = re.split('SQL query: ', query)[-1]
+    print(query)
+
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain()
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return query, result, answer
+
+
+if __name__ == "__main__":
+    import time
+    
+    start = time.time()
+    
+    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)']
+    question = "去年的固定燃燒總排放量是多少?"
+    query, result, answer = run(db, question, selected_table)
+    print("query: ", query)
+    print("result: ", result)
+    print("answer: ", answer)
+    
+    print(time.time()-start)
+