Pārlūkot izejas kodu

update agent and add local llm agent

ling 3 mēneši atpakaļ
vecāks
revīzija
bcdee032d2
10 mainītis faili ar 993 papildinājumiem un 103 dzēšanām
  1. 2 0
      .gitignore
  2. 51 0
      README.md
  3. 31 1
      ai_agent.ipynb
  4. 32 86
      ai_agent.py
  5. 609 0
      ai_agent_llama.py
  6. 64 8
      faiss_index.py
  7. 28 0
      semantic_search.py
  8. 73 0
      sql_qa_test.py
  9. 91 7
      systex_app.py
  10. 12 1
      text_to_sql_private.py

+ 2 - 0
.gitignore

@@ -3,3 +3,5 @@ chroma_db_carbon_questions/
 faiss_index.bin
 faiss_metadata.pkl
 .env
+*.csv
+chroma*/

+ 51 - 0
README.md

@@ -0,0 +1,51 @@
+# SYSTEX Multi-agent
+
+此專案建立了 multi-agent AI chatbot,主要包含兩個 agent:一個負責處理客戶自有資料的 Text-to-SQL agent,另一個則是使用檢索增強生成(RAG)技術處理專業知識的agent。整體 multi-agent 架構是使用 `langgraph` 完成。
+
+## 目錄
+- [專案概述](#專案概述)
+- [主要組件](#主要組件)
+- [安裝](#安裝)
+- [使用方式](#使用方式)
+- [檔案說明](#檔案說明)
+- [貢獻](#貢獻)
+- [授權](#授權)
+
+## 專案概述
+此 multi-agent 系統旨在透過判斷使用者提問而選擇使用客戶自有資料庫或外部專業知識來源,回答使用者提問。系統包含兩個主要代理:
+1. 客戶自有資料 agent:使用 **Text-to-SQL** 技術,用於處理客戶自有的結構化數據。
+2. 外部專業知識 agent:使用 **RAG** 技術,並以 **FAISS** 實現 RAG 的 retriever,用於從外部非結構化知識中檢索並生成答案。
+
+## 使用方式
+
+啟動 app:
+```bash
+conda activate llama3
+python systex_app.py
+```
+FastAPI Link: https://cmm.ai:8989/docs
+
+共有四個 API,如下:
+1. `/agent`:
+2. `/knowledge`
+3. `local_agents`
+4. `history`
+
+## 檔案說明
+
+- **`app.py`**:運行多代理系統的主入口。
+- **`ai_agent.py`**:定義多代理架構,包括 Text-to-SQL 和 RAG 代理。該架構由 `langgraph` 驅動。
+- **`faiss_index.py`**:管理 FAISS 檢索器,為 RAG 提供文件檢索功能。
+- **`tex_to_sql_private.py`**:包含將自然語言轉換為 SQL 查詢的邏輯,負責處理客戶自有資料。
+
+## 貢獻
+
+歡迎任何貢獻!如有改進建議或發現錯誤,請提交 Pull Request 或開啟 Issue。
+
+## 授權
+
+此專案依據 MIT 授權條款發布 - 詳情請參閱 [LICENSE](LICENSE) 檔案。
+
+---
+
+這樣的 README 是否符合你的需求?如果有其他資訊要補充,隨時告訴我!

+ 31 - 1
ai_agent.ipynb

@@ -358,10 +358,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
     "### Answer Grader\n",
     "\n",
     "# LLM\n",
@@ -722,6 +727,31 @@
     "\"\"\""
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'yes'"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "question = \"給我工業製程計算範例\"\n",
+    "generation = \"\"\"根據建準的資料,工業製程的排放範例顯示,直接排放的總排放量為0.3165公噸CO2e。\n",
+    "製程排放是指在工業製程過程中,由於物理或化學反應所產生的溫室氣體排放。這些排放源通常來自於特定的製程設備或過程,例如在半導體製造中使用的蝕刻設備,這些設備可能會釋放出二氧化碳(CO2)、甲烷(CH4)、氫氟碳化物(HFCs)、全氟碳化物(PFCs)、氧化亞氮(N2O)、六氟化硫(SF6)及三氟化氮(NF3)等多種溫室氣體。製程排放的管理和減少對於降低整體溫室氣體排放量及應對氣候變遷具有重要意義。\"\"\"\n",
+    "score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
+    "grade = score[\"score\"]\n",
+    "grade"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 27,

+ 32 - 86
ai_agent.py

@@ -40,10 +40,12 @@ from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 progress_bar = []
 
-def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
+def faiss_query(question: str, llm, docs=None, multi_query: bool = False) -> str:
     if multi_query:
-        docs = faiss_multiquery(question, retriever, llm)
+        docs = faiss_multiquery(question, retriever, llm, k=4)
         # print(docs)
+    elif docs:
+        pass
     else:
         docs = retriever.get_relevant_documents(question, k=10)
         # print(docs)
@@ -58,7 +60,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     You should not mention anything about "根據提供的文件內容" or other similar terms.
     請盡可能的詳細回答問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
-    勿回答無關資訊
+    勿回答無關資訊或任何與某特定公司相關的問題
     <|eot_id|>
     
     <|start_header_id|>user<|end_header_id|>
@@ -67,7 +69,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     {context}
 
     Question: {question}
-    用繁體中文回答問題,請用一段話詳細的回答。
+    用繁體中文回答問題,請用一段話詳細的回答。勿回答無關資訊或任何與某特定公司相關的問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     
     <|eot_id|>
@@ -169,18 +171,13 @@ def generate_additional_detail(sql_query):
         print(term)
         if term is None: continue
         question_format = [ f"溫室氣體排放源中的{term}是什麼意思?",  f"{term}是什麼意思?"]
-        # f"溫室氣體排放源中的{term}是什麼意思?",
         for question in question_format:
-            # question = f"什麼是{term}?"
             documents = retriever.get_relevant_documents(question, k=5)
             all_documents.extend(documents)
-            # for doc in documents:
-            #     print(doc)
+            
         all_question = "".join(question_format)
-        documents, generation = faiss_query(all_question, all_documents, llm, multi_query=True) 
-        # print(generation)
-        # print("-----------------------")
-        # generation = answer + "\n"
+        documents, generation = faiss_query(all_question, llm, docs=all_documents, multi_query=True) 
+        
         if "test@systex.com" in generation:
             generation = ""
         
@@ -301,43 +298,12 @@ def retrieve_and_generation(state):
         state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
     """
     progress_bar = show_progress(state, "---RETRIEVE---")
-    # progress_bar = state["progress"] if state["progress"] else []
-    # progress = "---RETRIEVE---"
-    # print(progress)
-    # progress_bar.append(progress)
     if not state["route"]:
         route = "RAG"
     else:
         route = state["route"]
     question = state["question"]
-    # print(state)
-    question_list = state["question_list"]
-    
-    # Retrieval
-    if not question_list:
-        # documents = retriever.invoke(question)
-        # TODO: correct Retrieval function
-        documents = retriever.get_relevant_documents(question, k=5)
-        # for doc in documents:
-        #     print(doc)
-            
-        # docs_documents = "\n\n".join(doc.page_content for doc in documents)
-        # print(documents)
-        documents, generation = faiss_query(question, documents, llm, multi_query=True)
-        # for doc in documents:
-        #     print(doc)
-    else:
-        generation = state["generation"]
-        
-        for sub_question in list(set(question_list)):
-            print(sub_question)
-            documents = retriever.get_relevant_documents(sub_question, k=5)
-            # for doc in documents:
-            #     print(doc)
-            documents, answer = faiss_query(sub_question, documents, llm, multi_query=True)
-            generation += answer
-            generation += "\n"
-            
+    documents, generation = faiss_query(question, llm, multi_query=True)
     print(generation)
             
     return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
@@ -444,6 +410,9 @@ def route_question(state):
         str: Next node to call
     """
 
+    if "route" in state.keys():
+        return "專業知識"
+    
     # print("---ROUTE QUESTION---")
     progress_bar = show_progress(state, "---ROUTE QUESTION---")
     question = state["question"]
@@ -452,10 +421,11 @@ def route_question(state):
     source = question_router.invoke({"question": question})
     print("Original:", source["datasource"])
     # if "建準" in question:
-    kw = ["建準", "北海", "廣興", "崑山廣興", "Inc", "SAS", "立準"]
-    if any(char in question for char in kw):
+    private_kw = ["建準", "北海", "廣興", "崑山廣興", "Inc", "SAS", "立準"]
+    public_kw = ["範例", "碳足跡"]
+    if any(char in question for char in private_kw):
         source["datasource"] = "自有數據"
-    elif "範例" in question:
+    elif any(char in question for char in public_kw):
         source["datasource"] = "專業知識"
         
     # print(source)
@@ -498,40 +468,6 @@ def grade_generation_v_documents_and_question(state):
         progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
         return "not useful"
     
-    
-    # progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
-    # # print(docs_documents)
-    # # print(generation)
-    # hallucination_grader = Hallucination_Grader()
-    # score = hallucination_grader.invoke(
-    #     {"documents": documents, "generation": generation}
-    # )
-    # # print(score)
-    # grade = score["score"]
-
-    # # Check hallucination
-    # if grade in ["yes", "true", 1, "1"]:
-    #     # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
-    #     progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
-    #     # Check question-answering
-    #     # print("---GRADE GENERATION vs QUESTION---")
-    #     progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
-    #     answer_grader = Answer_Grader()
-    #     score = answer_grader.invoke({"question": question, "generation": generation})
-    #     grade = score["score"]
-    #     if grade in ["yes", "true", 1, "1"]:
-    #         # print("---DECISION: GENERATION ADDRESSES QUESTION---")
-    #         progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
-    #         return "useful"
-    #     else:
-    #         # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
-    #         progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
-    #         return "not useful"
-    # else:
-    #     # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
-    #     progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
-    #     return "not supported"
-    
 def grade_sql_query(state):
     """
     Determines whether the Postgresql query are correct to the question
@@ -576,6 +512,7 @@ def grade_sql_query(state):
     #     # print("---GRADE: INCORRECT SQL QUERY---")
     #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
     #     return "incorrect"
+    
 def check_sql_answer(state):
     progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
     generation = state["generation"]
@@ -645,11 +582,6 @@ print(draw_mermaid)
 
 def main(question: str):
     
-    # app = build_graph()
-    # draw_mermaid = app.get_graph().draw_mermaid()
-    # print(draw_mermaid)
-    #建準去年的類別一排放量?
-    # inputs = {"question": "溫室氣體是什麼"}
     inputs = {"question": question, "progress_bar": None}
     for output in app.stream(inputs, {"recursion_limit": 10}):
         for key, value in output.items():
@@ -662,6 +594,20 @@ def main(question: str):
     # return value["generation"]
     return value
 
+def rag_main(question: str):
+    
+    inputs = {"question": question, "progress_bar": None, "route": "專業知識"}
+    for output in app.stream(inputs, {"recursion_limit": 10}):
+        for key, value in output.items():
+            pprint(f"Finished running: {key}:")
+    # pprint(value["generation"])
+    # pprint(value)
+    value["progress_bar"] = progress_bar
+    # pprint(value["progress_bar"])
+    
+    # return value["generation"]
+    return value
+
 if __name__ == "__main__":
     # result = main("建準去年的逸散排放總排放量是多少?")
     # result = main("建準廣興廠去年的上游運輸總排放量是多少?")

+ 609 - 0
ai_agent_llama.py

@@ -0,0 +1,609 @@
+
+from langchain_community.chat_models import ChatOllama
+from langchain_core.output_parsers import JsonOutputParser
+from langchain_core.prompts import PromptTemplate
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+
+# graph usage
+from pprint import pprint
+from typing import List
+from langchain_core.documents import Document
+from typing_extensions import TypedDict
+from langgraph.graph import END, StateGraph, START
+from langgraph.pregel import RetryPolicy
+
+# supabase db
+from langchain_community.utilities import SQLDatabase
+import os
+from dotenv import load_dotenv
+load_dotenv()
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# LLM
+# local_llm = "llama3.1:8b-instruct-fp16"
+# local_llm = "llama3.1:8b-instruct-q2_K"
+local_llm = "llama3-groq-tool-use:latest"
+llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+local_llm = "cwchang/llama3-taide-lx-8b-chat-alpha1:q3_k_s"
+llm = ChatOllama(model=local_llm, temperature=0)
+sql_llm = ChatOllama(model="codeqwen", temperature=0)
+# sql_llm = ChatOllama(model="eramax/nxcode-cq-7b-orpo:q6", temperature=0)
+
+from langchain_openai import ChatOpenAI
+# sql_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+
+# RAG usage
+from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
+retriever = create_faiss_retriever()
+
+# text-to-sql usage
+from text_to_sql_private import run, get_query, query_to_nl, table_description
+from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
+progress_bar = []
+
+def faiss_query(question: str, llm, docs=None, multi_query: bool = False) -> str:
+    if multi_query:
+        docs = faiss_multiquery(question, retriever, llm, k=4)
+        # print(docs)
+    elif docs:
+        pass
+    else:
+        docs = retriever.get_relevant_documents(question, k=10)
+        # print(docs)
+    context = docs
+    
+    system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    請盡可能的詳細回答問題。
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    勿回答無關資訊或任何與某特定公司相關的問題。
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
+    Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    用繁體中文回答問題,請用一段話詳細的回答。勿回答無關資訊或任何與某特定公司相關的問題。
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    prompt = ChatPromptTemplate.from_template(
+        system_prompt + "\n\n" +
+        template
+    )
+
+    rag_chain = prompt | llm | StrOutputParser()
+    return docs, rag_chain.invoke({"context": context, "question": question})
+
+### Hallucination Grader
+
+def Hallucination_Grader():
+    # Prompt
+    prompt = PromptTemplate(
+        template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are a grader assessing whether an answer is grounded in / supported by a set of facts. 
+        Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. 
+        Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. 
+        Return the a JSON with a single key 'score' and no premable or explanation. 
+        <|eot_id|><|start_header_id|>user<|end_header_id|>
+        Here are the facts:
+        \n ------- \n
+        {documents} 
+        \n ------- \n
+        Here is the answer: {generation} 
+        Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["generation", "documents"],
+    )
+
+    hallucination_grader = prompt | llm_json | JsonOutputParser()
+    
+    return hallucination_grader
+
+### Answer Grader
+
+def Answer_Grader():
+    # Prompt
+    prompt = PromptTemplate(
+        template="""
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an 
+        answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is 
+        useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
+        <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
+        \n ------- \n
+        {generation} 
+        \n ------- \n
+        Here is the question: {question} 
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["generation", "question"],
+    )
+
+    answer_grader = prompt | llm_json | JsonOutputParser()
+    
+    return answer_grader
+
+# Text-to-SQL
+# def run_text_to_sql(question: str):
+#     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+#     # question = "建準去年的固定燃燒總排放量是多少?"
+#     query, result, answer = run(db, question, selected_table, sql_llm)
+    
+#     return  answer, query
+
+def _get_query(question: str):
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
+    
+    query, result = get_query(db, question, selected_table, sql_llm)
+    return  query, result
+
+def _query_to_nl(question: str, query: str, result):
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
+    local_llm = "llama3-groq-tool-use:latest"
+    llm = ChatOllama(model=local_llm, temperature=0)
+    answer = query_to_nl(question, query, result, llm)
+    return  answer
+
+def generate_additional_question(sql_query):
+    terms = parse_sql_where(sql_query)
+    question_list = []
+    for term in terms:
+        if term is None: continue
+        question_format = [f"什麼是{term}?", f"{term}的用途是什麼"]
+        question_list.extend(question_format)
+        
+    return question_list
+        
+    
+def generate_additional_detail(sql_query):
+    terms = parse_sql_where(sql_query)
+    answer = ""
+    all_documents = []
+    for term in list(set(terms)):
+        print(term)
+        if term is None: continue
+        question_format = [ f"溫室氣體排放源中的{term}是什麼意思?",  f"{term}是什麼意思?"]
+        for question in question_format:
+            documents = retriever.get_relevant_documents(question, k=5)
+            all_documents.extend(documents)
+            
+        all_question = "".join(question_format)
+        documents, generation = faiss_query(all_question, llm, docs=all_documents, multi_query=True) 
+        
+        if "test@systex.com" in generation:
+            generation = ""
+        
+        answer += generation
+        # print(question)
+        # print(generation)
+    return answer
+### SQL Grader
+
+def SQL_Grader():
+    prompt = PromptTemplate(
+        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are a SQL query grader assessing correctness of PostgreSQL query to a user question. 
+        Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
+        
+        Here is database description:
+        {table_info}
+        
+        You need to check that each where statement is correctly filtered out what user question need.
+        
+        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
+        FROM "建準碳排放清冊數據new"
+        WHERE "事業名稱" like '%建準%'
+        AND "排放源" = '下游租賃'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
+        For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
+        
+        Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+        FROM "建準碳排放清冊數據new"
+        WHERE "事業名稱" like '%台積電%'
+        AND "排放源" = '固定燃燒'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
+        For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
+        
+        and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
+        
+        If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. 
+        You need to strictly examine whether the sql PostgreSQL query matches the user question.
+        Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
+        Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
+        <|eot_id|>
+        
+        <|start_header_id|>user<|end_header_id|>
+        Here is the PostgreSQL query: \n\n {sql_query} \n\n
+        Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+        """,
+        input_variables=["table_info", "question", "sql_query"],
+    )
+
+    sql_query_grader = prompt | llm_json | JsonOutputParser()
+    
+    return sql_query_grader
+
+### Router
+def Router():
+    prompt = PromptTemplate(
+        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are an expert at routing a user question to a 專業知識 or 自有數據. 
+        你需要分辨使用者問題是否在詢問某個公司與其據點廠房的自有數據或是尋求專業的碳盤查或碳管理等等的 ESG 知識和相關新聞,
+        如果問題是想了解某個公司與其據點廠房的碳排放源的排放量或用電、用水量等等,請使用"自有數據",
+        若使用者的問題是想了解碳盤查、碳交易或碳管理等等的 ESG 知識和相關新聞,請使用"專業知識"。
+        You do not need to be stringent with the keywords in the question related to these topics. 
+        Give a binary choice '自有數據' or '專業知識' based on the question. 
+        Return the a JSON with a single key 'datasource' and no premable or explanation. 
+        
+        Question to route: {question} 
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["question"],
+    )
+
+    question_router = prompt | llm_json | JsonOutputParser()
+    
+    return question_router
+
+class GraphState(TypedDict):
+    """
+    Represents the state of our graph.
+
+    Attributes:
+        question: question
+        generation: LLM generation
+        company_private_data: whether to search company private data
+        documents: list of documents
+    """
+
+    progress_bar: List[str]
+    route: str
+    question: str
+    question_list: List[str]
+    generation: str
+    documents: List[str]
+    retry: int
+    sql_query: str
+    sql_result: str
+    
+# Node
+def show_progress(state, progress: str):
+    global progress_bar
+    # progress_bar = state["progress_bar"] if state["progress_bar"] else []
+    
+    print(progress)
+    progress_bar.append(progress)
+    
+    return progress_bar
+
+def retrieve_and_generation(state):
+    """
+    Retrieve documents from vectorstore
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
+    """
+    progress_bar = show_progress(state, "---RETRIEVE---")
+    if not state["route"]:
+        route = "RAG"
+    else:
+        route = state["route"]
+    question = state["question"]
+    documents, generation = faiss_query(question, llm, multi_query=True)
+    print(generation)
+            
+    return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
+
+def company_private_data_get_sql_query(state):
+    """
+    Get PostgreSQL query according to question
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): return generated PostgreSQL query and record retry times
+    """
+    # print("---SQL QUERY---")
+    progress_bar = show_progress(state, "---SQL QUERY---")
+    if not state["route"]:
+        route = "SQL"
+    else:
+        route = state["route"]
+    question = state["question"]
+    
+    if state["retry"]:
+        retry = state["retry"]
+        retry += 1
+    else: 
+        retry = 0
+    # print("RETRY: ", retry)
+    
+    sql_query, sql_result = _get_query(question)
+    print(type(sql_result))
+    
+    return {"progress_bar": progress_bar, "route": route, "sql_query": sql_query, "sql_result": sql_result, "question": question, "retry": retry}
+    
+def company_private_data_search(state):
+    """
+    Execute PostgreSQL query and convert to nature language.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): Appended sql results to state
+    """
+
+    # print("---SQL TO NL---")
+    progress_bar = show_progress(state, "---SQL TO NL---")
+    # print(state)
+    question = state["question"]
+    sql_query = state["sql_query"]
+    sql_result = state["sql_result"]
+    generation = _query_to_nl(question, sql_query, sql_result)
+    
+    # generation = [company_private_data_result]
+    
+    return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation}
+
+def additional_explanation_question(state):
+    """
+    
+    Args:
+        state (_type_): _description_
+        
+    Returns:
+        state (dict): Appended additional explanation to state
+    """
+    
+    # print("---ADDITIONAL EXPLANATION---")
+    progress_bar = show_progress(state, "---ADDITIONAL EXPLANATION---")
+    # print(state)
+    question = state["question"]
+    sql_query = state["sql_query"]
+    # print(sql_query)
+    generation = state["generation"]
+    generation += "\n"
+    generation += generate_additional_detail(sql_query)
+    question_list = []    
+    
+    # question_list = generate_additional_question(sql_query)
+    # print(question_list)
+    
+    # generation = [company_private_data_result]
+    
+    return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
+
+def error(state):
+    # print("---SOMETHING WENT WRONG---")
+    progress_bar = show_progress(state, "---SOMETHING WENT WRONG---")
+    generation = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    
+    return {"progress_bar": progress_bar,  "generation": generation}
+
+### Conditional edge
+
+
+def route_question(state):
+    """
+    Route question to web search or RAG.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        str: Next node to call
+    """
+
+    # print("---ROUTE QUESTION---")
+    progress_bar = show_progress(state, "---ROUTE QUESTION---")
+    question = state["question"]
+    # print(question)
+    question_router = Router()
+    source = question_router.invoke({"question": question})
+    print("Original:", source["datasource"])
+    # if "建準" in question:
+    kw = ["建準", "北海", "廣興", "崑山廣興", "Inc", "SAS", "立準"]
+    if any(char in question for char in kw):
+        source["datasource"] = "自有數據"
+    elif "範例" in question:
+        source["datasource"] = "專業知識"
+        
+    # print(source)
+    print(source["datasource"])
+    if source["datasource"] == "自有數據":
+        # print("---ROUTE QUESTION TO TEXT-TO-SQL---")
+        progress_bar = show_progress(state, "---ROUTE QUESTION TO TEXT-TO-SQL---")
+        return "自有數據"
+    elif source["datasource"] == "專業知識":
+        # print("---ROUTE QUESTION TO RAG---")
+        progress_bar = show_progress(state, "---ROUTE QUESTION TO RAG---")
+        return "專業知識"
+    
+def grade_generation_v_documents_and_question(state):
+    """
+    Determines whether the generation is grounded in the document and answers question.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        str: Decision for next node to call
+    """
+
+    # print("---CHECK HALLUCINATIONS---")
+    question = state["question"]
+    documents = state["documents"]
+    generation = state["generation"]
+
+    progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
+    answer_grader = Answer_Grader()
+    score = answer_grader.invoke({"question": question, "generation": generation})
+    print(score)
+    grade = score["score"]
+    if grade in ["yes", "true", 1, "1"]:
+        # print("---DECISION: GENERATION ADDRESSES QUESTION---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
+        return "useful"
+    else:
+        # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+        return "not useful"
+    
+def grade_sql_query(state):
+    """
+    Determines whether the Postgresql query are correct to the question
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): Decision for retry or continue
+    """
+
+    # print("---CHECK SQL CORRECTNESS TO QUESTION---")
+    progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
+    question = state["question"]
+    sql_query = state["sql_query"]
+    sql_result = state["sql_result"]
+    if "None" in sql_result or sql_result.startswith("Error:"):
+        progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
+        return "incorrect"
+    else:
+        print(sql_result)
+        progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
+        return "correct"
+    # retry = state["retry"]
+
+    # # Score each doc
+    # sql_query_grader = SQL_Grader()
+    # score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
+    # grade = score["score"]
+    
+    
+    # # Document relevant
+    # if grade in ["yes", "true", 1, "1"]:
+    #     # print("---GRADE: CORRECT SQL QUERY---")
+    #     progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
+    #     return "correct"
+    # elif retry >= 5:
+    #     # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+    #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+    #     return "failed"
+    # else:
+    #     # print("---GRADE: INCORRECT SQL QUERY---")
+    #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
+    #     return "incorrect"
+    
+def check_sql_answer(state):
+    progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
+    generation = state["generation"]
+    if "test@systex.com" in generation:
+        progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
+        return "bad"
+    else:
+        progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
+        return "good"
+    
+def build_graph():
+    workflow = StateGraph(GraphState)
+
+    # Define the nodes
+    workflow.add_node("Text-to-SQL", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("ERROR", error)  # retrieve
+    company_private_data_search
+    workflow.add_conditional_edges(
+        START,
+        route_question,
+        {
+            "自有數據": "Text-to-SQL",
+            "專業知識": "RAG",
+        },
+    )
+
+    workflow.add_conditional_edges(
+        "RAG",
+        grade_generation_v_documents_and_question,
+        {
+            "useful": END,
+            "not useful": "ERROR",
+        },
+    )
+    workflow.add_conditional_edges(
+        "Text-to-SQL",
+        grade_sql_query,
+        {
+            "correct": "SQL Answer",
+            "incorrect": "RAG",
+            
+        },
+    )
+    workflow.add_conditional_edges(
+        "SQL Answer",
+        check_sql_answer,
+        {
+            "good": "Additoinal Explanation",
+            "bad": "RAG",
+            
+        },
+    )
+    
+    # workflow.add_edge("SQL Answer", "Additoinal Explanation")
+    workflow.add_edge("Additoinal Explanation", END)
+
+    app = workflow.compile()    
+    
+    return app
+
+app = build_graph()
+draw_mermaid = app.get_graph().draw_mermaid()
+print(draw_mermaid)
+
+def main(question: str):
+    
+    inputs = {"question": question, "progress_bar": None}
+    for output in app.stream(inputs, {"recursion_limit": 10}):
+        for key, value in output.items():
+            pprint(f"Finished running: {key}:")
+    # pprint(value["generation"])
+    # pprint(value)
+    value["progress_bar"] = progress_bar
+    # pprint(value["progress_bar"])
+    
+    # return value["generation"]
+    return value
+
+if __name__ == "__main__":
+    # result = main("建準去年的逸散排放總排放量是多少?")
+    # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
+    
+    result = main("建準北海廠去年的固定燃燒排放量是多少?")
+    # result = main("溫室氣體是什麼?")
+    # result = main("什麼是外購電力(綠電)?")
+    print("------------------------------------------------------")
+    print(result)

+ 64 - 8
faiss_index.py

@@ -16,6 +16,7 @@ import pandas as pd
 from langchain_core.documents import Document
 from langchain.load import dumps, loads
 from langchain_community.chat_models import ChatOllama
+from langchain.callbacks import get_openai_callback
 
 # Import from the parent directory
 import sys
@@ -146,7 +147,7 @@ def load_qa_pairs():
 
     return df['Question'].tolist(), df['Answer'].tolist()
 
-def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
+def faiss_multiquery(question: str, retriever: FAISSRetriever, llm, k: int = 4):
     generate_queries = multi_query_chain(llm)
 
     questions = generate_queries.invoke(question)
@@ -156,17 +157,17 @@ def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
         print(q)
 
     # docs = list(map(retriever.get_relevant_documents, questions))
-    docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
+    docs = list(map(lambda query: retriever.get_relevant_documents(query, k=k), questions))
     docs = [item for sublist in docs for item in sublist]
 
     return docs
 
-def faiss_query(retriever, question: str, llm, multi_query: bool = False) -> str:
+def faiss_query(retriever, question: str, llm, k: int = 4, multi_query: bool = False) -> str:
     if multi_query:
-        docs = faiss_multiquery(question, retriever, llm)
+        docs = faiss_multiquery(question, retriever, llm, k)
         # print(docs)
     else:
-        docs = retriever.get_relevant_documents(question, k=10)
+        docs = retriever.get_relevant_documents(question, k)
         # print(docs)
     context = docs
     
@@ -179,7 +180,7 @@ def faiss_query(retriever, question: str, llm, multi_query: bool = False) -> str
     You should not mention anything about "根據提供的文件內容" or other similar terms.
     請盡可能的詳細回答問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
-    勿回答無關資訊
+    勿回答無關資訊或任何與某特定公司相關的問題
     <|eot_id|>
     
     <|start_header_id|>user<|end_header_id|>
@@ -188,7 +189,7 @@ def faiss_query(retriever, question: str, llm, multi_query: bool = False) -> str
     {context}
 
     Question: {question}
-    用繁體中文回答問題,請用一段話詳細的回答。
+    用繁體中文回答問題,請用一段話詳細的回答。勿回答無關資訊或任何與某特定公司相關的問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     
     <|eot_id|>
@@ -309,8 +310,63 @@ async def run_evaluation():
 
     print("\nPerformance comparison complete.")
 
+def load_kg_q():
+    import pandas as pd
+    sheet_id ="1eVmO8fIjjtEQBlmqtipi2d0H-8VgJkI-icqhSTfdhHQ"
+    gid = "0"
+    df = pd.read_csv(f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv&gid={gid}")
+    
+    return df['Question'].tolist()
+    
+async def run_q_batch():
+    # local_llm = "llama3-groq-tool-use:latest"
+    # llama3 = ChatOllama(model=local_llm, temperature=0)
+    openai = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+
+    retriever = create_faiss_retriever()
+
+    questions = load_kg_q()
+    result_list = []
+
+    for question in questions:
+        print(f"\nQuestion: {question}")
+
+        with get_openai_callback() as cb:
+            start_time = time()
+            # openai_docs, openai_answer = faiss_query(retriever, question, openai, multi_query=True)
+            documents, answer = faiss_query(retriever, question, openai, multi_query=True)
+            openai_time = time() - start_time
+            print(f"Answer: {answer}")
+            print(f"Time: {openai_time:.4f} seconds")
+            save_history(question, answer, cb, openai_time)
+        
+        result = {'Question': question, 'Answer': answer}
+        result_list.append(result)
 
+    df = pd.DataFrame.from_records(result_list)
+    print(df)
+        
+    df.to_csv("kg_qa.csv", mode='w')
+
+    print("\nQA complete.")
+    
+def save_history(question, answer, cb, processing_time):
+    # reference = [doc.dict() for doc in reference]
+    record = {
+        'Question': question,
+        'Answer': answer,
+        'Total_Tokens': cb.total_tokens,
+        'Total_Cost': cb.total_cost,
+        'Processing_time': processing_time,
+    }
+    response = (
+        supabase.table("agent_records")
+        .insert(record)
+        .execute()
+    )
 if __name__ == "__main__":
-    asyncio.run(run_evaluation())
+    # asyncio.run(run_evaluation())
+    asyncio.run(run_q_batch())
+    # print(load_kg_q())
 
     

+ 28 - 0
semantic_search.py

@@ -20,7 +20,35 @@ from langchain_community.embeddings.openai import OpenAIEmbeddings
 from langchain_community.vectorstores import SupabaseVectorStore
 from supabase.client import create_client
 
+def grandson_vectordb(vectordb_directory = "./chroma_grandson"):
+    questions = ['我的跨損啊', "我要看孫子"]
+    
+    vectorstore = Chroma.from_texts(
+        texts=questions,
+        embedding=embeddings_model,
+        persist_directory=vectordb_directory
+        )
+    return vectorstore
 
+def grandson_semantic_cache(q, SIMILARITY_THRESHOLD=0.83, k=1, vectordb_directory = "./chroma_grandson"):
+    vectordb_directory = "./chroma_grandson"
+    if os.path.isdir(vectordb_directory):
+        vectorstore = Chroma(persist_directory=vectordb_directory, embedding_function=embeddings_model)
+    else:
+        print("create new vector db ...")
+        vectorstore = grandson_vectordb(vectordb_directory)
+
+    docs_and_scores = vectorstore.similarity_search_with_relevance_scores(q, k=1)
+    doc, score = docs_and_scores[0]
+
+    
+    if score >= SIMILARITY_THRESHOLD:
+        cache_question = doc.page_content
+        answer = "你有三個孫子,男生在國小念書,你要看他照片嗎"
+        return cache_question, answer
+    else:
+        return None, None
+    
 def create_qa_vectordb(supabase, vectordb_directory="./chroma_db_carbon_questions"):
 
     if os.path.isdir(vectordb_directory):

+ 73 - 0
sql_qa_test.py

@@ -0,0 +1,73 @@
+
+
+import csv
+import time
+from langchain.callbacks import get_openai_callback
+from ai_agent import main
+from systex_app import save_history
+
+
+def agent(question):
+    print(question)
+    start = time.time()
+    with get_openai_callback() as cb:
+        result = main(question)
+        answer = result["generation"]
+    processing_time = time.time() - start
+    save_history(question, answer, cb, processing_time)
+    
+    if "test@systex.com" in answer:
+        answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    print(answer)
+    return result
+
+places = {
+    '高雄': ["高雄廠", "高雄總部及運通廠", "運通廠", "總部", "高雄總部", "高雄"], 
+    '台北': ["台北廠", "台北辦事處", "台北", "臺北"],
+    '廣興': ["廣興廠", "昆山廠", "廣興"],
+    '北海': ["北海建準廠", "北海廠", "立準廠", "立準", "北海"],
+    '菲律賓': ["菲律賓建準廠", "菲律賓"],
+    '印度': ["India", "印度"],
+    '美國': ["Inc", "美國"],
+    '法國': ["SAS", "法國"],
+}
+items = ['直接排放', '能源間接排放', '運輸間接排放', '組織使用產品間接排放', '使用來自組織產品間接排放',
+        '類別1', '類別2', '類別3', '類別4', '類別5', '類別6',
+        '固定燃燒', '移動燃燒', '製程排放', '逸散排放', '土地利用', 
+        '外購電力', '外購能源', 
+        '上游運輸', '下游運輸', '員工通勤', '商務旅行', '訪客運輸', 
+        '購買產品', '外購燃料及能資源', '資本貨物', '上游租賃', '廢棄物處理', '廢棄物清運', '其他委外業務', 
+        '產品加工', '產品使用', '產品最終處理', '下游租賃', '投資排放', 
+        '其他間接排放']
+items2 = ['綠電', "灰電", "自產電力", "外購電力"]
+times = ["2022", "2023", "去年", "前年"]
+import random
+import math
+
+question_list = []
+for place in places.keys():
+    print(place)
+    loc_random = random.sample(places[place], k=math.ceil(len(places[place])/2))
+    print(loc_random)
+    for loc in loc_random:
+        # item_random = random.sample(items, k=10)
+        item_random = items2
+        for item in item_random:
+            year_time = random.choice(times)
+            # questions_format = [f"建準{place}{year_time}的{item}是多少", f"{place}{year_time}的{item}排放量", f"{year_time}{place}的{item}是多少", 
+            #                     f"請問{year_time}的{place}{item}排放量是多少?", f"{year_time}建準{place}{item}"]
+            questions_format = [f"建準{place}{year_time}的{item}是多少", f"{place}{year_time}的{item}使用量", f"{year_time}{place}的{item}是多少", 
+                                f"請問{year_time}的{place}{item}使用量是多少?", f"{year_time}建準{place}{item}"]
+            
+            questions = random.sample(questions_format, k=2)
+            # question_list.extend(questions)
+            for question in questions:
+                print(question)
+                result = agent(question)
+                
+                labels = result.keys()
+                with open("results2.csv", "a") as f:
+                    writer = csv.DictWriter(f, fieldnames=labels)
+                    writer.writeheader()
+                    writer.writerow(result)
+            

+ 91 - 7
systex_app.py

@@ -1,5 +1,6 @@
 import datetime
 from json import loads
+import threading
 import time
 from typing import List
 from fastapi import Body, FastAPI
@@ -7,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
 
 import pandas as pd
 from pydantic import BaseModel
+import requests
 import uvicorn
 
 from dotenv import load_dotenv
@@ -15,8 +17,9 @@ from supabase.client import Client, create_client
 
 from langchain.callbacks import get_openai_callback
 
-from ai_agent import main
-from semantic_search import semantic_cache
+from ai_agent import main, rag_main
+from ai_agent_llama import main as llama_main
+from semantic_search import semantic_cache, grandson_semantic_cache
 from RAG_strategy import get_search_query
 
 load_dotenv()
@@ -51,9 +54,14 @@ def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
     
     with get_openai_callback() as cb:
         # cache_question, cache_answer = semantic_cache(supabase, question)
-        cache_answer = None
+        cache_question, cache_answer = grandson_semantic_cache(question)
+        # cache_answer = None
         if cache_answer:
             answer = cache_answer
+            if "孫子" in answer:
+                path = "https://cmm.ai/systex-ai-chatbot/video_cache/"
+                video_cache = "grandson2.mp4"
+                return {"Answer": answer, "video_cache": path + video_cache}
         else:
             result = main(question)
             answer = result["generation"]
@@ -65,6 +73,44 @@ def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
     print(answer)
     return {"Answer": answer}  
 
+@app.post("/knowledge")
+def rag(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
+    print(question)
+    start = time.time()
+    
+    with get_openai_callback() as cb:
+        # cache_question, cache_answer = semantic_cache(supabase, question)
+        cache_answer = None
+        if cache_answer:
+            answer = cache_answer
+        else:
+            result = rag_main(question)
+            answer = result["generation"]
+    processing_time = time.time() - start
+    save_history(question, answer, cb, processing_time)
+    if "test@systex.com" in answer:
+        answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    print(answer)
+    return {"Answer": answer}  
+    
+@app.post("/local_agents")
+def local_agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
+    print(question)
+    start = time.time()
+    
+    with get_openai_callback() as cb:
+        # cache_question, cache_answer = semantic_cache(supabase, question)
+        cache_answer = None
+        if cache_answer:
+            answer = cache_answer
+        else:
+            result = llama_main(question)
+            answer = result["generation"]
+    processing_time = time.time() - start
+    save_history(question, answer, cb, processing_time)
+    
+    return {"Answer": answer}  
+
 def save_history(question, answer, cb, processing_time):
     # reference = [doc.dict() for doc in reference]
     record = {
@@ -101,9 +147,47 @@ async def get_history():
     result = loads(result)
     return result.values()  
 
-if __name__ == "__main__":
+
+def cleanup_files():
+    faiss_index_path = "faiss_index.bin"
+    metadata_path = "faiss_metadata.pkl"
+    try:
+        if os.path.exists(faiss_index_path):
+            os.remove(faiss_index_path)
+            print(f"{faiss_index_path} 已刪除")
+        if os.path.exists(metadata_path):
+            os.remove(metadata_path)
+            print(f"{metadata_path} 已刪除")
+    except Exception as e:
+        print(f"刪除檔案時出錯: {e}")
+        
+def send_heartbeat(url, sec=600):
+    while True:
+        try:
+            response = requests.get(url)
+            if response.status_code != 200:
+                print(f"Failed to send heartbeat, status code: {response.status_code}")
+        except requests.RequestException as e:
+            print(f"Error occurred: {e}")
+        # 等待 60 秒
+        time.sleep(sec)
+
+def start_heartbeat(url, sec=600):
+    heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url, sec))
+    heartbeat_thread.daemon = True
+    heartbeat_thread.start()
     
-    uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080, 
-                ssl_keyfile="/etc/ssl_file/key.pem", 
-                ssl_certfile="/etc/ssl_file/cert.pem")
+if __name__ == "__main__":
+    url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
+    start_heartbeat(url, sec=600)
+    # uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080, 
+    #             ssl_keyfile="/etc/ssl_file/key.pem", 
+    #             ssl_certfile="/etc/ssl_file/cert.pem")
+    try:
+        uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080, 
+                    ssl_keyfile="/etc/ssl_file/key.pem", ssl_certfile="/etc/ssl_file/cert.pem")
+    except KeyboardInterrupt:
+        print("收到 KeyboardInterrupt,正在清理...")
+    finally:
+        cleanup_files()
 

+ 12 - 1
text_to_sql_private.py

@@ -323,6 +323,7 @@ def sql_to_nl_chain(llm):
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
         根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
         ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
+        如果有單位,請回答時使用單位。
         
         The following shows some example:
         Question: 建準廣興廠去年的類別1總排放量是多少?
@@ -334,7 +335,7 @@ def sql_to_nl_chain(llm):
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
         SQL Result: [(1102.3712,)]
-        Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+        Answer: 建準廣興廠去年的類別1總排放量是1102.3712公噸CO2e
 
         如果你不知道答案或SQL query 出現Error請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
         若 SQL Result 為 0 代表數據為0。
@@ -363,6 +364,16 @@ def get_query(db, question, selected_table, llm):
     write_query = write_query_chain(db, llm)
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
+    # Regular expression pattern to extract SQL query
+    sql_pattern = r'SELECT[\s\S]+?;'
+
+    # Extract SQL query using re.search
+    sql_query = re.search(sql_pattern, query)
+    if sql_query:
+        query = sql_query.group()
+        # print(sql_query.group())
+    else:
+        print("No SQL query found.")
     query = re.split('SQL query: ', query)[-1]
     query = query.replace("```sql","").replace("```","")
     query = query.replace("碰排","碳排")