2 Commits 68aa73c006 ... e46d4b62c3

Auteur SHA1 Bericht Datum
  ling e46d4b62c3 update agent flow 4 maanden geleden
  ling b3e8ddabe9 update text-to-sql prompt 4 maanden geleden
8 gewijzigde bestanden met toevoegingen van 417 en 124 verwijderingen
  1. 119 13
      ai_agent.ipynb
  2. 120 77
      ai_agent.py
  3. 1 1
      faiss_index.py
  4. 1 1
      file_loader/news_vectordb.py
  5. 3 1
      post_processing_sqlparse.py
  6. 110 0
      rewrite_question.py
  7. 9 8
      systex_app.py
  8. 54 23
      text_to_sql_private.py

+ 119 - 13
ai_agent.ipynb

@@ -601,7 +601,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 8,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -643,7 +643,7 @@
     "    AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n",
     "    AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n",
     "    For the above example, we can find that user asked for \"建準\", but the PostgreSQL query gives \"事業名稱\" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n",
     "    For the above example, we can find that user asked for \"建準\", but the PostgreSQL query gives \"事業名稱\" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n",
     "    \n",
     "    \n",
-    "    and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
+    "    and so on. You need to examine whether the sql PostgreSQL query matches the user question.\n",
     "    \n",
     "    \n",
     "    If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n",
     "    If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n",
     "    You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
     "    You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
@@ -663,30 +663,122 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "### SQL Grader\n",
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
+    "# LLM\n",
+    "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
+    "    You are a SQL query grader assessing correctness of PostgreSQL query to a user question. \n",
+    "    Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.\n",
+    "    \n",
+    "    Here is database description:\n",
+    "    {table_info}\n",
+    "    \n",
+    "    You need to check that each where statement is correctly filtered out what user question need.\n",
+    "    You need to check if PostgreSQL query WHERE clause correctly filter records according to user question\n",
+    "    You need to examine whether the sql PostgreSQL query matches the user question.\n",
+    "    \n",
+    "    If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n",
+    "    You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
+    "    Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \\n\n",
+    "    Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>user<|end_header_id|>\n",
+    "    Here is the PostgreSQL query: \\n\\n {sql_query} \\n\\n\n",
+    "    Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+    "    \"\"\",\n",
+    "    input_variables=[\"table_info\", \"question\", \"sql_query\"],\n",
+    ")\n",
+    "\n",
+    "sql_query_grader = prompt | llm_json | JsonOutputParser()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "question = \"建準廣興廠去年的綠電使用量是多少?\"\n",
+    "sql_query = \"\"\"\n",
+    "\n",
+    "SELECT SUM(\"用電度數(kwh)\") AS \"自產電力綠電使用量\"\n",
+    "FROM \"用電度數\"\n",
+    "WHERE \"項目\" = '自產電力(綠電)'\n",
+    "AND \"盤查標準\" = 'GHG'\n",
+    "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from text_to_sql_private import get_query\n",
+    "selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'db' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[28], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m get_query(\u001b[43mdb\u001b[49m, question, selected_table, llm)\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'db' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "get_query(db, question, selected_table, llm)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "{'score': 'no'}\n"
+      "{'score': 'yes'}\n"
      ]
      ]
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "from text_to_sql2 import table_description\n",
-    "question = \"建準去年的類別一排放量\"\n",
+    "from text_to_sql_private import table_description\n",
+    "# question = \"建準去年的類別一排放量\"\n",
     "# sql_query = \"\"\"\n",
     "# sql_query = \"\"\"\n",
     "# SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"類別一排放量\"\n",
     "# SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"類別一排放量\"\n",
     "# FROM \"2023 清冊數據(GHG)\"\n",
     "# FROM \"2023 清冊數據(GHG)\"\n",
     "# WHERE \"類別\" = '類別一-直接排放'\n",
     "# WHERE \"類別\" = '類別一-直接排放'\n",
     "# \"\"\"\n",
     "# \"\"\"\n",
-    "question = \"台積電去年的固定燃燒總排放量是多少?\"\n",
+    "question = \"建準去年的固定燃燒總排放量是多少?\"\n",
     "sql_query = \"\"\"\n",
     "sql_query = \"\"\"\n",
     "SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n",
     "SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n",
-    "FROM \"104_112碳排放公開及建準資料\"\n",
-    "WHERE \"事業名稱\" like '%建準%'\n",
+    "FROM \"建準碳排放清冊數據\"\n",
+    "WHERE \"事業名稱\" like '%台積電%'\n",
     "AND \"排放源\" = '固定燃燒'\n",
     "AND \"排放源\" = '固定燃燒'\n",
     "AND \"盤查標準\" = 'GHG'\n",
     "AND \"盤查標準\" = 'GHG'\n",
     "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n",
     "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n",
@@ -874,7 +966,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 34,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -911,7 +1003,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 42,
+   "execution_count": 35,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -1061,7 +1153,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 43,
+   "execution_count": 36,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -1191,7 +1283,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 44,
+   "execution_count": 37,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -1337,6 +1429,20 @@
     "print(app.get_graph().draw_mermaid())"
     "print(app.get_graph().draw_mermaid())"
    ]
    ]
   },
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Image(\n",
+    "    app.get_graph().draw_mermaid_png(\n",
+    "        draw_method=MermaidDrawMethod.API,\n",
+    "        output_file_path=\"agent_workflow.png\",\n",
+    "    )\n",
+    ")"
+   ]
+  },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
    "execution_count": null,
    "execution_count": null,

+ 120 - 77
ai_agent.py

@@ -36,7 +36,6 @@ retriever = create_faiss_retriever()
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 progress_bar = []
 progress_bar = []
-
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     
     
     context = docs
     context = docs
@@ -60,6 +59,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
 
 
     Question: {question}
     Question: {question}
     用繁體中文回答問題
     用繁體中文回答問題
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     <|eot_id|>
     <|eot_id|>
     
     
     <|start_header_id|>assistant<|end_header_id|>
     <|start_header_id|>assistant<|end_header_id|>
@@ -121,19 +121,19 @@ def Answer_Grader():
 
 
 # Text-to-SQL
 # Text-to-SQL
 def run_text_to_sql(question: str):
 def run_text_to_sql(question: str):
-    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
     # question = "建準去年的固定燃燒總排放量是多少?"
     # question = "建準去年的固定燃燒總排放量是多少?"
     query, result, answer = run(db, question, selected_table, llm)
     query, result, answer = run(db, question, selected_table, llm)
     
     
     return  answer, query
     return  answer, query
 
 
 def _get_query(question: str):
 def _get_query(question: str):
-    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
-    query = get_query(db, question, selected_table, llm)
-    return  query
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+    query, result = get_query(db, question, selected_table, llm)
+    return  query, result
 
 
-def _query_to_nl(question: str, query: str):
-    answer = query_to_nl(db, question, query, llm)
+def _query_to_nl(question: str, query: str, result):
+    answer = query_to_nl(question, query, result, llm)
     return  answer
     return  answer
 
 
 def generate_additional_question(sql_query):
 def generate_additional_question(sql_query):
@@ -150,15 +150,17 @@ def generate_additional_question(sql_query):
 def generate_additional_detail(sql_query):
 def generate_additional_detail(sql_query):
     terms = parse_sql_where(sql_query)
     terms = parse_sql_where(sql_query)
     answer = ""
     answer = ""
-    for term in terms:
+    for term in list(set(terms)):
         if term is None: continue
         if term is None: continue
-        question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
+        question_format = [f"請解釋什麼是{term}?"]
         for question in question_format:
         for question in question_format:
             # question = f"什麼是{term}?"
             # question = f"什麼是{term}?"
-            documents = retriever.get_relevant_documents(question, k=30)
-            generation = faiss_query(question, documents, llm)
+            documents = retriever.get_relevant_documents(question, k=5)
+            generation = faiss_query(question, documents, llm) + "\n"
+            if "test@systex.com" in generation:
+                generation = ""
+            
             answer += generation
             answer += generation
-            answer += "\n"
             # print(question)
             # print(question)
             # print(generation)
             # print(generation)
     return answer
     return answer
@@ -177,7 +179,7 @@ def SQL_Grader():
         
         
         For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
         For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
-        FROM "104_112碳排放公開及建準資料"
+        FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%建準%'
         WHERE "事業名稱" like '%建準%'
         AND "排放源" = '下游租賃'
         AND "排放源" = '下游租賃'
         AND "盤查標準" = 'GHG'
         AND "盤查標準" = 'GHG'
@@ -186,7 +188,7 @@ def SQL_Grader():
         
         
         Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
         Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
         "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
-        FROM "104_112碳排放公開及建準資料"
+        FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%台積電%'
         WHERE "事業名稱" like '%台積電%'
         AND "排放源" = '固定燃燒'
         AND "排放源" = '固定燃燒'
         AND "盤查標準" = 'GHG'
         AND "盤查標準" = 'GHG'
@@ -251,6 +253,7 @@ class GraphState(TypedDict):
     documents: List[str]
     documents: List[str]
     retry: int
     retry: int
     sql_query: str
     sql_query: str
+    sql_result: str
     
     
 # Node
 # Node
 def show_progress(state, progress: str):
 def show_progress(state, progress: str):
@@ -289,7 +292,10 @@ def retrieve_and_generation(state):
     if not question_list:
     if not question_list:
         # documents = retriever.invoke(question)
         # documents = retriever.invoke(question)
         # TODO: correct Retrieval function
         # TODO: correct Retrieval function
-        documents = retriever.get_relevant_documents(question, k=30)
+        documents = retriever.get_relevant_documents(question, k=5)
+        for doc in documents:
+            print(doc)
+            
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # print(documents)
         # print(documents)
         generation = faiss_query(question, documents, llm)
         generation = faiss_query(question, documents, llm)
@@ -297,10 +303,13 @@ def retrieve_and_generation(state):
         generation = state["generation"]
         generation = state["generation"]
         
         
         for sub_question in list(set(question_list)):
         for sub_question in list(set(question_list)):
+            print(sub_question)
             documents = retriever.get_relevant_documents(sub_question, k=10)
             documents = retriever.get_relevant_documents(sub_question, k=10)
             generation += faiss_query(sub_question, documents, llm)
             generation += faiss_query(sub_question, documents, llm)
             generation += "\n"
             generation += "\n"
             
             
+    print(generation)
+            
     return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
     return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
 
 
 def company_private_data_get_sql_query(state):
 def company_private_data_get_sql_query(state):
@@ -328,9 +337,10 @@ def company_private_data_get_sql_query(state):
         retry = 0
         retry = 0
     # print("RETRY: ", retry)
     # print("RETRY: ", retry)
     
     
-    sql_query = _get_query(question)
+    sql_query, sql_result = _get_query(question)
+    print(type(sql_result))
     
     
-    return {"progress_bar": progress_bar, "route": route,"sql_query": sql_query, "question": question, "retry": retry}
+    return {"progress_bar": progress_bar, "route": route, "sql_query": sql_query, "sql_result": sql_result, "question": question, "retry": retry}
     
     
 def company_private_data_search(state):
 def company_private_data_search(state):
     """
     """
@@ -348,7 +358,8 @@ def company_private_data_search(state):
     # print(state)
     # print(state)
     question = state["question"]
     question = state["question"]
     sql_query = state["sql_query"]
     sql_query = state["sql_query"]
-    generation = _query_to_nl(question, sql_query)
+    sql_result = state["sql_result"]
+    generation = _query_to_nl(question, sql_query, sql_result)
     
     
     # generation = [company_private_data_result]
     # generation = [company_private_data_result]
     
     
@@ -371,11 +382,12 @@ def additional_explanation_question(state):
     sql_query = state["sql_query"]
     sql_query = state["sql_query"]
     # print(sql_query)
     # print(sql_query)
     generation = state["generation"]
     generation = state["generation"]
-    question_list = generate_additional_question(sql_query)
-    # print(question_list)
-    # generation += "\n"
-    # generation += generate_additional_detail(sql_query)
+    generation += "\n"
+    generation += generate_additional_detail(sql_query)
+    question_list = []    
     
     
+    # question_list = generate_additional_question(sql_query)
+    # print(question_list)
     
     
     # generation = [company_private_data_result]
     # generation = [company_private_data_result]
     
     
@@ -408,6 +420,9 @@ def route_question(state):
     # print(question)
     # print(question)
     question_router = Router()
     question_router = Router()
     source = question_router.invoke({"question": question})
     source = question_router.invoke({"question": question})
+    if "建準" in question:
+        source["datasource"] = "自有數據"
+        
     # print(source)
     # print(source)
     print(source["datasource"])
     print(source["datasource"])
     if source["datasource"] == "自有數據":
     if source["datasource"] == "自有數據":
@@ -431,43 +446,56 @@ def grade_generation_v_documents_and_question(state):
     """
     """
 
 
     # print("---CHECK HALLUCINATIONS---")
     # print("---CHECK HALLUCINATIONS---")
-    progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
     question = state["question"]
     question = state["question"]
     documents = state["documents"]
     documents = state["documents"]
     generation = state["generation"]
     generation = state["generation"]
 
 
-    
-    # print(docs_documents)
-    # print(generation)
-    hallucination_grader = Hallucination_Grader()
-    score = hallucination_grader.invoke(
-        {"documents": documents, "generation": generation}
-    )
-    # print(score)
+    progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
+    answer_grader = Answer_Grader()
+    score = answer_grader.invoke({"question": question, "generation": generation})
     grade = score["score"]
     grade = score["score"]
-
-    # Check hallucination
     if grade in ["yes", "true", 1, "1"]:
     if grade in ["yes", "true", 1, "1"]:
-        # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
-        progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
-        # Check question-answering
-        # print("---GRADE GENERATION vs QUESTION---")
-        progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
-        answer_grader = Answer_Grader()
-        score = answer_grader.invoke({"question": question, "generation": generation})
-        grade = score["score"]
-        if grade in ["yes", "true", 1, "1"]:
-            # print("---DECISION: GENERATION ADDRESSES QUESTION---")
-            progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
-            return "useful"
-        else:
-            # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
-            progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
-            return "not useful"
+        # print("---DECISION: GENERATION ADDRESSES QUESTION---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
+        return "useful"
     else:
     else:
-        # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
-        progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
-        return "not supported"
+        # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+        return "not useful"
+    
+    
+    # progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
+    # # print(docs_documents)
+    # # print(generation)
+    # hallucination_grader = Hallucination_Grader()
+    # score = hallucination_grader.invoke(
+    #     {"documents": documents, "generation": generation}
+    # )
+    # # print(score)
+    # grade = score["score"]
+
+    # # Check hallucination
+    # if grade in ["yes", "true", 1, "1"]:
+    #     # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+    #     progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+    #     # Check question-answering
+    #     # print("---GRADE GENERATION vs QUESTION---")
+    #     progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
+    #     answer_grader = Answer_Grader()
+    #     score = answer_grader.invoke({"question": question, "generation": generation})
+    #     grade = score["score"]
+    #     if grade in ["yes", "true", 1, "1"]:
+    #         # print("---DECISION: GENERATION ADDRESSES QUESTION---")
+    #         progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
+    #         return "useful"
+    #     else:
+    #         # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+    #         progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+    #         return "not useful"
+    # else:
+    #     # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+    #     progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+    #     return "not supported"
     
     
 def grade_sql_query(state):
 def grade_sql_query(state):
     """
     """
@@ -484,25 +512,34 @@ def grade_sql_query(state):
     progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
     progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
     question = state["question"]
     question = state["question"]
     sql_query = state["sql_query"]
     sql_query = state["sql_query"]
-    retry = state["retry"]
-
-    # Score each doc
-    sql_query_grader = SQL_Grader()
-    score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
-    grade = score["score"]
-    # Document relevant
-    if grade in ["yes", "true", 1, "1"]:
-        # print("---GRADE: CORRECT SQL QUERY---")
-        progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
-        return "correct"
-    elif retry >= 5:
-        # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
-        progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
-        return "failed"
-    else:
-        # print("---GRADE: INCORRECT SQL QUERY---")
-        progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
+    sql_result = state["sql_result"]
+    if "None" in sql_result:
+        progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         return "incorrect"
         return "incorrect"
+    else:
+        progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
+        return "correct"
+    # retry = state["retry"]
+
+    # # Score each doc
+    # sql_query_grader = SQL_Grader()
+    # score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
+    # grade = score["score"]
+    
+    
+    # # Document relevant
+    # if grade in ["yes", "true", 1, "1"]:
+    #     # print("---GRADE: CORRECT SQL QUERY---")
+    #     progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
+    #     return "correct"
+    # elif retry >= 5:
+    #     # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+    #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+    #     return "failed"
+    # else:
+    #     # print("---GRADE: INCORRECT SQL QUERY---")
+    #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
+    #     return "incorrect"
 
 
 def build_graph():
 def build_graph():
     workflow = StateGraph(GraphState)
     workflow = StateGraph(GraphState)
@@ -527,7 +564,6 @@ def build_graph():
         "RAG",
         "RAG",
         grade_generation_v_documents_and_question,
         grade_generation_v_documents_and_question,
         {
         {
-            "not supported": "ERROR",
             "useful": END,
             "useful": END,
             "not useful": "ERROR",
             "not useful": "ERROR",
         },
         },
@@ -537,21 +573,26 @@ def build_graph():
         grade_sql_query,
         grade_sql_query,
         {
         {
             "correct": "SQL Answer",
             "correct": "SQL Answer",
-            "incorrect": "ERROR",
-            "failed": "RAG"
+            "incorrect": "RAG",
             
             
         },
         },
     )
     )
     workflow.add_edge("SQL Answer", "Additoinal Explanation")
     workflow.add_edge("SQL Answer", "Additoinal Explanation")
-    workflow.add_edge("Additoinal Explanation", "RAG")
+    workflow.add_edge("Additoinal Explanation", END)
 
 
     app = workflow.compile()    
     app = workflow.compile()    
     
     
     return app
     return app
 
 
+app = build_graph()
+draw_mermaid = app.get_graph().draw_mermaid()
+print(draw_mermaid)
+
 def main(question: str):
 def main(question: str):
     
     
-    app = build_graph()
+    # app = build_graph()
+    # draw_mermaid = app.get_graph().draw_mermaid()
+    # print(draw_mermaid)
     #建準去年的類別一排放量?
     #建準去年的類別一排放量?
     # inputs = {"question": "溫室氣體是什麼"}
     # inputs = {"question": "溫室氣體是什麼"}
     inputs = {"question": question, "progress_bar": None}
     inputs = {"question": question, "progress_bar": None}
@@ -561,12 +602,14 @@ def main(question: str):
     # pprint(value["generation"])
     # pprint(value["generation"])
     # pprint(value)
     # pprint(value)
     value["progress_bar"] = progress_bar
     value["progress_bar"] = progress_bar
-    pprint(value["progress_bar"])
+    # pprint(value["progress_bar"])
     
     
     return value["generation"]
     return value["generation"]
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     # result = main("建準去年的逸散排放總排放量是多少?")
     # result = main("建準去年的逸散排放總排放量是多少?")
-    result = main("建準去年的綠電使用量是多少?")
+    result = main("建準夏威夷去年的綠電使用量是多少?")
+    # result = main("溫室氣體是什麼?")
+    # result = main("什麼是外購電力(綠電)?")
     print("------------------------------------------------------")
     print("------------------------------------------------------")
     print(result)
     print(result)

+ 1 - 1
faiss_index.py

@@ -44,7 +44,7 @@ load_dotenv('../../.env')
 supabase_url = os.getenv("SUPABASE_URL")
 supabase_url = os.getenv("SUPABASE_URL")
 supabase_key = os.getenv("SUPABASE_KEY")
 supabase_key = os.getenv("SUPABASE_KEY")
 openai_api_key = os.getenv("OPENAI_API_KEY")
 openai_api_key = os.getenv("OPENAI_API_KEY")
-document_table = "documents"
+document_table = "documents2"
 
 
 # Initialize Supabase client
 # Initialize Supabase client
 supabase: Client = create_client(supabase_url, supabase_key)
 supabase: Client = create_client(supabase_url, supabase_key)

+ 1 - 1
file_loader/news_vectordb.py

@@ -13,7 +13,7 @@ from add_vectordb import GetVectorStore
 load_dotenv("../.env")
 load_dotenv("../.env")
 supabase_url = os.environ.get("SUPABASE_URL")
 supabase_url = os.environ.get("SUPABASE_URL")
 supabase_key = os.environ.get("SUPABASE_KEY")
 supabase_key = os.environ.get("SUPABASE_KEY")
-document_table = "documents"
+document_table = "documents2"
 supabase: Client = create_client(supabase_url, supabase_key)
 supabase: Client = create_client(supabase_url, supabase_key)
 
 
 embeddings = OpenAIEmbeddings()
 embeddings = OpenAIEmbeddings()

+ 3 - 1
post_processing_sqlparse.py

@@ -60,7 +60,9 @@ def parse_sql_where(sql):
     column_dict = {
     column_dict = {
         "排放源": None,
         "排放源": None,
         "類別": None,
         "類別": None,
-        "項目": None
+        "類別項目": None,
+        "項目": None,
+        
     }
     }
 
 
     def get_column_details(token, column_args):
     def get_column_details(token, column_args):

+ 110 - 0
rewrite_question.py

@@ -0,0 +1,110 @@
+from langchain_core.output_parsers import StrOutputParser
+from langchain_openai import ChatOpenAI
+from langchain_core.runnables import RunnablePassthrough
+from langchain import PromptTemplate
+from langchain_community.chat_models import ChatOllama
+
+
+from langchain_core.runnables import (
+    RunnableBranch,
+    RunnableLambda,
+    RunnableParallel,
+    RunnablePassthrough,
+)
+from typing import Tuple, List, Optional
+from langchain_core.messages import AIMessage, HumanMessage
+
+local_llm = "llama3-groq-tool-use:latest"
+# llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+llm = ChatOllama(model=local_llm, temperature=0)
+
+def get_search_query():
+    # Condense a chat history and follow-up question into a standalone question
+    # 
+    # _template = """Given the following conversation and a follow up question, 
+    # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
+    # Generate standalone question in its original language.
+    # Chat History:
+    # {chat_history}
+    # Follow Up Input: {question}
+
+    # Hint:
+    # * Refer to chat history and add the subject to the question
+    # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
+    
+    # Standalone question:"""  # noqa: E501
+    _template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    Rewrite the following query by incorporating relevant context from the conversation history.
+    The rewritten query should:
+    
+    - Preserve the core intent and meaning of the original query
+    - Expand and clarify the query to make it more specific and informative for retrieving relevant context
+    - Avoid introducing new topics or queries that deviate from the original query
+    - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
+    - The rewritten query should be in its original language.
+    
+    Return ONLY the rewritten query text, without any additional formatting or explanations.
+    
+    <|eot_id|>
+        
+    <|begin_of_text|><|start_header_id|>user<|end_header_id|>
+    Conversation History:
+    {chat_history}
+    
+    Original query: [{question}]
+    
+    Hint:
+    * Refer to chat history and add the subject to the question
+    * Replace the pronouns in the question with the correct person or thing, please refer to chat history
+    
+    Rewritten query: 
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
+
+    def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
+        buffer = []
+        for human, ai in chat_history:
+            buffer.append(HumanMessage(content=human))
+            buffer.append(AIMessage(content=ai))
+        return buffer
+
+    _search_query = RunnableBranch(
+        # If input includes chat_history, we condense it with the follow-up question
+        (
+            RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
+                run_name="HasChatHistoryCheck"
+            ),  # Condense follow-up question and chat into a standalone_question
+            RunnablePassthrough.assign(
+                chat_history=lambda x: _format_chat_history(x["chat_history"])
+            )
+            | CONDENSE_QUESTION_PROMPT
+            | llm
+            | StrOutputParser(),
+        ),
+        # Else, we have no chat history, so just pass through the question
+        RunnableLambda(lambda x : x["question"]),
+    )
+
+    return _search_query
+
+if __name__ == "__main__":
+    _search_query = get_search_query()
+    chat_history = [
+        {
+            "q": "北海建準廠2023年的類別3排放量是多少?",
+            "a": """根據北海建準廠2023年的數據,類別3的排放量是2,162.62公噸CO2e。
+                類別3指的是溫室氣體排放量盤查作業中的一個範疇,該範疇涵蓋了事業之溫室氣體排放量的盤查和登錄。"""
+        }
+        ]
+    chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
+    print(chat_history)
+    
+    question = "類別2呢"
+    modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
+    print(modified_question)

+ 9 - 8
systex_app.py

@@ -39,20 +39,21 @@ class ChatHistoryItem(BaseModel):
     
     
 @app.post("/agents")
 @app.post("/agents")
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
+    print(question)
     start = time.time()
     start = time.time()
     
     
     with get_openai_callback() as cb:
     with get_openai_callback() as cb:
-        cache_question, cache_answer = semantic_cache(supabase, question)
+        # cache_question, cache_answer = semantic_cache(supabase, question)
+        cache_answer = None
         if cache_answer:
         if cache_answer:
-            processing_time = time.time() - start
-            save_history(question, cache_answer, cb, processing_time)
-
-            return {"Answer": cache_answer}
-    
-        answer = main(question)
-        
+            answer = cache_answer
+        else:
+            answer = main(question)
     processing_time = time.time() - start
     processing_time = time.time() - start
     save_history(question, answer, cb, processing_time)
     save_history(question, answer, cb, processing_time)
+    if "test@systex.com" in answer:
+        answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    print(answer)
     return {"Answer": answer}  
     return {"Answer": answer}  
 
 
 def save_history(question, answer, cb, processing_time):
 def save_history(question, answer, cb, processing_time):

+ 54 - 23
text_to_sql_private.py

@@ -76,39 +76,51 @@ llm = ChatOllama(model=local_llm, temperature=0)
 def get_examples():
 def get_examples():
     examples = [
     examples = [
         {
         {
-            "input": "建準廣興廠2023年的自產電力的綠電使用量是多少?",
-            "query": """SELECT SUM("用電度數(kwh)") AS "自產電力綠電使用量"
+            "input": "建準廣興廠年的自產電力的綠電使用量是多少?",
+            "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
                         FROM "用電度數"
                         FROM "用電度數"
-                        WHERE "項目" = '自產電力(綠電)'
+                        WHERE "項目" like '%綠電%'
+                        AND "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%廣興廠%'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
-                        AND "年度" = 2023;""",
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         },
         {
         {
-            "input": "建準廣興廠去年的類別1總排放量是多少?",
+            "input": "建準北海廠去年的類別1總排放量是多少?",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
-                        FROM "建準碳排放清冊數據"
+                        FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興廠%'
-                        AND ("類別" like '%類別1-直接排放%' OR "排放源" like '%類別1-直接排放%')
+                        AND "事業名稱" like '%北海%'
+                        AND "類別" = '類別1'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "建準廣興廠去年的直接排放總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%廣興%'
+                        AND ("類別項目" like '%直接排放%' OR "排放源" like '%直接排放%')
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         },
         {
         {
-            "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?",
+            "input": "建準台北辦事處2022年的類別2總排放量是多少?",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
-                        FROM "建準碳排放清冊數據"
+                        FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%台北辦事處%'
-                        AND ("類別" like '%類別2-能源間接排放%' OR "排放源" like '%類別2-能源間接排放%')
+                        AND "事業名稱" like '%台北%'
+                        AND "類別" = '類別2'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = 2022;""",
                         AND "年度" = 2022;""",
         },
         },
         {
         {
             "input": "建準去年的固定燃燒總排放量是多少?",
             "input": "建準去年的固定燃燒總排放量是多少?",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
-                        FROM "建準碳排放清冊數據"
+                        FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND ("類別" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         },
@@ -120,9 +132,9 @@ def get_examples():
 
 
 def table_description():
 def table_description():
     database_description = (
     database_description = (
-        "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據`. "
+        "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據new`."
         "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
         "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
-        "The `建準碳排放清冊數據` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
+        "The `建準碳排放清冊數據new` table 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
         "It includes the following columns:\n"
         "It includes the following columns:\n"
         "- `年度`: 盤查年度\n"
         "- `年度`: 盤查年度\n"
         "- `事業名稱`: 建準據點"
         "- `事業名稱`: 建準據點"
@@ -142,7 +154,7 @@ def table_description():
         "- `盤查標準`: ISO or GHG\n"
         "- `盤查標準`: ISO or GHG\n"
         
         
 
 
-        "The `用電度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
+        "The `用電度數` 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
         "It includes the following columns:\n"
         "It includes the following columns:\n"
         "- `年度`: 盤查年度\n"
         "- `年度`: 盤查年度\n"
         "- `事業名稱`: 建準據點"
         "- `事業名稱`: 建準據點"
@@ -154,7 +166,7 @@ def table_description():
         "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
         "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
         "- `盤查標準`: ISO or GHG\n"
         "- `盤查標準`: ISO or GHG\n"
         
         
-        "The `用水度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
+        "The `用水度數` 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
         "It includes the following columns:\n"
         "It includes the following columns:\n"
         "- `年度`: 盤查年度\n"
         "- `年度`: 盤查年度\n"
         "- `事業名稱`: 建準據點"
         "- `事業名稱`: 建準據點"
@@ -230,8 +242,22 @@ def sql_to_nl_chain(llm):
         <|begin_of_text|><|start_header_id|>system<|end_header_id|>
         <|begin_of_text|><|start_header_id|>system<|end_header_id|>
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
         根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
         根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
-
+        ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
         
         
+        The following shows some example:
+        Question: 廣興廠去年的類別1總排放量是多少?
+        SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%廣興%'
+                        AND "類別" = '類別1'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
+        SQL Result: [(1102.3712,)]
+        Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+
+        如果你不知道答案或SQL query 出現錯誤請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+        勿回答無關資訊
         <|eot_id|>
         <|eot_id|>
 
 
         <|begin_of_text|><|start_header_id|>user<|end_header_id|>
         <|begin_of_text|><|start_header_id|>user<|end_header_id|>
@@ -259,14 +285,19 @@ def get_query(db, question, selected_table, llm):
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     print(query)
     print(query)
     
     
-    return query
-
-def query_to_nl(db, question, query, llm):
     execute_query = QuerySQLDataBaseTool(db=db)
     execute_query = QuerySQLDataBaseTool(db=db)
     result = execute_query.invoke(query)
     result = execute_query.invoke(query)
     print(result)
     print(result)
 
 
+    return query, result
+
+def query_to_nl(question, query, result, llm):
+    # execute_query = QuerySQLDataBaseTool(db=db)
+    # result = execute_query.invoke(query)
+    # print(result)
+
     chain = sql_to_nl_chain(llm)
     chain = sql_to_nl_chain(llm)
+    print(result)
     answer = chain.invoke({"question": question, "query": query, "result": result})
     answer = chain.invoke({"question": question, "query": query, "result": result})
 
 
     return answer
     return answer
@@ -295,7 +326,7 @@ if __name__ == "__main__":
     
     
     start = time.time()
     start = time.time()
     
     
-    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據']
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
     question = "建準去年的上游運輸總排放量是多少?"
     question = "建準去年的上游運輸總排放量是多少?"
     # question = "台積電2022年的直接排放總排放量是多少?"
     # question = "台積電2022年的直接排放總排放量是多少?"
     # question = "建準廣興廠去年的灰電使用量"
     # question = "建準廣興廠去年的灰電使用量"