ai_agent.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. from langchain_community.chat_models import ChatOllama
  2. from langchain_core.output_parsers import JsonOutputParser
  3. from langchain_core.prompts import PromptTemplate
  4. from langchain.prompts import ChatPromptTemplate
  5. from langchain_core.output_parsers import StrOutputParser
  6. # graph usage
  7. from pprint import pprint
  8. from typing import List
  9. from langchain_core.documents import Document
  10. from typing_extensions import TypedDict
  11. from langgraph.graph import END, StateGraph, START
  12. from langgraph.pregel import RetryPolicy
  13. # supabase db
  14. from langchain_community.utilities import SQLDatabase
  15. import os
  16. from dotenv import load_dotenv
  17. load_dotenv()
  18. URI: str = os.environ.get('SUPABASE_URI')
  19. db = SQLDatabase.from_uri(URI)
  20. # LLM
  21. # local_llm = "llama3.1:8b-instruct-fp16"
  22. local_llm = "llama3-groq-tool-use:latest"
  23. llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
  24. llm = ChatOllama(model=local_llm, temperature=0)
  25. # RAG usage
  26. from faiss_index import create_faiss_retriever, faiss_query
  27. retriever = create_faiss_retriever()
  28. # text-to-sql usage
  29. from text_to_sql2 import run, get_query, query_to_nl, table_description
  30. def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
  31. context = docs
  32. system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  33. template = """
  34. <|begin_of_text|>
  35. <|start_header_id|>system<|end_header_id|>
  36. 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
  37. You should not mention anything about "根據提供的文件內容" or other similar terms.
  38. Use five sentences maximum and keep the answer concise.
  39. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  40. 勿回答無關資訊
  41. <|eot_id|>
  42. <|start_header_id|>user<|end_header_id|>
  43. Answer the following question based on this context:
  44. {context}
  45. Question: {question}
  46. 用繁體中文
  47. <|eot_id|>
  48. <|start_header_id|>assistant<|end_header_id|>
  49. """
  50. prompt = ChatPromptTemplate.from_template(
  51. system_prompt + "\n\n" +
  52. template
  53. )
  54. rag_chain = prompt | llm | StrOutputParser()
  55. return rag_chain.invoke({"context": context, "question": question})
  56. ### Hallucination Grader
  57. def Hallucination_Grader():
  58. # Prompt
  59. prompt = PromptTemplate(
  60. template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  61. You are a grader assessing whether an answer is grounded in / supported by a set of facts.
  62. Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts.
  63. Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation.
  64. Return the a JSON with a single key 'score' and no premable or explanation.
  65. <|eot_id|><|start_header_id|>user<|end_header_id|>
  66. Here are the facts:
  67. \n ------- \n
  68. {documents}
  69. \n ------- \n
  70. Here is the answer: {generation}
  71. Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
  72. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  73. input_variables=["generation", "documents"],
  74. )
  75. hallucination_grader = prompt | llm_json | JsonOutputParser()
  76. return hallucination_grader
  77. ### Answer Grader
  78. def Answer_Grader():
  79. # Prompt
  80. prompt = PromptTemplate(
  81. template="""
  82. <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
  83. answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
  84. useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
  85. <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
  86. \n ------- \n
  87. {generation}
  88. \n ------- \n
  89. Here is the question: {question}
  90. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  91. input_variables=["generation", "question"],
  92. )
  93. answer_grader = prompt | llm_json | JsonOutputParser()
  94. return answer_grader
  95. # Text-to-SQL
  96. def run_text_to_sql(question: str):
  97. selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
  98. # question = "建準去年的固定燃燒總排放量是多少?"
  99. query, result, answer = run(db, question, selected_table, llm)
  100. return answer, query
  101. def _get_query(question: str):
  102. selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
  103. query = get_query(db, question, selected_table, llm)
  104. return query
  105. def _query_to_nl(question: str, query: str):
  106. answer = query_to_nl(db, question, query, llm)
  107. return answer
  108. ### SQL Grader
  109. def SQL_Grader():
  110. prompt = PromptTemplate(
  111. template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
  112. You are a SQL query grader assessing correctness of PostgreSQL query to a user question.
  113. Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
  114. Here is database description:
  115. {table_info}
  116. You need to check that each where statement is correctly filtered out what user question need.
  117. For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
  118. "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
  119. FROM "104_112碳排放公開及建準資料"
  120. WHERE "事業名稱" like '%建準%'
  121. AND "排放源" = '下游租賃'
  122. AND "盤查標準" = 'GHG'
  123. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
  124. For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
  125. Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
  126. "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  127. FROM "104_112碳排放公開及建準資料"
  128. WHERE "事業名稱" like '%台積電%'
  129. AND "排放源" = '固定燃燒'
  130. AND "盤查標準" = 'GHG'
  131. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
  132. For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
  133. and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
  134. If the PostgreSQL query do not exactly matches the user question, grade it as incorrect.
  135. You need to strictly examine whether the sql PostgreSQL query matches the user question.
  136. Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
  137. Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
  138. <|eot_id|>
  139. <|start_header_id|>user<|end_header_id|>
  140. Here is the PostgreSQL query: \n\n {sql_query} \n\n
  141. Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
  142. """,
  143. input_variables=["table_info", "question", "sql_query"],
  144. )
  145. sql_query_grader = prompt | llm_json | JsonOutputParser()
  146. return sql_query_grader
  147. ### Router
  148. def Router():
  149. prompt = PromptTemplate(
  150. template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
  151. You are an expert at routing a user question to a vectorstore or company private data.
  152. Use company private data for questions about the informations about a company's greenhouse gas emissions data.
  153. Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG.
  154. You do not need to be stringent with the keywords in the question related to these topics.
  155. Give a binary choice 'company_private_data' or 'vectorstore' based on the question.
  156. Return the a JSON with a single key 'datasource' and no premable or explanation.
  157. Question to route: {question}
  158. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  159. input_variables=["question"],
  160. )
  161. question_router = prompt | llm_json | JsonOutputParser()
  162. return question_router
  163. class GraphState(TypedDict):
  164. """
  165. Represents the state of our graph.
  166. Attributes:
  167. question: question
  168. generation: LLM generation
  169. company_private_data: whether to search company private data
  170. documents: list of documents
  171. """
  172. question: str
  173. generation: str
  174. documents: List[str]
  175. retry: int
  176. sql_query: str
  177. # Node
  178. def retrieve_and_generation(state):
  179. """
  180. Retrieve documents from vectorstore
  181. Args:
  182. state (dict): The current graph state
  183. Returns:
  184. state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
  185. """
  186. print("---RETRIEVE---")
  187. question = state["question"]
  188. # Retrieval
  189. # documents = retriever.invoke(question)
  190. # TODO: correct Retrieval function
  191. documents = retriever.get_relevant_documents(question, k=30)
  192. # docs_documents = "\n\n".join(doc.page_content for doc in documents)
  193. # print(documents)
  194. generation = faiss_query(question, documents, llm)
  195. return {"documents": documents, "question": question, "generation": generation}
  196. def company_private_data_get_sql_query(state):
  197. """
  198. Get PostgreSQL query according to question
  199. Args:
  200. state (dict): The current graph state
  201. Returns:
  202. state (dict): return generated PostgreSQL query and record retry times
  203. """
  204. print("---SQL QUERY---")
  205. question = state["question"]
  206. if state["retry"]:
  207. retry = state["retry"]
  208. retry += 1
  209. else:
  210. retry = 0
  211. # print("RETRY: ", retry)
  212. sql_query = _get_query(question)
  213. return {"sql_query": sql_query, "question": question, "retry": retry}
  214. def company_private_data_search(state):
  215. """
  216. Execute PostgreSQL query and convert to nature language.
  217. Args:
  218. state (dict): The current graph state
  219. Returns:
  220. state (dict): Appended sql results to state
  221. """
  222. print("---SQL TO NL---")
  223. # print(state)
  224. question = state["question"]
  225. sql_query = state["sql_query"]
  226. generation = _query_to_nl(question, sql_query)
  227. # generation = [company_private_data_result]
  228. return {"sql_query": sql_query, "question": question, "generation": generation}
  229. ### Conditional edge
  230. def route_question(state):
  231. """
  232. Route question to web search or RAG.
  233. Args:
  234. state (dict): The current graph state
  235. Returns:
  236. str: Next node to call
  237. """
  238. print("---ROUTE QUESTION---")
  239. question = state["question"]
  240. # print(question)
  241. question_router = Router()
  242. source = question_router.invoke({"question": question})
  243. # print(source)
  244. print(source["datasource"])
  245. if source["datasource"] == "company_private_data":
  246. print("---ROUTE QUESTION TO TEXT-TO-SQL---")
  247. return "company_private_data"
  248. elif source["datasource"] == "vectorstore":
  249. print("---ROUTE QUESTION TO RAG---")
  250. return "vectorstore"
  251. def grade_generation_v_documents_and_question(state):
  252. """
  253. Determines whether the generation is grounded in the document and answers question.
  254. Args:
  255. state (dict): The current graph state
  256. Returns:
  257. str: Decision for next node to call
  258. """
  259. print("---CHECK HALLUCINATIONS---")
  260. question = state["question"]
  261. documents = state["documents"]
  262. generation = state["generation"]
  263. # print(docs_documents)
  264. # print(generation)
  265. hallucination_grader = Hallucination_Grader()
  266. score = hallucination_grader.invoke(
  267. {"documents": documents, "generation": generation}
  268. )
  269. # print(score)
  270. grade = score["score"]
  271. # Check hallucination
  272. if grade in ["yes", "true", 1, "1"]:
  273. print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
  274. # Check question-answering
  275. print("---GRADE GENERATION vs QUESTION---")
  276. answer_grader = Answer_Grader()
  277. score = answer_grader.invoke({"question": question, "generation": generation})
  278. grade = score["score"]
  279. if grade in ["yes", "true", 1, "1"]:
  280. print("---DECISION: GENERATION ADDRESSES QUESTION---")
  281. return "useful"
  282. else:
  283. print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
  284. return "not useful"
  285. else:
  286. pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
  287. return "not supported"
  288. def grade_sql_query(state):
  289. """
  290. Determines whether the Postgresql query are correct to the question
  291. Args:
  292. state (dict): The current graph state
  293. Returns:
  294. state (dict): Decision for retry or continue
  295. """
  296. print("---CHECK SQL CORRECTNESS TO QUESTION---")
  297. question = state["question"]
  298. sql_query = state["sql_query"]
  299. retry = state["retry"]
  300. # Score each doc
  301. sql_query_grader = SQL_Grader()
  302. score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
  303. grade = score["score"]
  304. # Document relevant
  305. if grade in ["yes", "true", 1, "1"]:
  306. print("---GRADE: CORRECT SQL QUERY---")
  307. return "correct"
  308. elif retry >= 5:
  309. print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
  310. return "failed"
  311. else:
  312. print("---GRADE: INCORRECT SQL QUERY---")
  313. return "incorrect"
  314. def build_graph():
  315. workflow = StateGraph(GraphState)
  316. # Define the nodes
  317. workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search
  318. workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
  319. workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
  320. workflow.add_conditional_edges(
  321. START,
  322. route_question,
  323. {
  324. "company_private_data": "company_private_data_query",
  325. "vectorstore": "retrieve_and_generation",
  326. },
  327. )
  328. workflow.add_conditional_edges(
  329. "retrieve_and_generation",
  330. grade_generation_v_documents_and_question,
  331. {
  332. "not supported": "retrieve_and_generation",
  333. "useful": END,
  334. "not useful": "retrieve_and_generation",
  335. },
  336. )
  337. workflow.add_conditional_edges(
  338. "company_private_data_query",
  339. grade_sql_query,
  340. {
  341. "correct": "company_private_data_search",
  342. "incorrect": "company_private_data_query",
  343. "failed": END
  344. },
  345. )
  346. workflow.add_edge("company_private_data_search", END)
  347. app = workflow.compile()
  348. return app
  349. def main():
  350. app = build_graph()
  351. #建準去年的類別一排放量?
  352. inputs = {"question": "溫室氣體是什麼"}
  353. for output in app.stream(inputs, {"recursion_limit": 10}):
  354. for key, value in output.items():
  355. pprint(f"Finished running: {key}:")
  356. pprint(value["generation"])
  357. return value["generation"]
  358. if __name__ == "__main__":
  359. result = main()
  360. print("------------------------------------------------------")
  361. print(result)