Browse Source

update systex_app

ling 4 months ago
parent
commit
68aa73c006
4 changed files with 127 additions and 52 deletions
  1. 17 6
      ai_agent.ipynb
  2. 76 35
      ai_agent.py
  3. 13 5
      post_processing_sqlparse.py
  4. 21 6
      systex_app.py

File diff suppressed because it is too large
+ 17 - 6
ai_agent.ipynb


+ 76 - 35
ai_agent.py

@@ -35,7 +35,7 @@ retriever = create_faiss_retriever()
 # text-to-sql usage
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
-
+progress_bar = []
 
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     
@@ -141,7 +141,7 @@ def generate_additional_question(sql_query):
     question_list = []
     for term in terms:
         if term is None: continue
-        question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
+        question_format = [f"什麼是{term}?", f"{term}的用途是什麼"]
         question_list.extend(question_format)
         
     return question_list
@@ -243,6 +243,7 @@ class GraphState(TypedDict):
         documents: list of documents
     """
 
+    progress_bar: List[str]
     route: str
     question: str
     question_list: List[str]
@@ -252,6 +253,15 @@ class GraphState(TypedDict):
     sql_query: str
     
 # Node
+def show_progress(state, progress: str):
+    global progress_bar
+    # progress_bar = state["progress_bar"] if state["progress_bar"] else []
+    
+    print(progress)
+    progress_bar.append(progress)
+    
+    return progress_bar
+
 def retrieve_and_generation(state):
     """
     Retrieve documents from vectorstore
@@ -262,13 +272,17 @@ def retrieve_and_generation(state):
     Returns:
         state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
     """
-    print("---RETRIEVE---")
+    progress_bar = show_progress(state, "---RETRIEVE---")
+    # progress_bar = state["progress"] if state["progress"] else []
+    # progress = "---RETRIEVE---"
+    # print(progress)
+    # progress_bar.append(progress)
     if not state["route"]:
         route = "RAG"
     else:
         route = state["route"]
     question = state["question"]
-    print(state)
+    # print(state)
     question_list = state["question_list"]
     
     # Retrieval
@@ -282,13 +296,12 @@ def retrieve_and_generation(state):
     else:
         generation = state["generation"]
         
-        for sub_question in question_list:
-            documents = retriever.get_relevant_documents(sub_question, k=30)
+        for sub_question in list(set(question_list)):
+            documents = retriever.get_relevant_documents(sub_question, k=10)
             generation += faiss_query(sub_question, documents, llm)
             generation += "\n"
             
-    
-    return {"route": route, "documents": documents, "question": question, "generation": generation}
+    return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
 
 def company_private_data_get_sql_query(state):
     """
@@ -300,7 +313,8 @@ def company_private_data_get_sql_query(state):
     Returns:
         state (dict): return generated PostgreSQL query and record retry times
     """
-    print("---SQL QUERY---")
+    # print("---SQL QUERY---")
+    progress_bar = show_progress(state, "---SQL QUERY---")
     if not state["route"]:
         route = "SQL"
     else:
@@ -316,7 +330,7 @@ def company_private_data_get_sql_query(state):
     
     sql_query = _get_query(question)
     
-    return {"route": route,"sql_query": sql_query, "question": question, "retry": retry}
+    return {"progress_bar": progress_bar, "route": route,"sql_query": sql_query, "question": question, "retry": retry}
     
 def company_private_data_search(state):
     """
@@ -329,7 +343,8 @@ def company_private_data_search(state):
         state (dict): Appended sql results to state
     """
 
-    print("---SQL TO NL---")
+    # print("---SQL TO NL---")
+    progress_bar = show_progress(state, "---SQL TO NL---")
     # print(state)
     question = state["question"]
     sql_query = state["sql_query"]
@@ -337,7 +352,7 @@ def company_private_data_search(state):
     
     # generation = [company_private_data_result]
     
-    return {"sql_query": sql_query, "question": question, "generation": generation}
+    return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation}
 
 def additional_explanation_question(state):
     """
@@ -349,21 +364,29 @@ def additional_explanation_question(state):
         state (dict): Appended additional explanation to state
     """
     
-    print("---ADDITIONAL EXPLANATION---")
-    print(state)
+    # print("---ADDITIONAL EXPLANATION---")
+    progress_bar = show_progress(state, "---ADDITIONAL EXPLANATION---")
+    # print(state)
     question = state["question"]
     sql_query = state["sql_query"]
-    print(sql_query)
+    # print(sql_query)
     generation = state["generation"]
     question_list = generate_additional_question(sql_query)
-    print(question_list)
+    # print(question_list)
     # generation += "\n"
     # generation += generate_additional_detail(sql_query)
     
     
     # generation = [company_private_data_result]
     
-    return {"sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
+    return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
+
+def error(state):
+    # print("---SOMETHING WENT WRONG---")
+    progress_bar = show_progress(state, "---SOMETHING WENT WRONG---")
+    generation = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    
+    return {"progress_bar": progress_bar,  "generation": generation}
 
 ### Conditional edge
 
@@ -379,7 +402,8 @@ def route_question(state):
         str: Next node to call
     """
 
-    print("---ROUTE QUESTION---")
+    # print("---ROUTE QUESTION---")
+    progress_bar = show_progress(state, "---ROUTE QUESTION---")
     question = state["question"]
     # print(question)
     question_router = Router()
@@ -387,10 +411,12 @@ def route_question(state):
     # print(source)
     print(source["datasource"])
     if source["datasource"] == "自有數據":
-        print("---ROUTE QUESTION TO TEXT-TO-SQL---")
+        # print("---ROUTE QUESTION TO TEXT-TO-SQL---")
+        progress_bar = show_progress(state, "---ROUTE QUESTION TO TEXT-TO-SQL---")
         return "自有數據"
     elif source["datasource"] == "專業知識":
-        print("---ROUTE QUESTION TO RAG---")
+        # print("---ROUTE QUESTION TO RAG---")
+        progress_bar = show_progress(state, "---ROUTE QUESTION TO RAG---")
         return "專業知識"
     
 def grade_generation_v_documents_and_question(state):
@@ -404,7 +430,8 @@ def grade_generation_v_documents_and_question(state):
         str: Decision for next node to call
     """
 
-    print("---CHECK HALLUCINATIONS---")
+    # print("---CHECK HALLUCINATIONS---")
+    progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
     question = state["question"]
     documents = state["documents"]
     generation = state["generation"]
@@ -421,20 +448,25 @@ def grade_generation_v_documents_and_question(state):
 
     # Check hallucination
     if grade in ["yes", "true", 1, "1"]:
-        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+        # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
         # Check question-answering
-        print("---GRADE GENERATION vs QUESTION---")
+        # print("---GRADE GENERATION vs QUESTION---")
+        progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
         answer_grader = Answer_Grader()
         score = answer_grader.invoke({"question": question, "generation": generation})
         grade = score["score"]
         if grade in ["yes", "true", 1, "1"]:
-            print("---DECISION: GENERATION ADDRESSES QUESTION---")
+            # print("---DECISION: GENERATION ADDRESSES QUESTION---")
+            progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
             return "useful"
         else:
-            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+            # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+            progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
             return "not useful"
     else:
-        pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+        # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+        progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
         return "not supported"
     
 def grade_sql_query(state):
@@ -448,7 +480,8 @@ def grade_sql_query(state):
         state (dict): Decision for retry or continue
     """
 
-    print("---CHECK SQL CORRECTNESS TO QUESTION---")
+    # print("---CHECK SQL CORRECTNESS TO QUESTION---")
+    progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
     question = state["question"]
     sql_query = state["sql_query"]
     retry = state["retry"]
@@ -459,13 +492,16 @@ def grade_sql_query(state):
     grade = score["score"]
     # Document relevant
     if grade in ["yes", "true", 1, "1"]:
-        print("---GRADE: CORRECT SQL QUERY---")
+        # print("---GRADE: CORRECT SQL QUERY---")
+        progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
         return "correct"
     elif retry >= 5:
-        print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+        # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+        progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
         return "failed"
     else:
-        print("---GRADE: INCORRECT SQL QUERY---")
+        # print("---GRADE: INCORRECT SQL QUERY---")
+        progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
         return "incorrect"
 
 def build_graph():
@@ -476,6 +512,7 @@ def build_graph():
     workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
     workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("ERROR", error)  # retrieve
     
     workflow.add_conditional_edges(
         START,
@@ -490,9 +527,9 @@ def build_graph():
         "RAG",
         grade_generation_v_documents_and_question,
         {
-            "not supported": "RAG",
+            "not supported": "ERROR",
             "useful": END,
-            "not useful": "RAG",
+            "not useful": "ERROR",
         },
     )
     workflow.add_conditional_edges(
@@ -500,7 +537,7 @@ def build_graph():
         grade_sql_query,
         {
             "correct": "SQL Answer",
-            "incorrect": "Text-to-SQL",
+            "incorrect": "ERROR",
             "failed": "RAG"
             
         },
@@ -513,19 +550,23 @@ def build_graph():
     return app
 
 def main(question: str):
+    
     app = build_graph()
     #建準去年的類別一排放量?
     # inputs = {"question": "溫室氣體是什麼"}
-    inputs = {"question": question}
+    inputs = {"question": question, "progress_bar": None}
     for output in app.stream(inputs, {"recursion_limit": 10}):
         for key, value in output.items():
             pprint(f"Finished running: {key}:")
     # pprint(value["generation"])
     # pprint(value)
+    value["progress_bar"] = progress_bar
+    pprint(value["progress_bar"])
     
     return value["generation"]
 
 if __name__ == "__main__":
-    result = main("建準去年的逸散排放總排放量是多少?")
+    # result = main("建準去年的逸散排放總排放量是多少?")
+    result = main("建準去年的綠電使用量是多少?")
     print("------------------------------------------------------")
     print(result)

+ 13 - 5
post_processing_sqlparse.py

@@ -59,12 +59,13 @@ def parse_sql_where(sql):
     stmt = sqlparse.parse(sql)[0]
     column_dict = {
         "排放源": None,
-        "類別": None
+        "類別": None,
+        "項目": None
     }
 
     def get_column_details(token, column_args):
         if isinstance(token, Comparison):
-            print(token, type(token))
+            # print(token, type(token))
             for column_name in column_args.keys():
                 if column_args[column_name] is None:
                     column_args[column_name] = extract_comparison_value(token.tokens, column_name)
@@ -83,7 +84,8 @@ def parse_sql_where(sql):
                     # print(token2, type(token2))
                     for token3 in token2.tokens:
                         column_dict = get_column_details(token3, column_dict)
-    column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys()]
+    column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys() if column_dict[column_name] is not None]
+    
     return column_values
 
 def get_table_name(sql):
@@ -109,7 +111,13 @@ if __name__ == "__main__":
         AND "盤查標準" = 'GHG'
         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
         """
-        
+    sql_query = """
+        SELECT SUM("用電度數(kwh)") AS "綠電使用量"
+        FROM "用電度數"
+        WHERE "項目" = '自產電力(綠電)'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1
+    """
     print(get_query_columns(sql_query, get_real_name=True))
-    print(parse_sql_for_stock_info(sql_query))
+    print(parse_sql_where(sql_query))
     print(get_table_name(sql_query))

+ 21 - 6
systex_app.py

@@ -2,7 +2,7 @@ import datetime
 from json import loads
 import time
 from typing import List
-from fastapi import FastAPI
+from fastapi import Body, FastAPI
 from fastapi.middleware.cors import CORSMiddleware
 
 import pandas as pd
@@ -16,6 +16,7 @@ from supabase.client import Client, create_client
 from langchain.callbacks import get_openai_callback
 
 from ai_agent import main
+from semantic_search import semantic_cache
 
 load_dotenv()
 URI = os.getenv("SUPABASE_URI")
@@ -32,15 +33,27 @@ app.add_middleware(
     allow_methods=["*"],
     allow_headers=["*"],
 )
-
-@app.get("/agents")
-def agent(question: str):
+class ChatHistoryItem(BaseModel):
+    q: str
+    a: str
+    
+@app.post("/agents")
+def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
     start = time.time()
+    
     with get_openai_callback() as cb:
+        cache_question, cache_answer = semantic_cache(supabase, question)
+        if cache_answer:
+            processing_time = time.time() - start
+            save_history(question, cache_answer, cb, processing_time)
+
+            return {"Answer": cache_answer}
+    
         answer = main(question)
+        
     processing_time = time.time() - start
     save_history(question, answer, cb, processing_time)
-    return {"answer": answer}  
+    return {"Answer": answer}  
 
 def save_history(question, answer, cb, processing_time):
     # reference = [doc.dict() for doc in reference]
@@ -80,5 +93,7 @@ async def get_history():
 
 if __name__ == "__main__":
     
-    uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080)
+    uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080, 
+                ssl_keyfile="/etc/ssl_file/key.pem", 
+                ssl_certfile="/etc/ssl_file/cert.pem")
 

Some files were not shown because too many files changed in this diff