|
@@ -35,7 +35,7 @@ retriever = create_faiss_retriever()
|
|
# text-to-sql usage
|
|
# text-to-sql usage
|
|
from text_to_sql_private import run, get_query, query_to_nl, table_description
|
|
from text_to_sql_private import run, get_query, query_to_nl, table_description
|
|
from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
|
|
from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
|
|
-
|
|
|
|
|
|
+progress_bar = []
|
|
|
|
|
|
def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
|
|
|
@@ -141,7 +141,7 @@ def generate_additional_question(sql_query):
|
|
question_list = []
|
|
question_list = []
|
|
for term in terms:
|
|
for term in terms:
|
|
if term is None: continue
|
|
if term is None: continue
|
|
- question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
|
|
|
|
|
|
+ question_format = [f"什麼是{term}?", f"{term}的用途是什麼"]
|
|
question_list.extend(question_format)
|
|
question_list.extend(question_format)
|
|
|
|
|
|
return question_list
|
|
return question_list
|
|
@@ -243,6 +243,7 @@ class GraphState(TypedDict):
|
|
documents: list of documents
|
|
documents: list of documents
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
+ progress_bar: List[str]
|
|
route: str
|
|
route: str
|
|
question: str
|
|
question: str
|
|
question_list: List[str]
|
|
question_list: List[str]
|
|
@@ -252,6 +253,15 @@ class GraphState(TypedDict):
|
|
sql_query: str
|
|
sql_query: str
|
|
|
|
|
|
# Node
|
|
# Node
|
|
|
|
+def show_progress(state, progress: str):
|
|
|
|
+ global progress_bar
|
|
|
|
+ # progress_bar = state["progress_bar"] if state["progress_bar"] else []
|
|
|
|
+
|
|
|
|
+ print(progress)
|
|
|
|
+ progress_bar.append(progress)
|
|
|
|
+
|
|
|
|
+ return progress_bar
|
|
|
|
+
|
|
def retrieve_and_generation(state):
|
|
def retrieve_and_generation(state):
|
|
"""
|
|
"""
|
|
Retrieve documents from vectorstore
|
|
Retrieve documents from vectorstore
|
|
@@ -262,13 +272,17 @@ def retrieve_and_generation(state):
|
|
Returns:
|
|
Returns:
|
|
state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
|
|
state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
|
|
"""
|
|
"""
|
|
- print("---RETRIEVE---")
|
|
|
|
|
|
+ progress_bar = show_progress(state, "---RETRIEVE---")
|
|
|
|
+ # progress_bar = state["progress"] if state["progress"] else []
|
|
|
|
+ # progress = "---RETRIEVE---"
|
|
|
|
+ # print(progress)
|
|
|
|
+ # progress_bar.append(progress)
|
|
if not state["route"]:
|
|
if not state["route"]:
|
|
route = "RAG"
|
|
route = "RAG"
|
|
else:
|
|
else:
|
|
route = state["route"]
|
|
route = state["route"]
|
|
question = state["question"]
|
|
question = state["question"]
|
|
- print(state)
|
|
|
|
|
|
+ # print(state)
|
|
question_list = state["question_list"]
|
|
question_list = state["question_list"]
|
|
|
|
|
|
# Retrieval
|
|
# Retrieval
|
|
@@ -282,13 +296,12 @@ def retrieve_and_generation(state):
|
|
else:
|
|
else:
|
|
generation = state["generation"]
|
|
generation = state["generation"]
|
|
|
|
|
|
- for sub_question in question_list:
|
|
|
|
- documents = retriever.get_relevant_documents(sub_question, k=30)
|
|
|
|
|
|
+ for sub_question in list(set(question_list)):
|
|
|
|
+ documents = retriever.get_relevant_documents(sub_question, k=10)
|
|
generation += faiss_query(sub_question, documents, llm)
|
|
generation += faiss_query(sub_question, documents, llm)
|
|
generation += "\n"
|
|
generation += "\n"
|
|
|
|
|
|
-
|
|
|
|
- return {"route": route, "documents": documents, "question": question, "generation": generation}
|
|
|
|
|
|
+ return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
|
|
|
|
|
|
def company_private_data_get_sql_query(state):
|
|
def company_private_data_get_sql_query(state):
|
|
"""
|
|
"""
|
|
@@ -300,7 +313,8 @@ def company_private_data_get_sql_query(state):
|
|
Returns:
|
|
Returns:
|
|
state (dict): return generated PostgreSQL query and record retry times
|
|
state (dict): return generated PostgreSQL query and record retry times
|
|
"""
|
|
"""
|
|
- print("---SQL QUERY---")
|
|
|
|
|
|
+ # print("---SQL QUERY---")
|
|
|
|
+ progress_bar = show_progress(state, "---SQL QUERY---")
|
|
if not state["route"]:
|
|
if not state["route"]:
|
|
route = "SQL"
|
|
route = "SQL"
|
|
else:
|
|
else:
|
|
@@ -316,7 +330,7 @@ def company_private_data_get_sql_query(state):
|
|
|
|
|
|
sql_query = _get_query(question)
|
|
sql_query = _get_query(question)
|
|
|
|
|
|
- return {"route": route,"sql_query": sql_query, "question": question, "retry": retry}
|
|
|
|
|
|
+ return {"progress_bar": progress_bar, "route": route,"sql_query": sql_query, "question": question, "retry": retry}
|
|
|
|
|
|
def company_private_data_search(state):
|
|
def company_private_data_search(state):
|
|
"""
|
|
"""
|
|
@@ -329,7 +343,8 @@ def company_private_data_search(state):
|
|
state (dict): Appended sql results to state
|
|
state (dict): Appended sql results to state
|
|
"""
|
|
"""
|
|
|
|
|
|
- print("---SQL TO NL---")
|
|
|
|
|
|
+ # print("---SQL TO NL---")
|
|
|
|
+ progress_bar = show_progress(state, "---SQL TO NL---")
|
|
# print(state)
|
|
# print(state)
|
|
question = state["question"]
|
|
question = state["question"]
|
|
sql_query = state["sql_query"]
|
|
sql_query = state["sql_query"]
|
|
@@ -337,7 +352,7 @@ def company_private_data_search(state):
|
|
|
|
|
|
# generation = [company_private_data_result]
|
|
# generation = [company_private_data_result]
|
|
|
|
|
|
- return {"sql_query": sql_query, "question": question, "generation": generation}
|
|
|
|
|
|
+ return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation}
|
|
|
|
|
|
def additional_explanation_question(state):
|
|
def additional_explanation_question(state):
|
|
"""
|
|
"""
|
|
@@ -349,21 +364,29 @@ def additional_explanation_question(state):
|
|
state (dict): Appended additional explanation to state
|
|
state (dict): Appended additional explanation to state
|
|
"""
|
|
"""
|
|
|
|
|
|
- print("---ADDITIONAL EXPLANATION---")
|
|
|
|
- print(state)
|
|
|
|
|
|
+ # print("---ADDITIONAL EXPLANATION---")
|
|
|
|
+ progress_bar = show_progress(state, "---ADDITIONAL EXPLANATION---")
|
|
|
|
+ # print(state)
|
|
question = state["question"]
|
|
question = state["question"]
|
|
sql_query = state["sql_query"]
|
|
sql_query = state["sql_query"]
|
|
- print(sql_query)
|
|
|
|
|
|
+ # print(sql_query)
|
|
generation = state["generation"]
|
|
generation = state["generation"]
|
|
question_list = generate_additional_question(sql_query)
|
|
question_list = generate_additional_question(sql_query)
|
|
- print(question_list)
|
|
|
|
|
|
+ # print(question_list)
|
|
# generation += "\n"
|
|
# generation += "\n"
|
|
# generation += generate_additional_detail(sql_query)
|
|
# generation += generate_additional_detail(sql_query)
|
|
|
|
|
|
|
|
|
|
# generation = [company_private_data_result]
|
|
# generation = [company_private_data_result]
|
|
|
|
|
|
- return {"sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
|
|
|
|
|
|
+ return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
|
|
|
|
+
|
|
|
|
+def error(state):
|
|
|
|
+ # print("---SOMETHING WENT WRONG---")
|
|
|
|
+ progress_bar = show_progress(state, "---SOMETHING WENT WRONG---")
|
|
|
|
+ generation = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
|
+
|
|
|
|
+ return {"progress_bar": progress_bar, "generation": generation}
|
|
|
|
|
|
### Conditional edge
|
|
### Conditional edge
|
|
|
|
|
|
@@ -379,7 +402,8 @@ def route_question(state):
|
|
str: Next node to call
|
|
str: Next node to call
|
|
"""
|
|
"""
|
|
|
|
|
|
- print("---ROUTE QUESTION---")
|
|
|
|
|
|
+ # print("---ROUTE QUESTION---")
|
|
|
|
+ progress_bar = show_progress(state, "---ROUTE QUESTION---")
|
|
question = state["question"]
|
|
question = state["question"]
|
|
# print(question)
|
|
# print(question)
|
|
question_router = Router()
|
|
question_router = Router()
|
|
@@ -387,10 +411,12 @@ def route_question(state):
|
|
# print(source)
|
|
# print(source)
|
|
print(source["datasource"])
|
|
print(source["datasource"])
|
|
if source["datasource"] == "自有數據":
|
|
if source["datasource"] == "自有數據":
|
|
- print("---ROUTE QUESTION TO TEXT-TO-SQL---")
|
|
|
|
|
|
+ # print("---ROUTE QUESTION TO TEXT-TO-SQL---")
|
|
|
|
+ progress_bar = show_progress(state, "---ROUTE QUESTION TO TEXT-TO-SQL---")
|
|
return "自有數據"
|
|
return "自有數據"
|
|
elif source["datasource"] == "專業知識":
|
|
elif source["datasource"] == "專業知識":
|
|
- print("---ROUTE QUESTION TO RAG---")
|
|
|
|
|
|
+ # print("---ROUTE QUESTION TO RAG---")
|
|
|
|
+ progress_bar = show_progress(state, "---ROUTE QUESTION TO RAG---")
|
|
return "專業知識"
|
|
return "專業知識"
|
|
|
|
|
|
def grade_generation_v_documents_and_question(state):
|
|
def grade_generation_v_documents_and_question(state):
|
|
@@ -404,7 +430,8 @@ def grade_generation_v_documents_and_question(state):
|
|
str: Decision for next node to call
|
|
str: Decision for next node to call
|
|
"""
|
|
"""
|
|
|
|
|
|
- print("---CHECK HALLUCINATIONS---")
|
|
|
|
|
|
+ # print("---CHECK HALLUCINATIONS---")
|
|
|
|
+ progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
|
|
question = state["question"]
|
|
question = state["question"]
|
|
documents = state["documents"]
|
|
documents = state["documents"]
|
|
generation = state["generation"]
|
|
generation = state["generation"]
|
|
@@ -421,20 +448,25 @@ def grade_generation_v_documents_and_question(state):
|
|
|
|
|
|
# Check hallucination
|
|
# Check hallucination
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
- print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
|
|
|
+ # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
# Check question-answering
|
|
# Check question-answering
|
|
- print("---GRADE GENERATION vs QUESTION---")
|
|
|
|
|
|
+ # print("---GRADE GENERATION vs QUESTION---")
|
|
|
|
+ progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
|
|
answer_grader = Answer_Grader()
|
|
answer_grader = Answer_Grader()
|
|
score = answer_grader.invoke({"question": question, "generation": generation})
|
|
score = answer_grader.invoke({"question": question, "generation": generation})
|
|
grade = score["score"]
|
|
grade = score["score"]
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
- print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
|
|
|
+ # print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
return "useful"
|
|
return "useful"
|
|
else:
|
|
else:
|
|
- print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
|
|
|
+ # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
return "not useful"
|
|
return "not useful"
|
|
else:
|
|
else:
|
|
- pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
|
|
|
+ # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
return "not supported"
|
|
return "not supported"
|
|
|
|
|
|
def grade_sql_query(state):
|
|
def grade_sql_query(state):
|
|
@@ -448,7 +480,8 @@ def grade_sql_query(state):
|
|
state (dict): Decision for retry or continue
|
|
state (dict): Decision for retry or continue
|
|
"""
|
|
"""
|
|
|
|
|
|
- print("---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
|
|
|
|
+ # print("---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
|
|
+ progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
question = state["question"]
|
|
question = state["question"]
|
|
sql_query = state["sql_query"]
|
|
sql_query = state["sql_query"]
|
|
retry = state["retry"]
|
|
retry = state["retry"]
|
|
@@ -459,13 +492,16 @@ def grade_sql_query(state):
|
|
grade = score["score"]
|
|
grade = score["score"]
|
|
# Document relevant
|
|
# Document relevant
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
- print("---GRADE: CORRECT SQL QUERY---")
|
|
|
|
|
|
+ # print("---GRADE: CORRECT SQL QUERY---")
|
|
|
|
+ progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
|
|
return "correct"
|
|
return "correct"
|
|
elif retry >= 5:
|
|
elif retry >= 5:
|
|
- print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
|
|
|
+ # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
|
+ progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
return "failed"
|
|
return "failed"
|
|
else:
|
|
else:
|
|
- print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
|
|
|
+ # print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
|
+ progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
|
|
return "incorrect"
|
|
return "incorrect"
|
|
|
|
|
|
def build_graph():
|
|
def build_graph():
|
|
@@ -476,6 +512,7 @@ def build_graph():
|
|
workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
|
|
workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
|
|
workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
|
|
+ workflow.add_node("ERROR", error) # retrieve
|
|
|
|
|
|
workflow.add_conditional_edges(
|
|
workflow.add_conditional_edges(
|
|
START,
|
|
START,
|
|
@@ -490,9 +527,9 @@ def build_graph():
|
|
"RAG",
|
|
"RAG",
|
|
grade_generation_v_documents_and_question,
|
|
grade_generation_v_documents_and_question,
|
|
{
|
|
{
|
|
- "not supported": "RAG",
|
|
|
|
|
|
+ "not supported": "ERROR",
|
|
"useful": END,
|
|
"useful": END,
|
|
- "not useful": "RAG",
|
|
|
|
|
|
+ "not useful": "ERROR",
|
|
},
|
|
},
|
|
)
|
|
)
|
|
workflow.add_conditional_edges(
|
|
workflow.add_conditional_edges(
|
|
@@ -500,7 +537,7 @@ def build_graph():
|
|
grade_sql_query,
|
|
grade_sql_query,
|
|
{
|
|
{
|
|
"correct": "SQL Answer",
|
|
"correct": "SQL Answer",
|
|
- "incorrect": "Text-to-SQL",
|
|
|
|
|
|
+ "incorrect": "ERROR",
|
|
"failed": "RAG"
|
|
"failed": "RAG"
|
|
|
|
|
|
},
|
|
},
|
|
@@ -513,19 +550,23 @@ def build_graph():
|
|
return app
|
|
return app
|
|
|
|
|
|
def main(question: str):
|
|
def main(question: str):
|
|
|
|
+
|
|
app = build_graph()
|
|
app = build_graph()
|
|
#建準去年的類別一排放量?
|
|
#建準去年的類別一排放量?
|
|
# inputs = {"question": "溫室氣體是什麼"}
|
|
# inputs = {"question": "溫室氣體是什麼"}
|
|
- inputs = {"question": question}
|
|
|
|
|
|
+ inputs = {"question": question, "progress_bar": None}
|
|
for output in app.stream(inputs, {"recursion_limit": 10}):
|
|
for output in app.stream(inputs, {"recursion_limit": 10}):
|
|
for key, value in output.items():
|
|
for key, value in output.items():
|
|
pprint(f"Finished running: {key}:")
|
|
pprint(f"Finished running: {key}:")
|
|
# pprint(value["generation"])
|
|
# pprint(value["generation"])
|
|
# pprint(value)
|
|
# pprint(value)
|
|
|
|
+ value["progress_bar"] = progress_bar
|
|
|
|
+ pprint(value["progress_bar"])
|
|
|
|
|
|
return value["generation"]
|
|
return value["generation"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
- result = main("建準去年的逸散排放總排放量是多少?")
|
|
|
|
|
|
+ # result = main("建準去年的逸散排放總排放量是多少?")
|
|
|
|
+ result = main("建準去年的綠電使用量是多少?")
|
|
print("------------------------------------------------------")
|
|
print("------------------------------------------------------")
|
|
print(result)
|
|
print(result)
|