|
@@ -0,0 +1,609 @@
|
|
|
+
|
|
|
+from langchain_community.chat_models import ChatOllama
|
|
|
+from langchain_core.output_parsers import JsonOutputParser
|
|
|
+from langchain_core.prompts import PromptTemplate
|
|
|
+
|
|
|
+from langchain.prompts import ChatPromptTemplate
|
|
|
+from langchain_core.output_parsers import StrOutputParser
|
|
|
+
|
|
|
+# graph usage
|
|
|
+from pprint import pprint
|
|
|
+from typing import List
|
|
|
+from langchain_core.documents import Document
|
|
|
+from typing_extensions import TypedDict
|
|
|
+from langgraph.graph import END, StateGraph, START
|
|
|
+from langgraph.pregel import RetryPolicy
|
|
|
+
|
|
|
+# supabase db
|
|
|
+from langchain_community.utilities import SQLDatabase
|
|
|
+import os
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
+URI: str = os.environ.get('SUPABASE_URI')
|
|
|
+db = SQLDatabase.from_uri(URI)
|
|
|
+
|
|
|
+# LLM
|
|
|
+# local_llm = "llama3.1:8b-instruct-fp16"
|
|
|
+# local_llm = "llama3.1:8b-instruct-q2_K"
|
|
|
+local_llm = "llama3-groq-tool-use:latest"
|
|
|
+llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
|
|
|
+local_llm = "cwchang/llama3-taide-lx-8b-chat-alpha1:q3_k_s"
|
|
|
+llm = ChatOllama(model=local_llm, temperature=0)
|
|
|
+sql_llm = ChatOllama(model="codeqwen", temperature=0)
|
|
|
+# sql_llm = ChatOllama(model="eramax/nxcode-cq-7b-orpo:q6", temperature=0)
|
|
|
+
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+# sql_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
|
|
|
+
|
|
|
+# RAG usage
|
|
|
+from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
|
|
|
+retriever = create_faiss_retriever()
|
|
|
+
|
|
|
+# text-to-sql usage
|
|
|
+from text_to_sql_private import run, get_query, query_to_nl, table_description
|
|
|
+from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
|
|
|
+progress_bar = []
|
|
|
+
|
|
|
+def faiss_query(question: str, llm, docs=None, multi_query: bool = False) -> str:
|
|
|
+ if multi_query:
|
|
|
+ docs = faiss_multiquery(question, retriever, llm, k=4)
|
|
|
+ # print(docs)
|
|
|
+ elif docs:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ docs = retriever.get_relevant_documents(question, k=10)
|
|
|
+ # print(docs)
|
|
|
+ context = docs
|
|
|
+
|
|
|
+ system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
+ template = """
|
|
|
+ <|begin_of_text|>
|
|
|
+
|
|
|
+ <|start_header_id|>system<|end_header_id|>
|
|
|
+ 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
|
|
|
+ You should not mention anything about "根據提供的文件內容" or other similar terms.
|
|
|
+ 請盡可能的詳細回答問題。
|
|
|
+ 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
+ 勿回答無關資訊或任何與某特定公司相關的問題。
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>user<|end_header_id|>
|
|
|
+ Answer the following question based on this context:
|
|
|
+
|
|
|
+ {context}
|
|
|
+
|
|
|
+ Question: {question}
|
|
|
+ 用繁體中文回答問題,請用一段話詳細的回答。勿回答無關資訊或任何與某特定公司相關的問題。
|
|
|
+ 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
+
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>assistant<|end_header_id|>
|
|
|
+ """
|
|
|
+ prompt = ChatPromptTemplate.from_template(
|
|
|
+ system_prompt + "\n\n" +
|
|
|
+ template
|
|
|
+ )
|
|
|
+
|
|
|
+ rag_chain = prompt | llm | StrOutputParser()
|
|
|
+ return docs, rag_chain.invoke({"context": context, "question": question})
|
|
|
+
|
|
|
+### Hallucination Grader
|
|
|
+
|
|
|
+def Hallucination_Grader():
|
|
|
+ # Prompt
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ You are a grader assessing whether an answer is grounded in / supported by a set of facts.
|
|
|
+ Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts.
|
|
|
+ Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation.
|
|
|
+ Return the a JSON with a single key 'score' and no premable or explanation.
|
|
|
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
|
+ Here are the facts:
|
|
|
+ \n ------- \n
|
|
|
+ {documents}
|
|
|
+ \n ------- \n
|
|
|
+ Here is the answer: {generation}
|
|
|
+ Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
|
|
|
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
|
|
+ input_variables=["generation", "documents"],
|
|
|
+ )
|
|
|
+
|
|
|
+ hallucination_grader = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return hallucination_grader
|
|
|
+
|
|
|
+### Answer Grader
|
|
|
+
|
|
|
+def Answer_Grader():
|
|
|
+ # Prompt
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template="""
|
|
|
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
|
|
|
+ answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
|
|
|
+ useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
|
|
|
+ <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
|
|
|
+ \n ------- \n
|
|
|
+ {generation}
|
|
|
+ \n ------- \n
|
|
|
+ Here is the question: {question}
|
|
|
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
|
|
+ input_variables=["generation", "question"],
|
|
|
+ )
|
|
|
+
|
|
|
+ answer_grader = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return answer_grader
|
|
|
+
|
|
|
+# Text-to-SQL
|
|
|
+# def run_text_to_sql(question: str):
|
|
|
+# selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
|
|
|
+# # question = "建準去年的固定燃燒總排放量是多少?"
|
|
|
+# query, result, answer = run(db, question, selected_table, sql_llm)
|
|
|
+
|
|
|
+# return answer, query
|
|
|
+
|
|
|
+def _get_query(question: str):
|
|
|
+ selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
|
|
|
+ question = question.replace("美國", "美國 Inc")
|
|
|
+ question = question.replace("法國", "法國 SAS")
|
|
|
+
|
|
|
+ query, result = get_query(db, question, selected_table, sql_llm)
|
|
|
+ return query, result
|
|
|
+
|
|
|
+def _query_to_nl(question: str, query: str, result):
|
|
|
+ question = question.replace("美國", "美國 Inc")
|
|
|
+ question = question.replace("法國", "法國 SAS")
|
|
|
+ local_llm = "llama3-groq-tool-use:latest"
|
|
|
+ llm = ChatOllama(model=local_llm, temperature=0)
|
|
|
+ answer = query_to_nl(question, query, result, llm)
|
|
|
+ return answer
|
|
|
+
|
|
|
+def generate_additional_question(sql_query):
|
|
|
+ terms = parse_sql_where(sql_query)
|
|
|
+ question_list = []
|
|
|
+ for term in terms:
|
|
|
+ if term is None: continue
|
|
|
+ question_format = [f"什麼是{term}?", f"{term}的用途是什麼"]
|
|
|
+ question_list.extend(question_format)
|
|
|
+
|
|
|
+ return question_list
|
|
|
+
|
|
|
+
|
|
|
+def generate_additional_detail(sql_query):
|
|
|
+ terms = parse_sql_where(sql_query)
|
|
|
+ answer = ""
|
|
|
+ all_documents = []
|
|
|
+ for term in list(set(terms)):
|
|
|
+ print(term)
|
|
|
+ if term is None: continue
|
|
|
+ question_format = [ f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
|
|
|
+ for question in question_format:
|
|
|
+ documents = retriever.get_relevant_documents(question, k=5)
|
|
|
+ all_documents.extend(documents)
|
|
|
+
|
|
|
+ all_question = "".join(question_format)
|
|
|
+ documents, generation = faiss_query(all_question, llm, docs=all_documents, multi_query=True)
|
|
|
+
|
|
|
+ if "test@systex.com" in generation:
|
|
|
+ generation = ""
|
|
|
+
|
|
|
+ answer += generation
|
|
|
+ # print(question)
|
|
|
+ # print(generation)
|
|
|
+ return answer
|
|
|
+### SQL Grader
|
|
|
+
|
|
|
+def SQL_Grader():
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ You are a SQL query grader assessing correctness of PostgreSQL query to a user question.
|
|
|
+ Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
|
|
|
+
|
|
|
+ Here is database description:
|
|
|
+ {table_info}
|
|
|
+
|
|
|
+ You need to check that each where statement is correctly filtered out what user question need.
|
|
|
+
|
|
|
+ For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
+ "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
|
|
|
+ FROM "建準碳排放清冊數據new"
|
|
|
+ WHERE "事業名稱" like '%建準%'
|
|
|
+ AND "排放源" = '下游租賃'
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
|
|
|
+ For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
|
|
|
+
|
|
|
+ Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
+ "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
|
|
|
+ FROM "建準碳排放清冊數據new"
|
|
|
+ WHERE "事業名稱" like '%台積電%'
|
|
|
+ AND "排放源" = '固定燃燒'
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
|
|
|
+ For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
|
|
|
+
|
|
|
+ and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
|
|
|
+
|
|
|
+ If the PostgreSQL query do not exactly matches the user question, grade it as incorrect.
|
|
|
+ You need to strictly examine whether the sql PostgreSQL query matches the user question.
|
|
|
+ Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
|
|
|
+ Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>user<|end_header_id|>
|
|
|
+ Here is the PostgreSQL query: \n\n {sql_query} \n\n
|
|
|
+ Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
|
+ """,
|
|
|
+ input_variables=["table_info", "question", "sql_query"],
|
|
|
+ )
|
|
|
+
|
|
|
+ sql_query_grader = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return sql_query_grader
|
|
|
+
|
|
|
+### Router
|
|
|
+def Router():
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ You are an expert at routing a user question to a 專業知識 or 自有數據.
|
|
|
+ 你需要分辨使用者問題是否在詢問某個公司與其據點廠房的自有數據或是尋求專業的碳盤查或碳管理等等的 ESG 知識和相關新聞,
|
|
|
+ 如果問題是想了解某個公司與其據點廠房的碳排放源的排放量或用電、用水量等等,請使用"自有數據",
|
|
|
+ 若使用者的問題是想了解碳盤查、碳交易或碳管理等等的 ESG 知識和相關新聞,請使用"專業知識"。
|
|
|
+ You do not need to be stringent with the keywords in the question related to these topics.
|
|
|
+ Give a binary choice '自有數據' or '專業知識' based on the question.
|
|
|
+ Return the a JSON with a single key 'datasource' and no premable or explanation.
|
|
|
+
|
|
|
+ Question to route: {question}
|
|
|
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
|
|
+ input_variables=["question"],
|
|
|
+ )
|
|
|
+
|
|
|
+ question_router = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return question_router
|
|
|
+
|
|
|
+class GraphState(TypedDict):
|
|
|
+ """
|
|
|
+ Represents the state of our graph.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ question: question
|
|
|
+ generation: LLM generation
|
|
|
+ company_private_data: whether to search company private data
|
|
|
+ documents: list of documents
|
|
|
+ """
|
|
|
+
|
|
|
+ progress_bar: List[str]
|
|
|
+ route: str
|
|
|
+ question: str
|
|
|
+ question_list: List[str]
|
|
|
+ generation: str
|
|
|
+ documents: List[str]
|
|
|
+ retry: int
|
|
|
+ sql_query: str
|
|
|
+ sql_result: str
|
|
|
+
|
|
|
+# Node
|
|
|
+def show_progress(state, progress: str):
|
|
|
+ global progress_bar
|
|
|
+ # progress_bar = state["progress_bar"] if state["progress_bar"] else []
|
|
|
+
|
|
|
+ print(progress)
|
|
|
+ progress_bar.append(progress)
|
|
|
+
|
|
|
+ return progress_bar
|
|
|
+
|
|
|
+def retrieve_and_generation(state):
|
|
|
+ """
|
|
|
+ Retrieve documents from vectorstore
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
|
|
|
+ """
|
|
|
+ progress_bar = show_progress(state, "---RETRIEVE---")
|
|
|
+ if not state["route"]:
|
|
|
+ route = "RAG"
|
|
|
+ else:
|
|
|
+ route = state["route"]
|
|
|
+ question = state["question"]
|
|
|
+ documents, generation = faiss_query(question, llm, multi_query=True)
|
|
|
+ print(generation)
|
|
|
+
|
|
|
+ return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
|
|
|
+
|
|
|
+def company_private_data_get_sql_query(state):
|
|
|
+ """
|
|
|
+ Get PostgreSQL query according to question
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): return generated PostgreSQL query and record retry times
|
|
|
+ """
|
|
|
+ # print("---SQL QUERY---")
|
|
|
+ progress_bar = show_progress(state, "---SQL QUERY---")
|
|
|
+ if not state["route"]:
|
|
|
+ route = "SQL"
|
|
|
+ else:
|
|
|
+ route = state["route"]
|
|
|
+ question = state["question"]
|
|
|
+
|
|
|
+ if state["retry"]:
|
|
|
+ retry = state["retry"]
|
|
|
+ retry += 1
|
|
|
+ else:
|
|
|
+ retry = 0
|
|
|
+ # print("RETRY: ", retry)
|
|
|
+
|
|
|
+ sql_query, sql_result = _get_query(question)
|
|
|
+ print(type(sql_result))
|
|
|
+
|
|
|
+ return {"progress_bar": progress_bar, "route": route, "sql_query": sql_query, "sql_result": sql_result, "question": question, "retry": retry}
|
|
|
+
|
|
|
+def company_private_data_search(state):
|
|
|
+ """
|
|
|
+ Execute PostgreSQL query and convert to nature language.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): Appended sql results to state
|
|
|
+ """
|
|
|
+
|
|
|
+ # print("---SQL TO NL---")
|
|
|
+ progress_bar = show_progress(state, "---SQL TO NL---")
|
|
|
+ # print(state)
|
|
|
+ question = state["question"]
|
|
|
+ sql_query = state["sql_query"]
|
|
|
+ sql_result = state["sql_result"]
|
|
|
+ generation = _query_to_nl(question, sql_query, sql_result)
|
|
|
+
|
|
|
+ # generation = [company_private_data_result]
|
|
|
+
|
|
|
+ return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation}
|
|
|
+
|
|
|
+def additional_explanation_question(state):
|
|
|
+ """
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (_type_): _description_
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): Appended additional explanation to state
|
|
|
+ """
|
|
|
+
|
|
|
+ # print("---ADDITIONAL EXPLANATION---")
|
|
|
+ progress_bar = show_progress(state, "---ADDITIONAL EXPLANATION---")
|
|
|
+ # print(state)
|
|
|
+ question = state["question"]
|
|
|
+ sql_query = state["sql_query"]
|
|
|
+ # print(sql_query)
|
|
|
+ generation = state["generation"]
|
|
|
+ generation += "\n"
|
|
|
+ generation += generate_additional_detail(sql_query)
|
|
|
+ question_list = []
|
|
|
+
|
|
|
+ # question_list = generate_additional_question(sql_query)
|
|
|
+ # print(question_list)
|
|
|
+
|
|
|
+ # generation = [company_private_data_result]
|
|
|
+
|
|
|
+ return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
|
|
|
+
|
|
|
+def error(state):
|
|
|
+ # print("---SOMETHING WENT WRONG---")
|
|
|
+ progress_bar = show_progress(state, "---SOMETHING WENT WRONG---")
|
|
|
+ generation = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
+
|
|
|
+ return {"progress_bar": progress_bar, "generation": generation}
|
|
|
+
|
|
|
+### Conditional edge
|
|
|
+
|
|
|
+
|
|
|
+def route_question(state):
|
|
|
+ """
|
|
|
+ Route question to web search or RAG.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: Next node to call
|
|
|
+ """
|
|
|
+
|
|
|
+ # print("---ROUTE QUESTION---")
|
|
|
+ progress_bar = show_progress(state, "---ROUTE QUESTION---")
|
|
|
+ question = state["question"]
|
|
|
+ # print(question)
|
|
|
+ question_router = Router()
|
|
|
+ source = question_router.invoke({"question": question})
|
|
|
+ print("Original:", source["datasource"])
|
|
|
+ # if "建準" in question:
|
|
|
+ kw = ["建準", "北海", "廣興", "崑山廣興", "Inc", "SAS", "立準"]
|
|
|
+ if any(char in question for char in kw):
|
|
|
+ source["datasource"] = "自有數據"
|
|
|
+ elif "範例" in question:
|
|
|
+ source["datasource"] = "專業知識"
|
|
|
+
|
|
|
+ # print(source)
|
|
|
+ print(source["datasource"])
|
|
|
+ if source["datasource"] == "自有數據":
|
|
|
+ # print("---ROUTE QUESTION TO TEXT-TO-SQL---")
|
|
|
+ progress_bar = show_progress(state, "---ROUTE QUESTION TO TEXT-TO-SQL---")
|
|
|
+ return "自有數據"
|
|
|
+ elif source["datasource"] == "專業知識":
|
|
|
+ # print("---ROUTE QUESTION TO RAG---")
|
|
|
+ progress_bar = show_progress(state, "---ROUTE QUESTION TO RAG---")
|
|
|
+ return "專業知識"
|
|
|
+
|
|
|
+def grade_generation_v_documents_and_question(state):
|
|
|
+ """
|
|
|
+ Determines whether the generation is grounded in the document and answers question.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: Decision for next node to call
|
|
|
+ """
|
|
|
+
|
|
|
+ # print("---CHECK HALLUCINATIONS---")
|
|
|
+ question = state["question"]
|
|
|
+ documents = state["documents"]
|
|
|
+ generation = state["generation"]
|
|
|
+
|
|
|
+ progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
|
|
|
+ answer_grader = Answer_Grader()
|
|
|
+ score = answer_grader.invoke({"question": question, "generation": generation})
|
|
|
+ print(score)
|
|
|
+ grade = score["score"]
|
|
|
+ if grade in ["yes", "true", 1, "1"]:
|
|
|
+ # print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ return "useful"
|
|
|
+ else:
|
|
|
+ # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ return "not useful"
|
|
|
+
|
|
|
+def grade_sql_query(state):
|
|
|
+ """
|
|
|
+ Determines whether the Postgresql query are correct to the question
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): Decision for retry or continue
|
|
|
+ """
|
|
|
+
|
|
|
+ # print("---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
|
+ progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
|
+ question = state["question"]
|
|
|
+ sql_query = state["sql_query"]
|
|
|
+ sql_result = state["sql_result"]
|
|
|
+ if "None" in sql_result or sql_result.startswith("Error:"):
|
|
|
+ progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
|
|
|
+ return "incorrect"
|
|
|
+ else:
|
|
|
+ print(sql_result)
|
|
|
+ progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
|
|
|
+ return "correct"
|
|
|
+ # retry = state["retry"]
|
|
|
+
|
|
|
+ # # Score each doc
|
|
|
+ # sql_query_grader = SQL_Grader()
|
|
|
+ # score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
|
|
|
+ # grade = score["score"]
|
|
|
+
|
|
|
+
|
|
|
+ # # Document relevant
|
|
|
+ # if grade in ["yes", "true", 1, "1"]:
|
|
|
+ # # print("---GRADE: CORRECT SQL QUERY---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
|
|
|
+ # return "correct"
|
|
|
+ # elif retry >= 5:
|
|
|
+ # # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
+ # return "failed"
|
|
|
+ # else:
|
|
|
+ # # print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
+ # progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
|
|
|
+ # return "incorrect"
|
|
|
+
|
|
|
+def check_sql_answer(state):
|
|
|
+ progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
|
|
|
+ generation = state["generation"]
|
|
|
+ if "test@systex.com" in generation:
|
|
|
+ progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
|
|
|
+ return "bad"
|
|
|
+ else:
|
|
|
+ progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
|
|
|
+ return "good"
|
|
|
+
|
|
|
+def build_graph():
|
|
|
+ workflow = StateGraph(GraphState)
|
|
|
+
|
|
|
+ # Define the nodes
|
|
|
+ workflow.add_node("Text-to-SQL", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search
|
|
|
+ workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
|
|
|
+ workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
|
+ workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
|
+ workflow.add_node("ERROR", error) # retrieve
|
|
|
+ company_private_data_search
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ START,
|
|
|
+ route_question,
|
|
|
+ {
|
|
|
+ "自有數據": "Text-to-SQL",
|
|
|
+ "專業知識": "RAG",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "RAG",
|
|
|
+ grade_generation_v_documents_and_question,
|
|
|
+ {
|
|
|
+ "useful": END,
|
|
|
+ "not useful": "ERROR",
|
|
|
+ },
|
|
|
+ )
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "Text-to-SQL",
|
|
|
+ grade_sql_query,
|
|
|
+ {
|
|
|
+ "correct": "SQL Answer",
|
|
|
+ "incorrect": "RAG",
|
|
|
+
|
|
|
+ },
|
|
|
+ )
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "SQL Answer",
|
|
|
+ check_sql_answer,
|
|
|
+ {
|
|
|
+ "good": "Additoinal Explanation",
|
|
|
+ "bad": "RAG",
|
|
|
+
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ # workflow.add_edge("SQL Answer", "Additoinal Explanation")
|
|
|
+ workflow.add_edge("Additoinal Explanation", END)
|
|
|
+
|
|
|
+ app = workflow.compile()
|
|
|
+
|
|
|
+ return app
|
|
|
+
|
|
|
+app = build_graph()
|
|
|
+draw_mermaid = app.get_graph().draw_mermaid()
|
|
|
+print(draw_mermaid)
|
|
|
+
|
|
|
+def main(question: str):
|
|
|
+
|
|
|
+ inputs = {"question": question, "progress_bar": None}
|
|
|
+ for output in app.stream(inputs, {"recursion_limit": 10}):
|
|
|
+ for key, value in output.items():
|
|
|
+ pprint(f"Finished running: {key}:")
|
|
|
+ # pprint(value["generation"])
|
|
|
+ # pprint(value)
|
|
|
+ value["progress_bar"] = progress_bar
|
|
|
+ # pprint(value["progress_bar"])
|
|
|
+
|
|
|
+ # return value["generation"]
|
|
|
+ return value
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # result = main("建準去年的逸散排放總排放量是多少?")
|
|
|
+ # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
|
|
|
+
|
|
|
+ result = main("建準北海廠去年的固定燃燒排放量是多少?")
|
|
|
+ # result = main("溫室氣體是什麼?")
|
|
|
+ # result = main("什麼是外購電力(綠電)?")
|
|
|
+ print("------------------------------------------------------")
|
|
|
+ print(result)
|