|
@@ -24,20 +24,29 @@ db = SQLDatabase.from_uri(URI)
|
|
|
|
|
|
# LLM
|
|
|
# local_llm = "llama3.1:8b-instruct-fp16"
|
|
|
+# local_llm = "llama3.1:8b-instruct-q2_K"
|
|
|
local_llm = "llama3-groq-tool-use:latest"
|
|
|
llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
|
|
|
-llm = ChatOllama(model=local_llm, temperature=0)
|
|
|
+# llm = ChatOllama(model=local_llm, temperature=0)
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
|
|
|
|
|
# RAG usage
|
|
|
-from faiss_index import create_faiss_retriever, faiss_query
|
|
|
+from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
|
|
|
retriever = create_faiss_retriever()
|
|
|
|
|
|
# text-to-sql usage
|
|
|
from text_to_sql_private import run, get_query, query_to_nl, table_description
|
|
|
from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
|
|
|
progress_bar = []
|
|
|
+
|
|
|
def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
-
|
|
|
+ if multi_query:
|
|
|
+ docs = faiss_multiquery(question, retriever, llm)
|
|
|
+ # print(docs)
|
|
|
+ else:
|
|
|
+ docs = retriever.get_relevant_documents(question, k=10)
|
|
|
+ # print(docs)
|
|
|
context = docs
|
|
|
|
|
|
system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
@@ -47,7 +56,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
<|start_header_id|>system<|end_header_id|>
|
|
|
你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
|
|
|
You should not mention anything about "根據提供的文件內容" or other similar terms.
|
|
|
- Use five sentences maximum and keep the answer concise.
|
|
|
+ 請盡可能的詳細回答問題。
|
|
|
如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
勿回答無關資訊
|
|
|
<|eot_id|>
|
|
@@ -58,8 +67,9 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
{context}
|
|
|
|
|
|
Question: {question}
|
|
|
- 用繁體中文回答問題
|
|
|
+ 用繁體中文回答問題,請用一段話詳細的回答。
|
|
|
如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
+
|
|
|
<|eot_id|>
|
|
|
|
|
|
<|start_header_id|>assistant<|end_header_id|>
|
|
@@ -129,10 +139,14 @@ def run_text_to_sql(question: str):
|
|
|
|
|
|
def _get_query(question: str):
|
|
|
selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
|
|
|
+ question = question.replace("美國", "美國 Inc")
|
|
|
+ question = question.replace("法國", "法國 SAS")
|
|
|
query, result = get_query(db, question, selected_table, llm)
|
|
|
return query, result
|
|
|
|
|
|
def _query_to_nl(question: str, query: str, result):
|
|
|
+ question = question.replace("美國", "美國 Inc")
|
|
|
+ question = question.replace("法國", "法國 SAS")
|
|
|
answer = query_to_nl(question, query, result, llm)
|
|
|
return answer
|
|
|
|
|
@@ -150,19 +164,24 @@ def generate_additional_question(sql_query):
|
|
|
def generate_additional_detail(sql_query):
|
|
|
terms = parse_sql_where(sql_query)
|
|
|
answer = ""
|
|
|
+ all_documents = []
|
|
|
for term in list(set(terms)):
|
|
|
if term is None: continue
|
|
|
- question_format = [f"請解釋什麼是{term}?"]
|
|
|
+ question_format = [f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
|
|
|
for question in question_format:
|
|
|
# question = f"什麼是{term}?"
|
|
|
documents = retriever.get_relevant_documents(question, k=5)
|
|
|
- generation = faiss_query(question, documents, llm) + "\n"
|
|
|
- if "test@systex.com" in generation:
|
|
|
- generation = ""
|
|
|
-
|
|
|
- answer += generation
|
|
|
- # print(question)
|
|
|
- # print(generation)
|
|
|
+ all_documents.extend(documents)
|
|
|
+ # for doc in documents:
|
|
|
+ # print(doc)
|
|
|
+ all_question = "\n".join(question_format)
|
|
|
+ generation = faiss_query(all_question, all_documents, llm, multi_query=True) + "\n"
|
|
|
+ if "test@systex.com" in generation:
|
|
|
+ generation = ""
|
|
|
+
|
|
|
+ answer += generation
|
|
|
+ # print(question)
|
|
|
+ # print(generation)
|
|
|
return answer
|
|
|
### SQL Grader
|
|
|
|
|
@@ -177,7 +196,7 @@ def SQL_Grader():
|
|
|
|
|
|
You need to check that each where statement is correctly filtered out what user question need.
|
|
|
|
|
|
- For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
+ For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
"SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
|
|
|
FROM "建準碳排放清冊數據new"
|
|
|
WHERE "事業名稱" like '%建準%'
|
|
@@ -293,19 +312,21 @@ def retrieve_and_generation(state):
|
|
|
# documents = retriever.invoke(question)
|
|
|
# TODO: correct Retrieval function
|
|
|
documents = retriever.get_relevant_documents(question, k=5)
|
|
|
- for doc in documents:
|
|
|
- print(doc)
|
|
|
+ # for doc in documents:
|
|
|
+ # print(doc)
|
|
|
|
|
|
# docs_documents = "\n\n".join(doc.page_content for doc in documents)
|
|
|
# print(documents)
|
|
|
- generation = faiss_query(question, documents, llm)
|
|
|
+ generation = faiss_query(question, documents, llm, multi_query=True)
|
|
|
else:
|
|
|
generation = state["generation"]
|
|
|
|
|
|
for sub_question in list(set(question_list)):
|
|
|
print(sub_question)
|
|
|
- documents = retriever.get_relevant_documents(sub_question, k=10)
|
|
|
- generation += faiss_query(sub_question, documents, llm)
|
|
|
+ documents = retriever.get_relevant_documents(sub_question, k=5)
|
|
|
+ # for doc in documents:
|
|
|
+ # print(doc)
|
|
|
+ generation += faiss_query(sub_question, documents, llm, multi_query=True)
|
|
|
generation += "\n"
|
|
|
|
|
|
print(generation)
|
|
@@ -513,10 +534,11 @@ def grade_sql_query(state):
|
|
|
question = state["question"]
|
|
|
sql_query = state["sql_query"]
|
|
|
sql_result = state["sql_result"]
|
|
|
- if "None" in sql_result:
|
|
|
+ if "None" in sql_result or sql_result.startswith("Error:"):
|
|
|
progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
|
|
|
return "incorrect"
|
|
|
else:
|
|
|
+ print(sql_result)
|
|
|
progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
|
|
|
return "correct"
|
|
|
# retry = state["retry"]
|
|
@@ -540,7 +562,16 @@ def grade_sql_query(state):
|
|
|
# # print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
# progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
|
|
|
# return "incorrect"
|
|
|
-
|
|
|
+def check_sql_answer(state):
|
|
|
+ progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
|
|
|
+ generation = state["generation"]
|
|
|
+ if "test@systex.com" in generation:
|
|
|
+ progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
|
|
|
+ return "bad"
|
|
|
+ else:
|
|
|
+ progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
|
|
|
+ return "good"
|
|
|
+
|
|
|
def build_graph():
|
|
|
workflow = StateGraph(GraphState)
|
|
|
|
|
@@ -550,7 +581,7 @@ def build_graph():
|
|
|
workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
|
workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
|
workflow.add_node("ERROR", error) # retrieve
|
|
|
-
|
|
|
+ company_private_data_search
|
|
|
workflow.add_conditional_edges(
|
|
|
START,
|
|
|
route_question,
|
|
@@ -577,7 +608,17 @@ def build_graph():
|
|
|
|
|
|
},
|
|
|
)
|
|
|
- workflow.add_edge("SQL Answer", "Additoinal Explanation")
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "SQL Answer",
|
|
|
+ check_sql_answer,
|
|
|
+ {
|
|
|
+ "good": "Additoinal Explanation",
|
|
|
+ "bad": "RAG",
|
|
|
+
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ # workflow.add_edge("SQL Answer", "Additoinal Explanation")
|
|
|
workflow.add_edge("Additoinal Explanation", END)
|
|
|
|
|
|
app = workflow.compile()
|
|
@@ -604,11 +645,14 @@ def main(question: str):
|
|
|
value["progress_bar"] = progress_bar
|
|
|
# pprint(value["progress_bar"])
|
|
|
|
|
|
- return value["generation"]
|
|
|
+ # return value["generation"]
|
|
|
+ return value
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
# result = main("建準去年的逸散排放總排放量是多少?")
|
|
|
- result = main("建準夏威夷去年的綠電使用量是多少?")
|
|
|
+ # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
|
|
|
+
|
|
|
+ result = main("建準北海廠去年的固定燃燒排放量是多少?")
|
|
|
# result = main("溫室氣體是什麼?")
|
|
|
# result = main("什麼是外購電力(綠電)?")
|
|
|
print("------------------------------------------------------")
|