ai_agent_llama.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. from langchain_community.chat_models import ChatOllama
  2. from langchain_core.output_parsers import JsonOutputParser
  3. from langchain_core.prompts import PromptTemplate
  4. from langchain.prompts import ChatPromptTemplate
  5. from langchain_core.output_parsers import StrOutputParser
  6. # graph usage
  7. from pprint import pprint
  8. from typing import List
  9. from langchain_core.documents import Document
  10. from typing_extensions import TypedDict
  11. from langgraph.graph import END, StateGraph, START
  12. from langgraph.pregel import RetryPolicy
  13. # supabase db
  14. from langchain_community.utilities import SQLDatabase
  15. import os
  16. from dotenv import load_dotenv
  17. load_dotenv()
  18. URI: str = os.environ.get('SUPABASE_URI')
  19. db = SQLDatabase.from_uri(URI)
  20. # LLM
  21. # local_llm = "llama3.1:8b-instruct-fp16"
  22. # local_llm = "llama3.1:8b-instruct-q2_K"
  23. local_llm = "llama3-groq-tool-use:latest"
  24. llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
  25. local_llm = "cwchang/llama3-taide-lx-8b-chat-alpha1:q3_k_s"
  26. llm = ChatOllama(model=local_llm, temperature=0)
  27. sql_llm = ChatOllama(model="codeqwen", temperature=0)
  28. # sql_llm = ChatOllama(model="eramax/nxcode-cq-7b-orpo:q6", temperature=0)
  29. from langchain_openai import ChatOpenAI
  30. # sql_llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  31. # RAG usage
  32. from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
  33. retriever = create_faiss_retriever()
  34. # text-to-sql usage
  35. from text_to_sql_private import run, get_query, query_to_nl, table_description
  36. from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
  37. progress_bar = []
  38. def faiss_query(question: str, llm, docs=None, multi_query: bool = False) -> str:
  39. if multi_query:
  40. docs = faiss_multiquery(question, retriever, llm, k=4)
  41. # print(docs)
  42. elif docs:
  43. pass
  44. else:
  45. docs = retriever.get_relevant_documents(question, k=10)
  46. # print(docs)
  47. context = docs
  48. system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  49. template = """
  50. <|begin_of_text|>
  51. <|start_header_id|>system<|end_header_id|>
  52. 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
  53. You should not mention anything about "根據提供的文件內容" or other similar terms.
  54. 請盡可能的詳細回答問題。
  55. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  56. 勿回答無關資訊或任何與某特定公司相關的問題。
  57. <|eot_id|>
  58. <|start_header_id|>user<|end_header_id|>
  59. Answer the following question based on this context:
  60. {context}
  61. Question: {question}
  62. 用繁體中文回答問題,請用一段話詳細的回答。勿回答無關資訊或任何與某特定公司相關的問題。
  63. 如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  64. <|eot_id|>
  65. <|start_header_id|>assistant<|end_header_id|>
  66. """
  67. prompt = ChatPromptTemplate.from_template(
  68. system_prompt + "\n\n" +
  69. template
  70. )
  71. rag_chain = prompt | llm | StrOutputParser()
  72. return docs, rag_chain.invoke({"context": context, "question": question})
  73. ### Hallucination Grader
  74. def Hallucination_Grader():
  75. # Prompt
  76. prompt = PromptTemplate(
  77. template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  78. You are a grader assessing whether an answer is grounded in / supported by a set of facts.
  79. Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts.
  80. Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation.
  81. Return the a JSON with a single key 'score' and no premable or explanation.
  82. <|eot_id|><|start_header_id|>user<|end_header_id|>
  83. Here are the facts:
  84. \n ------- \n
  85. {documents}
  86. \n ------- \n
  87. Here is the answer: {generation}
  88. Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
  89. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  90. input_variables=["generation", "documents"],
  91. )
  92. hallucination_grader = prompt | llm_json | JsonOutputParser()
  93. return hallucination_grader
  94. ### Answer Grader
  95. def Answer_Grader():
  96. # Prompt
  97. prompt = PromptTemplate(
  98. template="""
  99. <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
  100. answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
  101. useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
  102. <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
  103. \n ------- \n
  104. {generation}
  105. \n ------- \n
  106. Here is the question: {question}
  107. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  108. input_variables=["generation", "question"],
  109. )
  110. answer_grader = prompt | llm_json | JsonOutputParser()
  111. return answer_grader
  112. # Text-to-SQL
  113. # def run_text_to_sql(question: str):
  114. # selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
  115. # # question = "建準去年的固定燃燒總排放量是多少?"
  116. # query, result, answer = run(db, question, selected_table, sql_llm)
  117. # return answer, query
  118. def _get_query(question: str):
  119. selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
  120. question = question.replace("美國", "美國 Inc")
  121. question = question.replace("法國", "法國 SAS")
  122. query, result = get_query(db, question, selected_table, sql_llm)
  123. return query, result
  124. def _query_to_nl(question: str, query: str, result):
  125. question = question.replace("美國", "美國 Inc")
  126. question = question.replace("法國", "法國 SAS")
  127. local_llm = "llama3-groq-tool-use:latest"
  128. llm = ChatOllama(model=local_llm, temperature=0)
  129. answer = query_to_nl(question, query, result, llm)
  130. return answer
  131. def generate_additional_question(sql_query):
  132. terms = parse_sql_where(sql_query)
  133. question_list = []
  134. for term in terms:
  135. if term is None: continue
  136. question_format = [f"什麼是{term}?", f"{term}的用途是什麼"]
  137. question_list.extend(question_format)
  138. return question_list
  139. def generate_additional_detail(sql_query):
  140. terms = parse_sql_where(sql_query)
  141. answer = ""
  142. all_documents = []
  143. for term in list(set(terms)):
  144. print(term)
  145. if term is None: continue
  146. question_format = [ f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
  147. for question in question_format:
  148. documents = retriever.get_relevant_documents(question, k=5)
  149. all_documents.extend(documents)
  150. all_question = "".join(question_format)
  151. documents, generation = faiss_query(all_question, llm, docs=all_documents, multi_query=True)
  152. if "test@systex.com" in generation:
  153. generation = ""
  154. answer += generation
  155. # print(question)
  156. # print(generation)
  157. return answer
  158. ### SQL Grader
  159. def SQL_Grader():
  160. prompt = PromptTemplate(
  161. template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
  162. You are a SQL query grader assessing correctness of PostgreSQL query to a user question.
  163. Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
  164. Here is database description:
  165. {table_info}
  166. You need to check that each where statement is correctly filtered out what user question need.
  167. For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is
  168. "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
  169. FROM "建準碳排放清冊數據new"
  170. WHERE "事業名稱" like '%建準%'
  171. AND "排放源" = '下游租賃'
  172. AND "盤查標準" = 'GHG'
  173. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
  174. For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
  175. Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is
  176. "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  177. FROM "建準碳排放清冊數據new"
  178. WHERE "事業名稱" like '%台積電%'
  179. AND "排放源" = '固定燃燒'
  180. AND "盤查標準" = 'GHG'
  181. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
  182. For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
  183. and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
  184. If the PostgreSQL query do not exactly matches the user question, grade it as incorrect.
  185. You need to strictly examine whether the sql PostgreSQL query matches the user question.
  186. Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
  187. Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
  188. <|eot_id|>
  189. <|start_header_id|>user<|end_header_id|>
  190. Here is the PostgreSQL query: \n\n {sql_query} \n\n
  191. Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
  192. """,
  193. input_variables=["table_info", "question", "sql_query"],
  194. )
  195. sql_query_grader = prompt | llm_json | JsonOutputParser()
  196. return sql_query_grader
  197. ### Router
  198. def Router():
  199. prompt = PromptTemplate(
  200. template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
  201. You are an expert at routing a user question to a 專業知識 or 自有數據.
  202. 你需要分辨使用者問題是否在詢問某個公司與其據點廠房的自有數據或是尋求專業的碳盤查或碳管理等等的 ESG 知識和相關新聞,
  203. 如果問題是想了解某個公司與其據點廠房的碳排放源的排放量或用電、用水量等等,請使用"自有數據",
  204. 若使用者的問題是想了解碳盤查、碳交易或碳管理等等的 ESG 知識和相關新聞,請使用"專業知識"。
  205. You do not need to be stringent with the keywords in the question related to these topics.
  206. Give a binary choice '自有數據' or '專業知識' based on the question.
  207. Return the a JSON with a single key 'datasource' and no premable or explanation.
  208. Question to route: {question}
  209. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
  210. input_variables=["question"],
  211. )
  212. question_router = prompt | llm_json | JsonOutputParser()
  213. return question_router
  214. class GraphState(TypedDict):
  215. """
  216. Represents the state of our graph.
  217. Attributes:
  218. question: question
  219. generation: LLM generation
  220. company_private_data: whether to search company private data
  221. documents: list of documents
  222. """
  223. progress_bar: List[str]
  224. route: str
  225. question: str
  226. question_list: List[str]
  227. generation: str
  228. documents: List[str]
  229. retry: int
  230. sql_query: str
  231. sql_result: str
  232. # Node
  233. def show_progress(state, progress: str):
  234. global progress_bar
  235. # progress_bar = state["progress_bar"] if state["progress_bar"] else []
  236. print(progress)
  237. progress_bar.append(progress)
  238. return progress_bar
  239. def retrieve_and_generation(state):
  240. """
  241. Retrieve documents from vectorstore
  242. Args:
  243. state (dict): The current graph state
  244. Returns:
  245. state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
  246. """
  247. progress_bar = show_progress(state, "---RETRIEVE---")
  248. if not state["route"]:
  249. route = "RAG"
  250. else:
  251. route = state["route"]
  252. question = state["question"]
  253. documents, generation = faiss_query(question, llm, multi_query=True)
  254. print(generation)
  255. return {"progress_bar": progress_bar, "route": route, "documents": documents, "question": question, "generation": generation}
  256. def company_private_data_get_sql_query(state):
  257. """
  258. Get PostgreSQL query according to question
  259. Args:
  260. state (dict): The current graph state
  261. Returns:
  262. state (dict): return generated PostgreSQL query and record retry times
  263. """
  264. # print("---SQL QUERY---")
  265. progress_bar = show_progress(state, "---SQL QUERY---")
  266. if not state["route"]:
  267. route = "SQL"
  268. else:
  269. route = state["route"]
  270. question = state["question"]
  271. if state["retry"]:
  272. retry = state["retry"]
  273. retry += 1
  274. else:
  275. retry = 0
  276. # print("RETRY: ", retry)
  277. sql_query, sql_result = _get_query(question)
  278. print(type(sql_result))
  279. return {"progress_bar": progress_bar, "route": route, "sql_query": sql_query, "sql_result": sql_result, "question": question, "retry": retry}
  280. def company_private_data_search(state):
  281. """
  282. Execute PostgreSQL query and convert to nature language.
  283. Args:
  284. state (dict): The current graph state
  285. Returns:
  286. state (dict): Appended sql results to state
  287. """
  288. # print("---SQL TO NL---")
  289. progress_bar = show_progress(state, "---SQL TO NL---")
  290. # print(state)
  291. question = state["question"]
  292. sql_query = state["sql_query"]
  293. sql_result = state["sql_result"]
  294. generation = _query_to_nl(question, sql_query, sql_result)
  295. # generation = [company_private_data_result]
  296. return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation}
  297. def additional_explanation_question(state):
  298. """
  299. Args:
  300. state (_type_): _description_
  301. Returns:
  302. state (dict): Appended additional explanation to state
  303. """
  304. # print("---ADDITIONAL EXPLANATION---")
  305. progress_bar = show_progress(state, "---ADDITIONAL EXPLANATION---")
  306. # print(state)
  307. question = state["question"]
  308. sql_query = state["sql_query"]
  309. # print(sql_query)
  310. generation = state["generation"]
  311. generation += "\n"
  312. generation += generate_additional_detail(sql_query)
  313. question_list = []
  314. # question_list = generate_additional_question(sql_query)
  315. # print(question_list)
  316. # generation = [company_private_data_result]
  317. return {"progress_bar": progress_bar, "sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
  318. def error(state):
  319. # print("---SOMETHING WENT WRONG---")
  320. progress_bar = show_progress(state, "---SOMETHING WENT WRONG---")
  321. generation = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  322. return {"progress_bar": progress_bar, "generation": generation}
  323. ### Conditional edge
  324. def route_question(state):
  325. """
  326. Route question to web search or RAG.
  327. Args:
  328. state (dict): The current graph state
  329. Returns:
  330. str: Next node to call
  331. """
  332. # print("---ROUTE QUESTION---")
  333. progress_bar = show_progress(state, "---ROUTE QUESTION---")
  334. question = state["question"]
  335. # print(question)
  336. question_router = Router()
  337. source = question_router.invoke({"question": question})
  338. print("Original:", source["datasource"])
  339. # if "建準" in question:
  340. kw = ["建準", "北海", "廣興", "崑山廣興", "Inc", "SAS", "立準"]
  341. if any(char in question for char in kw):
  342. source["datasource"] = "自有數據"
  343. elif "範例" in question:
  344. source["datasource"] = "專業知識"
  345. # print(source)
  346. print(source["datasource"])
  347. if source["datasource"] == "自有數據":
  348. # print("---ROUTE QUESTION TO TEXT-TO-SQL---")
  349. progress_bar = show_progress(state, "---ROUTE QUESTION TO TEXT-TO-SQL---")
  350. return "自有數據"
  351. elif source["datasource"] == "專業知識":
  352. # print("---ROUTE QUESTION TO RAG---")
  353. progress_bar = show_progress(state, "---ROUTE QUESTION TO RAG---")
  354. return "專業知識"
  355. def grade_generation_v_documents_and_question(state):
  356. """
  357. Determines whether the generation is grounded in the document and answers question.
  358. Args:
  359. state (dict): The current graph state
  360. Returns:
  361. str: Decision for next node to call
  362. """
  363. # print("---CHECK HALLUCINATIONS---")
  364. question = state["question"]
  365. documents = state["documents"]
  366. generation = state["generation"]
  367. progress_bar = show_progress(state, "---GRADE GENERATION vs QUESTION---")
  368. answer_grader = Answer_Grader()
  369. score = answer_grader.invoke({"question": question, "generation": generation})
  370. print(score)
  371. grade = score["score"]
  372. if grade in ["yes", "true", 1, "1"]:
  373. # print("---DECISION: GENERATION ADDRESSES QUESTION---")
  374. progress_bar = show_progress(state, "---DECISION: GENERATION ADDRESSES QUESTION---")
  375. return "useful"
  376. else:
  377. # print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
  378. progress_bar = show_progress(state, "---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
  379. return "not useful"
  380. def grade_sql_query(state):
  381. """
  382. Determines whether the Postgresql query are correct to the question
  383. Args:
  384. state (dict): The current graph state
  385. Returns:
  386. state (dict): Decision for retry or continue
  387. """
  388. # print("---CHECK SQL CORRECTNESS TO QUESTION---")
  389. progress_bar = show_progress(state, "---CHECK SQL CORRECTNESS TO QUESTION---")
  390. question = state["question"]
  391. sql_query = state["sql_query"]
  392. sql_result = state["sql_result"]
  393. if "None" in sql_result or sql_result.startswith("Error:"):
  394. progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
  395. return "incorrect"
  396. else:
  397. print(sql_result)
  398. progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
  399. return "correct"
  400. # retry = state["retry"]
  401. # # Score each doc
  402. # sql_query_grader = SQL_Grader()
  403. # score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
  404. # grade = score["score"]
  405. # # Document relevant
  406. # if grade in ["yes", "true", 1, "1"]:
  407. # # print("---GRADE: CORRECT SQL QUERY---")
  408. # progress_bar = show_progress(state, "---GRADE: CORRECT SQL QUERY---")
  409. # return "correct"
  410. # elif retry >= 5:
  411. # # print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
  412. # progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
  413. # return "failed"
  414. # else:
  415. # # print("---GRADE: INCORRECT SQL QUERY---")
  416. # progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
  417. # return "incorrect"
  418. def check_sql_answer(state):
  419. progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
  420. generation = state["generation"]
  421. if "test@systex.com" in generation:
  422. progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
  423. return "bad"
  424. else:
  425. progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
  426. return "good"
  427. def build_graph():
  428. workflow = StateGraph(GraphState)
  429. # Define the nodes
  430. workflow.add_node("Text-to-SQL", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search
  431. workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search
  432. workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5)) # retrieve
  433. workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve
  434. workflow.add_node("ERROR", error) # retrieve
  435. company_private_data_search
  436. workflow.add_conditional_edges(
  437. START,
  438. route_question,
  439. {
  440. "自有數據": "Text-to-SQL",
  441. "專業知識": "RAG",
  442. },
  443. )
  444. workflow.add_conditional_edges(
  445. "RAG",
  446. grade_generation_v_documents_and_question,
  447. {
  448. "useful": END,
  449. "not useful": "ERROR",
  450. },
  451. )
  452. workflow.add_conditional_edges(
  453. "Text-to-SQL",
  454. grade_sql_query,
  455. {
  456. "correct": "SQL Answer",
  457. "incorrect": "RAG",
  458. },
  459. )
  460. workflow.add_conditional_edges(
  461. "SQL Answer",
  462. check_sql_answer,
  463. {
  464. "good": "Additoinal Explanation",
  465. "bad": "RAG",
  466. },
  467. )
  468. # workflow.add_edge("SQL Answer", "Additoinal Explanation")
  469. workflow.add_edge("Additoinal Explanation", END)
  470. app = workflow.compile()
  471. return app
  472. app = build_graph()
  473. draw_mermaid = app.get_graph().draw_mermaid()
  474. print(draw_mermaid)
  475. def main(question: str):
  476. inputs = {"question": question, "progress_bar": None}
  477. for output in app.stream(inputs, {"recursion_limit": 10}):
  478. for key, value in output.items():
  479. pprint(f"Finished running: {key}:")
  480. # pprint(value["generation"])
  481. # pprint(value)
  482. value["progress_bar"] = progress_bar
  483. # pprint(value["progress_bar"])
  484. # return value["generation"]
  485. return value
  486. if __name__ == "__main__":
  487. # result = main("建準去年的逸散排放總排放量是多少?")
  488. # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
  489. result = main("建準北海廠去年的固定燃燒排放量是多少?")
  490. # result = main("溫室氣體是什麼?")
  491. # result = main("什麼是外購電力(綠電)?")
  492. print("------------------------------------------------------")
  493. print(result)