Browse Source

update agent flow

ling 4 months ago
parent
commit
e46d4b62c3
6 changed files with 360 additions and 100 deletions
  1. 119 13
      ai_agent.ipynb
  2. 120 77
      ai_agent.py
  3. 1 1
      faiss_index.py
  4. 1 1
      file_loader/news_vectordb.py
  5. 110 0
      rewrite_question.py
  6. 9 8
      systex_app.py

+ 119 - 13
ai_agent.ipynb

@@ -601,7 +601,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 19,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -643,7 +643,7 @@
     "    AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n",
     "    For the above example, we can find that user asked for \"建準\", but the PostgreSQL query gives \"事業名稱\" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n",
     "    \n",
-    "    and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
+    "    and so on. You need to examine whether the sql PostgreSQL query matches the user question.\n",
     "    \n",
     "    If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n",
     "    You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
@@ -663,30 +663,122 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 20,
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "### SQL Grader\n",
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
+    "# LLM\n",
+    "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
+    "    You are a SQL query grader assessing correctness of PostgreSQL query to a user question. \n",
+    "    Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.\n",
+    "    \n",
+    "    Here is database description:\n",
+    "    {table_info}\n",
+    "    \n",
+    "    You need to check that each where statement is correctly filtered out what user question need.\n",
+    "    You need to check if PostgreSQL query WHERE clause correctly filter records according to user question\n",
+    "    You need to examine whether the sql PostgreSQL query matches the user question.\n",
+    "    \n",
+    "    If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n",
+    "    You need to strictly examine whether the sql PostgreSQL query matches the user question.\n",
+    "    Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \\n\n",
+    "    Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>user<|end_header_id|>\n",
+    "    Here is the PostgreSQL query: \\n\\n {sql_query} \\n\\n\n",
+    "    Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+    "    \"\"\",\n",
+    "    input_variables=[\"table_info\", \"question\", \"sql_query\"],\n",
+    ")\n",
+    "\n",
+    "sql_query_grader = prompt | llm_json | JsonOutputParser()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "question = \"建準廣興廠去年的綠電使用量是多少?\"\n",
+    "sql_query = \"\"\"\n",
+    "\n",
+    "SELECT SUM(\"用電度數(kwh)\") AS \"自產電力綠電使用量\"\n",
+    "FROM \"用電度數\"\n",
+    "WHERE \"項目\" = '自產電力(綠電)'\n",
+    "AND \"盤查標準\" = 'GHG'\n",
+    "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from text_to_sql_private import get_query\n",
+    "selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'db' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[28], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m get_query(\u001b[43mdb\u001b[49m, question, selected_table, llm)\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'db' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "get_query(db, question, selected_table, llm)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "{'score': 'no'}\n"
+      "{'score': 'yes'}\n"
      ]
     }
    ],
    "source": [
-    "from text_to_sql2 import table_description\n",
-    "question = \"建準去年的類別一排放量\"\n",
+    "from text_to_sql_private import table_description\n",
+    "# question = \"建準去年的類別一排放量\"\n",
     "# sql_query = \"\"\"\n",
     "# SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"類別一排放量\"\n",
     "# FROM \"2023 清冊數據(GHG)\"\n",
     "# WHERE \"類別\" = '類別一-直接排放'\n",
     "# \"\"\"\n",
-    "question = \"台積電去年的固定燃燒總排放量是多少?\"\n",
+    "question = \"建準去年的固定燃燒總排放量是多少?\"\n",
     "sql_query = \"\"\"\n",
     "SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n",
-    "FROM \"104_112碳排放公開及建準資料\"\n",
-    "WHERE \"事業名稱\" like '%建準%'\n",
+    "FROM \"建準碳排放清冊數據\"\n",
+    "WHERE \"事業名稱\" like '%台積電%'\n",
     "AND \"排放源\" = '固定燃燒'\n",
     "AND \"盤查標準\" = 'GHG'\n",
     "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n",
@@ -874,7 +966,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 34,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -911,7 +1003,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 42,
+   "execution_count": 35,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1061,7 +1153,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 43,
+   "execution_count": 36,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1191,7 +1283,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 44,
+   "execution_count": 37,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1337,6 +1429,20 @@
     "print(app.get_graph().draw_mermaid())"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Image(\n",
+    "    app.get_graph().draw_mermaid_png(\n",
+    "        draw_method=MermaidDrawMethod.API,\n",
+    "        output_file_path=\"agent_workflow.png\",\n",
+    "    )\n",
+    ")"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,

+ 120 - 77
ai_agent.py

@@ -36,7 +36,6 @@ retriever = create_faiss_retriever()
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 progress_bar = []
-
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     
     context = docs
@@ -60,6 +59,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
 
     Question: {question}
     用繁體中文回答問題
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     <|eot_id|>
     
     <|start_header_id|>assistant<|end_header_id|>
@@ -121,19 +121,19 @@ def Answer_Grader():
 
 # Text-to-SQL
 def run_text_to_sql(question: str):
-    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
     # question = "建準去年的固定燃燒總排放量是多少?"
     query, result, answer = run(db, question, selected_table, llm)
     
     return  answer, query
 
 def _get_query(question: str):
-    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
-    query = get_query(db, question, selected_table, llm)
-    return  query
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+    query, result = get_query(db, question, selected_table, llm)
+    return  query, result
 
-def _query_to_nl(question: str, query: str):
-    answer = query_to_nl(db, question, query, llm)
+def _query_to_nl(question: str, query: str, result):
+    answer = query_to_nl(question, query, result, llm)
     return  answer
 
 def generate_additional_question(sql_query):
@@ -150,15 +150,17 @@ def generate_additional_question(sql_query):
 def generate_additional_detail(sql_query):
     terms = parse_sql_where(sql_query)
     answer = ""
-    for term in terms:
+    for term in list(set(terms)):
         if term is None: continue
-        question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
+        question_format = [f"請解釋什麼是{term}?"]
         for question in question_format:
             # question = f"什麼是{term}?"
-            documents = retriever.get_relevant_documents(question, k=30)
-            generation = faiss_query(question, documents, llm)
+            documents = retriever.get_relevant_documents(question, k=5)
+            generation = faiss_query(question, documents, llm) + "\n"
+            if "test@systex.com" in generation:
+                generation = ""
+            
             answer += generation
-            answer += "\n"
             # print(question)
             # print(generation)
     return answer
@@ -177,7 +179,7 @@ def SQL_Grader():
         
         For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
-        FROM "104_112碳排放公開及建準資料"
+        FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%建準%'
         AND "排放源" = '下游租賃'
         AND "盤查標準" = 'GHG'
@@ -186,7 +188,7 @@ def SQL_Grader():
         
         Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
-        FROM "104_112碳排放公開及建準資料"
+        FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%台積電%'
         AND "排放源" = '固定燃燒'
         AND "盤查標準" = 'GHG'
@@ -251,6 +253,7 @@ class GraphState(TypedDict):
     documents: List[str]
     retry: int
     sql_query: str
+    sql_result: str
     
 # Node
 def show_progress(state, progress: str):
@@ -289,7 +292,10 @@ def retrieve_and_generation(state):
     if not question_list:
         # documents = retriever.invoke(question)
         # TODO: correct Retrieval function
-        documents = retriever.get_relevant_documents(question, k=30)
+        documents = retriever.get_relevant_documents(question, k=5)
+        for doc in documents:
+            print(doc)
+            
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # print(documents)
         generation = faiss_query(question, documents, llm)
@@ -297,10 +303,13 @@ def retrieve_and_generation(state):
         generation = state["generation"]
         
         for sub_question in list(set(question_list)):
+            print(sub_question)
             documents = retriever.get_relevant_documents(sub_question, k=10)
             generation += faiss_query(sub_question, documents, llm)
             generation += "\n"
             
+    print(generation)
+            
     return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
 
 def company_private_data_get_sql_query(state):
@@ -328,9 +337,10 @@ def company_private_data_get_sql_query(state):
         retry = 0
     # print("RETRY: ", retry)
     
-    sql_query = _get_query(question)
+    sql_query, sql_result = _get_query(question)
+    print(type(sql_result))
     
-    return {"progress_bar": progress_bar, "route": route,"sql_query": sql_query, "question": question, "retry": retry}
+    return {"progress_bar": progress_bar, "route": route, "sql_query": sql_query, "sql_result": sql_result, "question": question, "retry": retry}
     
 def company_private_data_search(state):
     """
@@ -348,7 +358,8 @@ def company_private_data_search(state):
     # print(state)
     question = state["question"]
     sql_query = state["sql_query"]
-    generation = _query_to_nl(question, sql_query)
+    sql_result = state["sql_result"]
+    generation = _query_to_nl(question, sql_query, sql_result)
     
     # generation = [company_private_data_result]
     
@@ -371,11 +382,12 @@ def additional_explanation_question(state):
     sql_query = state["sql_query"]
     # print(sql_query)
     generation = state["generation"]
-    question_list = generate_additional_question(sql_query)
-    # print(question_list)
-    # generation += "\n"
-    # generation += generate_additional_detail(sql_query)
+    generation += "\n"
+    generation += generate_additional_detail(sql_query)
+    question_list = []    
     
+    # question_list = generate_additional_question(sql_query)
+    # print(question_list)
     
     # generation = [company_private_data_result]
     
@@ -408,6 +420,9 @@ def route_question(state):
     # print(question)
     question_router = Router()
     source = question_router.invoke({"question": question})
+    if "建準" in question:
+        source["datasource"] = "自有數據"
+        
     # print(source)
     print(source["datasource"])
     if source["datasource"] == "自有數據":
@@ -431,43 +446,56 @@ def grade_generation_v_documents_and_question(state):
     """
 
     # print("---CHECK HALLUCINATIONS---")
-    progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
     question = state["question"]
     documents = state["documents"]
     generation = state["generation"]
 
-    
-    # print(docs_documents)
-    # print(generation)
-    hallucination_grader = Hallucination_Grader()
-    score = hallucination_grader.invoke(
-        {"documents": documents, "generation": generation}
-    )
-    # print(score)
+    progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
+    answer_grader = Answer_Grader()
+    score = answer_grader.invoke({"question": question, "generation": generation})
     grade = score["score"]
-
-    # Check hallucination
     if grade in ["yes", "true", 1, "1"]:
-        # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
-        progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
-        # Check question-answering
-        # print("---GRADE GENERATION vs QUESTION---")
-        progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
-        answer_grader = Answer_Grader()
-        score = answer_grader.invoke({"question": question, "generation": generation})
-        grade = score["score"]
-        if grade in ["yes", "true", 1, "1"]:
-            # print("---DECISION: GENERATION ADDRESSES QUESTION---")
-            progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
-            return "useful"
-        else:
-            # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
-            progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
-            return "not useful"
+        # print("---DECISION: GENERATION ADDRESSES QUESTION---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
+        return "useful"
     else:
-        # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
-        progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
-        return "not supported"
+        # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+        return "not useful"
+    
+    
+    # progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
+    # # print(docs_documents)
+    # # print(generation)
+    # hallucination_grader = Hallucination_Grader()
+    # score = hallucination_grader.invoke(
+    #     {"documents": documents, "generation": generation}
+    # )
+    # # print(score)
+    # grade = score["score"]
+
+    # # Check hallucination
+    # if grade in ["yes", "true", 1, "1"]:
+    #     # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+    #     progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+    #     # Check question-answering
+    #     # print("---GRADE GENERATION vs QUESTION---")
+    #     progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
+    #     answer_grader = Answer_Grader()
+    #     score = answer_grader.invoke({"question": question, "generation": generation})
+    #     grade = score["score"]
+    #     if grade in ["yes", "true", 1, "1"]:
+    #         # print("---DECISION: GENERATION ADDRESSES QUESTION---")
+    #         progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
+    #         return "useful"
+    #     else:
+    #         # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+    #         progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+    #         return "not useful"
+    # else:
+    #     # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+    #     progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+    #     return "not supported"
     
 def grade_sql_query(state):
     """
@@ -484,25 +512,34 @@ def grade_sql_query(state):
     progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
     question = state["question"]
     sql_query = state["sql_query"]
-    retry = state["retry"]
-
-    # Score each doc
-    sql_query_grader = SQL_Grader()
-    score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
-    grade = score["score"]
-    # Document relevant
-    if grade in ["yes", "true", 1, "1"]:
-        # print("---GRADE: CORRECT SQL QUERY---")
-        progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
-        return "correct"
-    elif retry >= 5:
-        # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
-        progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
-        return "failed"
-    else:
-        # print("---GRADE: INCORRECT SQL QUERY---")
-        progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
+    sql_result = state["sql_result"]
+    if "None" in sql_result:
+        progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         return "incorrect"
+    else:
+        progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
+        return "correct"
+    # retry = state["retry"]
+
+    # # Score each doc
+    # sql_query_grader = SQL_Grader()
+    # score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
+    # grade = score["score"]
+    
+    
+    # # Document relevant
+    # if grade in ["yes", "true", 1, "1"]:
+    #     # print("---GRADE: CORRECT SQL QUERY---")
+    #     progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
+    #     return "correct"
+    # elif retry >= 5:
+    #     # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+    #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+    #     return "failed"
+    # else:
+    #     # print("---GRADE: INCORRECT SQL QUERY---")
+    #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
+    #     return "incorrect"
 
 def build_graph():
     workflow = StateGraph(GraphState)
@@ -527,7 +564,6 @@ def build_graph():
         "RAG",
         grade_generation_v_documents_and_question,
         {
-            "not supported": "ERROR",
             "useful": END,
             "not useful": "ERROR",
         },
@@ -537,21 +573,26 @@ def build_graph():
         grade_sql_query,
         {
             "correct": "SQL Answer",
-            "incorrect": "ERROR",
-            "failed": "RAG"
+            "incorrect": "RAG",
             
         },
     )
     workflow.add_edge("SQL Answer", "Additoinal Explanation")
-    workflow.add_edge("Additoinal Explanation", "RAG")
+    workflow.add_edge("Additoinal Explanation", END)
 
     app = workflow.compile()    
     
     return app
 
+app = build_graph()
+draw_mermaid = app.get_graph().draw_mermaid()
+print(draw_mermaid)
+
 def main(question: str):
     
-    app = build_graph()
+    # app = build_graph()
+    # draw_mermaid = app.get_graph().draw_mermaid()
+    # print(draw_mermaid)
     #建準去年的類別一排放量?
     # inputs = {"question": "溫室氣體是什麼"}
     inputs = {"question": question, "progress_bar": None}
@@ -561,12 +602,14 @@ def main(question: str):
     # pprint(value["generation"])
     # pprint(value)
     value["progress_bar"] = progress_bar
-    pprint(value["progress_bar"])
+    # pprint(value["progress_bar"])
     
     return value["generation"]
 
 if __name__ == "__main__":
     # result = main("建準去年的逸散排放總排放量是多少?")
-    result = main("建準去年的綠電使用量是多少?")
+    result = main("建準夏威夷去年的綠電使用量是多少?")
+    # result = main("溫室氣體是什麼?")
+    # result = main("什麼是外購電力(綠電)?")
     print("------------------------------------------------------")
     print(result)

+ 1 - 1
faiss_index.py

@@ -44,7 +44,7 @@ load_dotenv('../../.env')
 supabase_url = os.getenv("SUPABASE_URL")
 supabase_key = os.getenv("SUPABASE_KEY")
 openai_api_key = os.getenv("OPENAI_API_KEY")
-document_table = "documents"
+document_table = "documents2"
 
 # Initialize Supabase client
 supabase: Client = create_client(supabase_url, supabase_key)

+ 1 - 1
file_loader/news_vectordb.py

@@ -13,7 +13,7 @@ from add_vectordb import GetVectorStore
 load_dotenv("../.env")
 supabase_url = os.environ.get("SUPABASE_URL")
 supabase_key = os.environ.get("SUPABASE_KEY")
-document_table = "documents"
+document_table = "documents2"
 supabase: Client = create_client(supabase_url, supabase_key)
 
 embeddings = OpenAIEmbeddings()

+ 110 - 0
rewrite_question.py

@@ -0,0 +1,110 @@
+from langchain_core.output_parsers import StrOutputParser
+from langchain_openai import ChatOpenAI
+from langchain_core.runnables import RunnablePassthrough
+from langchain import PromptTemplate
+from langchain_community.chat_models import ChatOllama
+
+
+from langchain_core.runnables import (
+    RunnableBranch,
+    RunnableLambda,
+    RunnableParallel,
+    RunnablePassthrough,
+)
+from typing import Tuple, List, Optional
+from langchain_core.messages import AIMessage, HumanMessage
+
+local_llm = "llama3-groq-tool-use:latest"
+# llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+llm = ChatOllama(model=local_llm, temperature=0)
+
+def get_search_query():
+    # Condense a chat history and follow-up question into a standalone question
+    # 
+    # _template = """Given the following conversation and a follow up question, 
+    # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
+    # Generate standalone question in its original language.
+    # Chat History:
+    # {chat_history}
+    # Follow Up Input: {question}
+
+    # Hint:
+    # * Refer to chat history and add the subject to the question
+    # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
+    
+    # Standalone question:"""  # noqa: E501
+    _template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    Rewrite the following query by incorporating relevant context from the conversation history.
+    The rewritten query should:
+    
+    - Preserve the core intent and meaning of the original query
+    - Expand and clarify the query to make it more specific and informative for retrieving relevant context
+    - Avoid introducing new topics or queries that deviate from the original query
+    - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
+    - The rewritten query should be in its original language.
+    
+    Return ONLY the rewritten query text, without any additional formatting or explanations.
+    
+    <|eot_id|>
+        
+    <|begin_of_text|><|start_header_id|>user<|end_header_id|>
+    Conversation History:
+    {chat_history}
+    
+    Original query: [{question}]
+    
+    Hint:
+    * Refer to chat history and add the subject to the question
+    * Replace the pronouns in the question with the correct person or thing, please refer to chat history
+    
+    Rewritten query: 
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
+
+    def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
+        buffer = []
+        for human, ai in chat_history:
+            buffer.append(HumanMessage(content=human))
+            buffer.append(AIMessage(content=ai))
+        return buffer
+
+    _search_query = RunnableBranch(
+        # If input includes chat_history, we condense it with the follow-up question
+        (
+            RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
+                run_name="HasChatHistoryCheck"
+            ),  # Condense follow-up question and chat into a standalone_question
+            RunnablePassthrough.assign(
+                chat_history=lambda x: _format_chat_history(x["chat_history"])
+            )
+            | CONDENSE_QUESTION_PROMPT
+            | llm
+            | StrOutputParser(),
+        ),
+        # Else, we have no chat history, so just pass through the question
+        RunnableLambda(lambda x : x["question"]),
+    )
+
+    return _search_query
+
+if __name__ == "__main__":
+    _search_query = get_search_query()
+    chat_history = [
+        {
+            "q": "北海建準廠2023年的類別3排放量是多少?",
+            "a": """根據北海建準廠2023年的數據,類別3的排放量是2,162.62公噸CO2e。
+                類別3指的是溫室氣體排放量盤查作業中的一個範疇,該範疇涵蓋了事業之溫室氣體排放量的盤查和登錄。"""
+        }
+        ]
+    chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
+    print(chat_history)
+    
+    question = "類別2呢"
+    modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
+    print(modified_question)

+ 9 - 8
systex_app.py

@@ -39,20 +39,21 @@ class ChatHistoryItem(BaseModel):
     
 @app.post("/agents")
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
+    print(question)
     start = time.time()
     
     with get_openai_callback() as cb:
-        cache_question, cache_answer = semantic_cache(supabase, question)
+        # cache_question, cache_answer = semantic_cache(supabase, question)
+        cache_answer = None
         if cache_answer:
-            processing_time = time.time() - start
-            save_history(question, cache_answer, cb, processing_time)
-
-            return {"Answer": cache_answer}
-    
-        answer = main(question)
-        
+            answer = cache_answer
+        else:
+            answer = main(question)
     processing_time = time.time() - start
     save_history(question, answer, cb, processing_time)
+    if "test@systex.com" in answer:
+        answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    print(answer)
     return {"Answer": answer}  
 
 def save_history(question, answer, cb, processing_time):