|
@@ -0,0 +1,453 @@
|
|
|
+
|
|
|
+from langchain_community.chat_models import ChatOllama
|
|
|
+from langchain_core.output_parsers import JsonOutputParser
|
|
|
+from langchain_core.prompts import PromptTemplate
|
|
|
+
|
|
|
+from langchain.prompts import ChatPromptTemplate
|
|
|
+from langchain_core.output_parsers import StrOutputParser
|
|
|
+
|
|
|
+# graph usage
|
|
|
+from pprint import pprint
|
|
|
+from typing import List
|
|
|
+from langchain_core.documents import Document
|
|
|
+from typing_extensions import TypedDict
|
|
|
+from langgraph.graph import END, StateGraph, START
|
|
|
+from langgraph.pregel import RetryPolicy
|
|
|
+
|
|
|
+# supabase db
|
|
|
+from langchain_community.utilities import SQLDatabase
|
|
|
+import os
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
+URI: str = os.environ.get('SUPABASE_URI')
|
|
|
+db = SQLDatabase.from_uri(URI)
|
|
|
+
|
|
|
+# LLM
|
|
|
+# local_llm = "llama3.1:8b-instruct-fp16"
|
|
|
+local_llm = "llama3-groq-tool-use:latest"
|
|
|
+llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
|
|
|
+llm = ChatOllama(model=local_llm, temperature=0)
|
|
|
+
|
|
|
+# RAG usage
|
|
|
+from faiss_index import create_faiss_retriever, faiss_query
|
|
|
+retriever = create_faiss_retriever()
|
|
|
+
|
|
|
+# text-to-sql usage
|
|
|
+from text_to_sql2 import run, get_query, query_to_nl, table_description
|
|
|
+
|
|
|
+
|
|
|
+def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
|
|
|
+
|
|
|
+ context = docs
|
|
|
+
|
|
|
+ system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
+ template = """
|
|
|
+ <|begin_of_text|>
|
|
|
+
|
|
|
+ <|start_header_id|>system<|end_header_id|>
|
|
|
+ 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
|
|
|
+ You should not mention anything about "根據提供的文件內容" or other similar terms.
|
|
|
+ Use five sentences maximum and keep the answer concise.
|
|
|
+ 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
|
|
|
+ 勿回答無關資訊
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>user<|end_header_id|>
|
|
|
+ Answer the following question based on this context:
|
|
|
+
|
|
|
+ {context}
|
|
|
+
|
|
|
+ Question: {question}
|
|
|
+ 用繁體中文
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>assistant<|end_header_id|>
|
|
|
+ """
|
|
|
+ prompt = ChatPromptTemplate.from_template(
|
|
|
+ system_prompt + "\n\n" +
|
|
|
+ template
|
|
|
+ )
|
|
|
+
|
|
|
+ rag_chain = prompt | llm | StrOutputParser()
|
|
|
+ return rag_chain.invoke({"context": context, "question": question})
|
|
|
+
|
|
|
+### Hallucination Grader
|
|
|
+
|
|
|
+def Hallucination_Grader():
|
|
|
+ # Prompt
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ You are a grader assessing whether an answer is grounded in / supported by a set of facts.
|
|
|
+ Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts.
|
|
|
+ Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation.
|
|
|
+ Return the a JSON with a single key 'score' and no premable or explanation.
|
|
|
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
|
|
|
+ Here are the facts:
|
|
|
+ \n ------- \n
|
|
|
+ {documents}
|
|
|
+ \n ------- \n
|
|
|
+ Here is the answer: {generation}
|
|
|
+ Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
|
|
|
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
|
|
+ input_variables=["generation", "documents"],
|
|
|
+ )
|
|
|
+
|
|
|
+ hallucination_grader = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return hallucination_grader
|
|
|
+
|
|
|
+### Answer Grader
|
|
|
+
|
|
|
+def Answer_Grader():
|
|
|
+ # Prompt
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template="""
|
|
|
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
|
|
|
+ answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
|
|
|
+ useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
|
|
|
+ <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
|
|
|
+ \n ------- \n
|
|
|
+ {generation}
|
|
|
+ \n ------- \n
|
|
|
+ Here is the question: {question}
|
|
|
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
|
|
+ input_variables=["generation", "question"],
|
|
|
+ )
|
|
|
+
|
|
|
+ answer_grader = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return answer_grader
|
|
|
+
|
|
|
+# Text-to-SQL
|
|
|
+def run_text_to_sql(question: str):
|
|
|
+ selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
|
|
|
+ # question = "建準去年的固定燃燒總排放量是多少?"
|
|
|
+ query, result, answer = run(db, question, selected_table, llm)
|
|
|
+
|
|
|
+ return answer, query
|
|
|
+
|
|
|
+def _get_query(question: str):
|
|
|
+ selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
|
|
|
+ query = get_query(db, question, selected_table, llm)
|
|
|
+ return query
|
|
|
+
|
|
|
+def _query_to_nl(question: str, query: str):
|
|
|
+ answer = query_to_nl(db, question, query, llm)
|
|
|
+ return answer
|
|
|
+
|
|
|
+
|
|
|
+### SQL Grader
|
|
|
+
|
|
|
+def SQL_Grader():
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ You are a SQL query grader assessing correctness of PostgreSQL query to a user question.
|
|
|
+ Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
|
|
|
+
|
|
|
+ Here is database description:
|
|
|
+ {table_info}
|
|
|
+
|
|
|
+ You need to check that each where statement is correctly filtered out what user question need.
|
|
|
+
|
|
|
+ For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
+ "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
|
|
|
+ FROM "104_112碳排放公開及建準資料"
|
|
|
+ WHERE "事業名稱" like '%建準%'
|
|
|
+ AND "排放源" = '下游租賃'
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
|
|
|
+ For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
|
|
|
+
|
|
|
+ Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
|
|
|
+ "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
|
|
|
+ FROM "104_112碳排放公開及建準資料"
|
|
|
+ WHERE "事業名稱" like '%台積電%'
|
|
|
+ AND "排放源" = '固定燃燒'
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
|
|
|
+ For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
|
|
|
+
|
|
|
+ and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
|
|
|
+
|
|
|
+ If the PostgreSQL query do not exactly matches the user question, grade it as incorrect.
|
|
|
+ You need to strictly examine whether the sql PostgreSQL query matches the user question.
|
|
|
+ Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
|
|
|
+ Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>user<|end_header_id|>
|
|
|
+ Here is the PostgreSQL query: \n\n {sql_query} \n\n
|
|
|
+ Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
|
+ """,
|
|
|
+ input_variables=["table_info", "question", "sql_query"],
|
|
|
+ )
|
|
|
+
|
|
|
+ sql_query_grader = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return sql_query_grader
|
|
|
+
|
|
|
+### Router
|
|
|
+def Router():
|
|
|
+ prompt = PromptTemplate(
|
|
|
+ template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ You are an expert at routing a user question to a vectorstore or company private data.
|
|
|
+ Use company private data for questions about the informations about a company's greenhouse gas emissions data.
|
|
|
+ Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG.
|
|
|
+ You do not need to be stringent with the keywords in the question related to these topics.
|
|
|
+ Give a binary choice 'company_private_data' or 'vectorstore' based on the question.
|
|
|
+ Return the a JSON with a single key 'datasource' and no premable or explanation.
|
|
|
+
|
|
|
+ Question to route: {question}
|
|
|
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
|
|
|
+ input_variables=["question"],
|
|
|
+ )
|
|
|
+
|
|
|
+ question_router = prompt | llm_json | JsonOutputParser()
|
|
|
+
|
|
|
+ return question_router
|
|
|
+
|
|
|
+class GraphState(TypedDict):
|
|
|
+ """
|
|
|
+ Represents the state of our graph.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ question: question
|
|
|
+ generation: LLM generation
|
|
|
+ company_private_data: whether to search company private data
|
|
|
+ documents: list of documents
|
|
|
+ """
|
|
|
+
|
|
|
+ question: str
|
|
|
+ generation: str
|
|
|
+ documents: List[str]
|
|
|
+ retry: int
|
|
|
+ sql_query: str
|
|
|
+
|
|
|
+# Node
|
|
|
+def retrieve_and_generation(state):
|
|
|
+ """
|
|
|
+ Retrieve documents from vectorstore
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
|
|
|
+ """
|
|
|
+ print("---RETRIEVE---")
|
|
|
+ question = state["question"]
|
|
|
+
|
|
|
+ # Retrieval
|
|
|
+ # documents = retriever.invoke(question)
|
|
|
+ # TODO: correct Retrieval function
|
|
|
+ documents = retriever.get_relevant_documents(question, k=30)
|
|
|
+ # docs_documents = "\n\n".join(doc.page_content for doc in documents)
|
|
|
+ # print(documents)
|
|
|
+ generation = faiss_query(question, documents, llm)
|
|
|
+ return {"documents": documents, "question": question, "generation": generation}
|
|
|
+
|
|
|
+def company_private_data_get_sql_query(state):
|
|
|
+ """
|
|
|
+ Get PostgreSQL query according to question
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): return generated PostgreSQL query and record retry times
|
|
|
+ """
|
|
|
+ print("---SQL QUERY---")
|
|
|
+ question = state["question"]
|
|
|
+
|
|
|
+ if state["retry"]:
|
|
|
+ retry = state["retry"]
|
|
|
+ retry += 1
|
|
|
+ else:
|
|
|
+ retry = 0
|
|
|
+ # print("RETRY: ", retry)
|
|
|
+
|
|
|
+ sql_query = _get_query(question)
|
|
|
+
|
|
|
+ return {"sql_query": sql_query, "question": question, "retry": retry}
|
|
|
+
|
|
|
+def company_private_data_search(state):
|
|
|
+ """
|
|
|
+ Execute PostgreSQL query and convert to nature language.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): Appended sql results to state
|
|
|
+ """
|
|
|
+
|
|
|
+ print("---SQL TO NL---")
|
|
|
+ # print(state)
|
|
|
+ question = state["question"]
|
|
|
+ sql_query = state["sql_query"]
|
|
|
+ generation = _query_to_nl(question, sql_query)
|
|
|
+
|
|
|
+ # generation = [company_private_data_result]
|
|
|
+
|
|
|
+ return {"sql_query": sql_query, "question": question, "generation": generation}
|
|
|
+
|
|
|
+### Conditional edge
|
|
|
+
|
|
|
+
|
|
|
+def route_question(state):
|
|
|
+ """
|
|
|
+ Route question to web search or RAG.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: Next node to call
|
|
|
+ """
|
|
|
+
|
|
|
+ print("---ROUTE QUESTION---")
|
|
|
+ question = state["question"]
|
|
|
+ # print(question)
|
|
|
+ question_router = Router()
|
|
|
+ source = question_router.invoke({"question": question})
|
|
|
+ # print(source)
|
|
|
+ print(source["datasource"])
|
|
|
+ if source["datasource"] == "company_private_data":
|
|
|
+ print("---ROUTE QUESTION TO TEXT-TO-SQL---")
|
|
|
+ return "company_private_data"
|
|
|
+ elif source["datasource"] == "vectorstore":
|
|
|
+ print("---ROUTE QUESTION TO RAG---")
|
|
|
+ return "vectorstore"
|
|
|
+
|
|
|
+def grade_generation_v_documents_and_question(state):
|
|
|
+ """
|
|
|
+ Determines whether the generation is grounded in the document and answers question.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: Decision for next node to call
|
|
|
+ """
|
|
|
+
|
|
|
+ print("---CHECK HALLUCINATIONS---")
|
|
|
+ question = state["question"]
|
|
|
+ documents = state["documents"]
|
|
|
+ generation = state["generation"]
|
|
|
+
|
|
|
+
|
|
|
+ # print(docs_documents)
|
|
|
+ # print(generation)
|
|
|
+ hallucination_grader = Hallucination_Grader()
|
|
|
+ score = hallucination_grader.invoke(
|
|
|
+ {"documents": documents, "generation": generation}
|
|
|
+ )
|
|
|
+ # print(score)
|
|
|
+ grade = score["score"]
|
|
|
+
|
|
|
+ # Check hallucination
|
|
|
+ if grade in ["yes", "true", 1, "1"]:
|
|
|
+ print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
|
|
|
+ # Check question-answering
|
|
|
+ print("---GRADE GENERATION vs QUESTION---")
|
|
|
+ answer_grader = Answer_Grader()
|
|
|
+ score = answer_grader.invoke({"question": question, "generation": generation})
|
|
|
+ grade = score["score"]
|
|
|
+ if grade in ["yes", "true", 1, "1"]:
|
|
|
+ print("---DECISION: GENERATION ADDRESSES QUESTION---")
|
|
|
+ return "useful"
|
|
|
+ else:
|
|
|
+ print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
|
|
|
+ return "not useful"
|
|
|
+ else:
|
|
|
+ pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
|
|
|
+ return "not supported"
|
|
|
+
|
|
|
+def grade_sql_query(state):
|
|
|
+ """
|
|
|
+ Determines whether the Postgresql query are correct to the question
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state (dict): The current graph state
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ state (dict): Decision for retry or continue
|
|
|
+ """
|
|
|
+
|
|
|
+ print("---CHECK SQL CORRECTNESS TO QUESTION---")
|
|
|
+ question = state["question"]
|
|
|
+ sql_query = state["sql_query"]
|
|
|
+ retry = state["retry"]
|
|
|
+
|
|
|
+ # Score each doc
|
|
|
+ sql_query_grader = SQL_Grader()
|
|
|
+ score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
|
|
|
+ grade = score["score"]
|
|
|
+ # Document relevant
|
|
|
+ if grade in ["yes", "true", 1, "1"]:
|
|
|
+ print("---GRADE: CORRECT SQL QUERY---")
|
|
|
+ return "correct"
|
|
|
+ elif retry >= 5:
|
|
|
+ print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
|
|
|
+ return "failed"
|
|
|
+ else:
|
|
|
+ print("---GRADE: INCORRECT SQL QUERY---")
|
|
|
+ return "incorrect"
|
|
|
+
|
|
|
+def build_graph():
|
|
|
+ workflow = StateGraph(GraphState)
|
|
|
+
|
|
|
+ # Define the nodes
|
|
|
+ workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search
|
|
|
+ workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
|
|
|
+ workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
|
|
|
+
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ START,
|
|
|
+ route_question,
|
|
|
+ {
|
|
|
+ "company_private_data": "company_private_data_query",
|
|
|
+ "vectorstore": "retrieve_and_generation",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "retrieve_and_generation",
|
|
|
+ grade_generation_v_documents_and_question,
|
|
|
+ {
|
|
|
+ "not supported": "retrieve_and_generation",
|
|
|
+ "useful": END,
|
|
|
+ "not useful": "retrieve_and_generation",
|
|
|
+ },
|
|
|
+ )
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "company_private_data_query",
|
|
|
+ grade_sql_query,
|
|
|
+ {
|
|
|
+ "correct": "company_private_data_search",
|
|
|
+ "incorrect": "company_private_data_query",
|
|
|
+ "failed": END
|
|
|
+
|
|
|
+ },
|
|
|
+ )
|
|
|
+ workflow.add_edge("company_private_data_search", END)
|
|
|
+
|
|
|
+ app = workflow.compile()
|
|
|
+
|
|
|
+ return app
|
|
|
+
|
|
|
+def main():
|
|
|
+ app = build_graph()
|
|
|
+ #建準去年的類別一排放量?
|
|
|
+ inputs = {"question": "溫室氣體是什麼"}
|
|
|
+ for output in app.stream(inputs, {"recursion_limit": 10}):
|
|
|
+ for key, value in output.items():
|
|
|
+ pprint(f"Finished running: {key}:")
|
|
|
+ pprint(value["generation"])
|
|
|
+
|
|
|
+ return value["generation"]
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ result = main()
|
|
|
+ print("------------------------------------------------------")
|
|
|
+ print(result)
|