|
@@ -36,7 +36,6 @@ retriever = create_faiss_retriever()
|
|
|
from text_to_sql_private import run, get_query, query_to_nl, table_description
|
|
|
from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
|
|
|
progress_bar = []
|
|
|
-
|
|
|
def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
|
|
|
context = docs
|
|
@@ -60,6 +59,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
|
|
|
Question: {question}
|
|
|
用繁體中文回答問題
|
|
|
+ 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
<|eot_id|>
|
|
|
|
|
|
<|start_header_id|>assistant<|end_header_id|>
|
|
@@ -121,19 +121,19 @@ def Answer_Grader():
|
|
|
|
|
|
# Text-to-SQL
|
|
|
def run_text_to_sql(question: str):
|
|
|
- selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
|
|
|
+ selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
|
|
|
# question = "建準去年的固定燃燒總排放量是多少?"
|
|
|
query, result, answer = run(db, question, selected_table, llm)
|
|
|
|
|
|
return answer, query
|
|
|
|
|
|
def _get_query(question: str):
|
|
|
- selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
|
|
|
- query = get_query(db, question, selected_table, llm)
|
|
|
- return query
|
|
|
+ selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
|
|
|
+ query, result = get_query(db, question, selected_table, llm)
|
|
|
+ return query, result
|
|
|
|
|
|
-def _query_to_nl(question: str, query: str):
|
|
|
- answer = query_to_nl(db, question, query, llm)
|
|
|
+def _query_to_nl(question: str, query: str, result):
|
|
|
+ answer = query_to_nl(question, query, result, llm)
|
|
|
return answer
|
|
|
|
|
|
def generate_additional_question(sql_query):
|
|
@@ -150,15 +150,17 @@ def generate_additional_question(sql_query):
|
|
|
def generate_additional_detail(sql_query):
|
|
|
terms = parse_sql_where(sql_query)
|
|
|
answer = ""
|
|
|
- for term in terms:
|
|
|
+ for term in list(set(terms)):
|
|
|
if term is None: continue
|
|
|
- question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
|
|
|
+ question_format = [f"請解釋什麼是{term}?"]
|
|
|
for question in question_format:
|
|
|
# question = f"什麼是{term}?"
|
|
|
- documents = retriever.get_relevant_documents(question, k=30)
|
|
|
- generation = faiss_query(question, documents, llm)
|
|
|
+ documents = retriever.get_relevant_documents(question, k=5)
|
|
|
+ generation = faiss_query(question, documents, llm) + "\n"
|
|
|
+ if "test@systex.com" in generation:
|
|
|
+ generation = ""
|
|
|
+
|
|
|
answer += generation
|
|
|
- answer += "\n"
|
|
|
# print(question)
|
|
|
# print(generation)
|
|
|
return answer
|
|
@@ -177,7 +179,7 @@ def SQL_Grader():
|
|
|
|
|
|
For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
"SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
|
|
|
- FROM "104_112碳排放公開及建準資料"
|
|
|
+ FROM "建準碳排放清冊數據new"
|
|
|
WHERE "事業名稱" like '%建準%'
|
|
|
AND "排放源" = '下游租賃'
|
|
|
AND "盤查標準" = 'GHG'
|
|
@@ -186,7 +188,7 @@ def SQL_Grader():
|
|
|
|
|
|
Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
"SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
|
|
|
- FROM "104_112碳排放公開及建準資料"
|
|
|
+ FROM "建準碳排放清冊數據new"
|
|
|
WHERE "事業名稱" like '%台積電%'
|
|
|
AND "排放源" = '固定燃燒'
|
|
|
AND "盤查標準" = 'GHG'
|
|
@@ -251,6 +253,7 @@ class GraphState(TypedDict):
|
|
|
documents: List[str]
|
|
|
retry: int
|
|
|
sql_query: str
|
|
|
+ sql_result: str
|
|
|
|
|
|
# Node
|
|
|
def show_progress(state, progress: str):
|
|
@@ -289,7 +292,10 @@ def retrieve_and_generation(state):
|
|
|
if not question_list:
|
|
|
# documents = retriever.invoke(question)
|
|
|
# TODO: correct Retrieval function
|
|
|
- documents = retriever.get_relevant_documents(question, k=30)
|
|
|
+ documents = retriever.get_relevant_documents(question, k=5)
|
|
|
+ for doc in documents:
|
|
|
+ print(doc)
|
|
|
+
|
|
|
# docs_documents = "\n\n".join(doc.page_content for doc in documents)
|
|
|
# print(documents)
|
|
|
generation = faiss_query(question, documents, llm)
|
|
@@ -297,10 +303,13 @@ def retrieve_and_generation(state):
|
|
|
generation = state["generation"]
|
|
|
|
|
|
for sub_question in list(set(question_list)):
|
|
|
+ print(sub_question)
|
|
|
documents = retriever.get_relevant_documents(sub_question, k=10)
|
|
|
generation += faiss_query(sub_question, documents, llm)
|
|
|
generation += "\n"
|
|
|
|
|
|
+ print(generation)
|
|
|
+
|
|
|
return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
|
|
|
|
|
|
def company_private_data_get_sql_query(state):
|
|
@@ -328,9 +337,10 @@ def company_private_data_get_sql_query(state):
|
|
|
retry = 0
|
|
|
# print("RETRY: ", retry)
|
|
|
|
|
|
- sql_query = _get_query(question)
|
|
|
+ sql_query, sql_result = _get_query(question)
|
|
|
+ print(type(sql_result))
|
|
|
|
|
|
- return {"progress_bar": progress_bar, "route": route,"sql_query": sql_query, "question": question, "retry": retry}
|
|
|
+ return {"progress_bar": progress_bar, "route": route, "sql_query": sql_query, "sql_result": sql_result, "question": question, "retry": retry}
|
|
|
|
|
|
def company_private_data_search(state):
|
|
|
"""
|
|
@@ -348,7 +358,8 @@ def company_private_data_search(state):
|
|
|
# print(state)
|
|
|
question = state["question"]
|
|
|
sql_query = state["sql_query"]
|
|
|
- generation = _query_to_nl(question, sql_query)
|
|
|
+ sql_result = state["sql_result"]
|
|
|
+ generation = _query_to_nl(question, sql_query, sql_result)
|
|
|
|
|
|
# generation = [company_private_data_result]
|
|
|
|
|
@@ -371,11 +382,12 @@ def additional_explanation_question(state):
|
|
|
sql_query = state["sql_query"]
|
|
|
# print(sql_query)
|
|
|
generation = state["generation"]
|
|
|
- question_list = generate_additional_question(sql_query)
|
|
|
- # print(question_list)
|
|
|
- # generation += "\n"
|
|
|
- # generation += generate_additional_detail(sql_query)
|
|
|
+ generation += "\n"
|
|
|
+ generation += generate_additional_detail(sql_query)
|
|
|
+ question_list = []
|
|
|
|
|
|
+ # question_list = generate_additional_question(sql_query)
|
|
|
+ # print(question_list)
|
|
|
|
|
|
# generation = [company_private_data_result]
|
|
|
|
|
@@ -408,6 +420,9 @@ def route_question(state):
|
|
|
# print(question)
|
|
|
question_router = Router()
|
|
|
source = question_router.invoke({"question": question})
|
|
|
+ if "建準" in question:
|
|
|
+ source["datasource"] = "自有數據"
|
|
|
+
|
|
|
# print(source)
|
|
|
print(source["datasource"])
|
|
|
if source["datasource"] == "自有數據":
|
|
@@ -431,43 +446,56 @@ def grade_generation_v_documents_and_question(state):
|
|
|
"""
|
|
|
|
|
|
# print("---CHECK HALLUCINATIONS---")
|
|
|
- progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
|
|
|
question = state["question"]
|
|
|
documents = state["documents"]
|
|
|
generation = state["generation"]
|
|
|
|
|
|
-
|
|
|
- # print(docs_documents)
|
|
|
- # print(generation)
|
|
|
- hallucination_grader = Hallucination_Grader()
|
|
|
- score = hallucination_grader.invoke(
|
|
|
- {"documents": documents, "generation": generation}
|
|
|
- )
|
|
|
- # print(score)
|
|
|
+ progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
|
|
|
+ answer_grader = Answer_Grader()
|
|
|
+ score = answer_grader.invoke({"question": question, "generation": generation})
|
|
|
grade = score["score"]
|
|
|
-
|
|
|
- # Check hallucination
|
|
|
if grade in ["yes", "true", 1, "1"]:
|
|
|
- # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
- progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
- # Check question-answering
|
|
|
- # print("---GRADE GENERATION vs QUESTION---")
|
|
|
- progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
|
|
|
- answer_grader = Answer_Grader()
|
|
|
- score = answer_grader.invoke({"question": question, "generation": generation})
|
|
|
- grade = score["score"]
|
|
|
- if grade in ["yes", "true", 1, "1"]:
|
|
|
- # print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
- progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
- return "useful"
|
|
|
- else:
|
|
|
- # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
- progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
- return "not useful"
|
|
|
+ # print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ return "useful"
|
|
|
else:
|
|
|
- # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
- progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
- return "not supported"
|
|
|
+ # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ return "not useful"
|
|
|
+
|
|
|
+
|
|
|
+ # progress_bar = show_progress(state, "---CHECK HALLUCINATIONS---")
|
|
|
+ # # print(docs_documents)
|
|
|
+ # # print(generation)
|
|
|
+ # hallucination_grader = Hallucination_Grader()
|
|
|
+ # score = hallucination_grader.invoke(
|
|
|
+ # {"documents": documents, "generation": generation}
|
|
|
+ # )
|
|
|
+ # # print(score)
|
|
|
+ # grade = score["score"]
|
|
|
+
|
|
|
+ # # Check hallucination
|
|
|
+ # if grade in ["yes", "true", 1, "1"]:
|
|
|
+ # # print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
+ # progress_bar = show_progress(state, "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
+ # # Check question-answering
|
|
|
+ # # print("---GRADE GENERATION vs QUESTION---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
|
|
|
+ # answer_grader = Answer_Grader()
|
|
|
+ # score = answer_grader.invoke({"question": question, "generation": generation})
|
|
|
+ # grade = score["score"]
|
|
|
+ # if grade in ["yes", "true", 1, "1"]:
|
|
|
+ # # print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ # progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ # return "useful"
|
|
|
+ # else:
|
|
|
+ # # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ # progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ # return "not useful"
|
|
|
+ # else:
|
|
|
+ # # pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
+ # progress_bar = show_progress(state, "---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
+ # return "not supported"
|
|
|
|
|
|
def grade_sql_query(state):
|
|
|
"""
|
|
@@ -484,25 +512,34 @@ def grade_sql_query(state):
|
|
|
progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
|
question = state["question"]
|
|
|
sql_query = state["sql_query"]
|
|
|
- retry = state["retry"]
|
|
|
-
|
|
|
- # Score each doc
|
|
|
- sql_query_grader = SQL_Grader()
|
|
|
- score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
|
|
|
- grade = score["score"]
|
|
|
- # Document relevant
|
|
|
- if grade in ["yes", "true", 1, "1"]:
|
|
|
- # print("---GRADE: CORRECT SQL QUERY---")
|
|
|
- progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
|
|
|
- return "correct"
|
|
|
- elif retry >= 5:
|
|
|
- # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
- progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
- return "failed"
|
|
|
- else:
|
|
|
- # print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
- progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
|
|
|
+ sql_result = state["sql_result"]
|
|
|
+ if "None" in sql_result:
|
|
|
+ progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
|
|
|
return "incorrect"
|
|
|
+ else:
|
|
|
+ progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
|
|
|
+ return "correct"
|
|
|
+ # retry = state["retry"]
|
|
|
+
|
|
|
+ # # Score each doc
|
|
|
+ # sql_query_grader = SQL_Grader()
|
|
|
+ # score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
|
|
|
+ # grade = score["score"]
|
|
|
+
|
|
|
+
|
|
|
+ # # Document relevant
|
|
|
+ # if grade in ["yes", "true", 1, "1"]:
|
|
|
+ # # print("---GRADE: CORRECT SQL QUERY---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
|
|
|
+ # return "correct"
|
|
|
+ # elif retry >= 5:
|
|
|
+ # # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
+ # return "failed"
|
|
|
+ # else:
|
|
|
+ # # print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
|
|
|
+ # return "incorrect"
|
|
|
|
|
|
def build_graph():
|
|
|
workflow = StateGraph(GraphState)
|
|
@@ -527,7 +564,6 @@ def build_graph():
|
|
|
"RAG",
|
|
|
grade_generation_v_documents_and_question,
|
|
|
{
|
|
|
- "not supported": "ERROR",
|
|
|
"useful": END,
|
|
|
"not useful": "ERROR",
|
|
|
},
|
|
@@ -537,21 +573,26 @@ def build_graph():
|
|
|
grade_sql_query,
|
|
|
{
|
|
|
"correct": "SQL Answer",
|
|
|
- "incorrect": "ERROR",
|
|
|
- "failed": "RAG"
|
|
|
+ "incorrect": "RAG",
|
|
|
|
|
|
},
|
|
|
)
|
|
|
workflow.add_edge("SQL Answer", "Additoinal Explanation")
|
|
|
- workflow.add_edge("Additoinal Explanation", "RAG")
|
|
|
+ workflow.add_edge("Additoinal Explanation", END)
|
|
|
|
|
|
app = workflow.compile()
|
|
|
|
|
|
return app
|
|
|
|
|
|
+app = build_graph()
|
|
|
+draw_mermaid = app.get_graph().draw_mermaid()
|
|
|
+print(draw_mermaid)
|
|
|
+
|
|
|
def main(question: str):
|
|
|
|
|
|
- app = build_graph()
|
|
|
+ # app = build_graph()
|
|
|
+ # draw_mermaid = app.get_graph().draw_mermaid()
|
|
|
+ # print(draw_mermaid)
|
|
|
#建準去年的類別一排放量?
|
|
|
# inputs = {"question": "溫室氣體是什麼"}
|
|
|
inputs = {"question": question, "progress_bar": None}
|
|
@@ -561,12 +602,14 @@ def main(question: str):
|
|
|
# pprint(value["generation"])
|
|
|
# pprint(value)
|
|
|
value["progress_bar"] = progress_bar
|
|
|
- pprint(value["progress_bar"])
|
|
|
+ # pprint(value["progress_bar"])
|
|
|
|
|
|
return value["generation"]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
# result = main("建準去年的逸散排放總排放量是多少?")
|
|
|
- result = main("建準去年的綠電使用量是多少?")
|
|
|
+ result = main("建準夏威夷去年的綠電使用量是多少?")
|
|
|
+ # result = main("溫室氣體是什麼?")
|
|
|
+ # result = main("什麼是外購電力(綠電)?")
|
|
|
print("------------------------------------------------------")
|
|
|
print(result)
|