{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LLM" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### LLM\n", "from langchain_community.chat_models import ChatOllama\n", "# local_llm = \"llama3.1:8b-instruct-fp16\"\n", "local_llm = \"llama3-groq-tool-use:latest\"\n", "\n", "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "llm = ChatOllama(model=local_llm, temperature=0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# RAG" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FAISS index loaded from faiss_index.bin\n", "Metadata loaded from faiss_metadata.pkl\n", "Using existing FAISS index and metadata.\n", "Creating FAISS retriever...\n" ] } ], "source": [ "from faiss_index import create_faiss_retriever, faiss_query\n", "retriever = create_faiss_retriever()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from langchain.prompts import ChatPromptTemplate\n", "from langchain_core.output_parsers import StrOutputParser\n", "def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:\n", " context = docs\n", " # try:\n", " # context = \"\\n\".join(doc.page_content for doc in docs)\n", " # except:\n", " # context = \"\\n\".join(doc for doc in docs)\n", " \n", " system_prompt: str = \"你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。\"\n", " template = \"\"\"\n", " <|begin_of_text|>\n", " \n", " <|start_header_id|>system<|end_header_id|>\n", " 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \\n\n", " You should not mention anything about \"根據提供的文件內容\" or other similar terms.\n", " Use five sentences maximum and keep the answer concise.\n", " 如果你不知道答案請回答:\"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。\"\n", " 勿回答無關資訊\n", " <|eot_id|>\n", " \n", " <|start_header_id|>user<|end_header_id|>\n", " Answer the following question based on this context:\n", "\n", " {context}\n", "\n", " Question: {question}\n", " 用繁體中文\n", " <|eot_id|>\n", " \n", " <|start_header_id|>assistant<|end_header_id|>\n", " \"\"\"\n", " prompt = ChatPromptTemplate.from_template(\n", " system_prompt + \"\\n\\n\" +\n", " template\n", " )\n", " \n", " # prompt = ChatPromptTemplate.from_template(\n", " # system_prompt + \"\\n\\n\" +\n", " # \"Answer the following question based on this context:\\n\\n\"\n", " # \"{context}\\n\\n\"\n", " # \"Question: {question}\\n\"\n", " # \"Answer in the same language as the question. If you don't know the answer, \"\n", " # \"say 'I'm sorry, I don't have enough information to answer that question.'\"\n", " # )\n", "\n", " \n", " # chain = prompt | taide_llm | StrOutputParser()\n", " chain = prompt | llm | StrOutputParser()\n", " return chain.invoke({\"context\": context, \"question\": question})" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# docs = retriever.get_relevant_documents(question, k=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "question = \"誰需要繳交碳費?\"\n", "docs = retriever.get_relevant_documents(question, k=50)\n", "docs" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "### Generate\n", "# llm = ChatOllama(model=local_llm, temperature=0)\n", "\n", "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n", "# generation = faiss_query(question, docs_documents, llm)\n", "# generation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Retrieval Grader" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "### Retrieval Grader\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance \n", " of a retrieved document to a user question. If the document contains keywords related to the user question, \n", " grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n", " Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \\n\n", " Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n", " <|eot_id|><|start_header_id|>user<|end_header_id|>\n", " Here is the retrieved document: \\n\\n {document} \\n\\n\n", " Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", " \"\"\",\n", " input_variables=[\"question\", \"document\"],\n", ")\n", "\n", "retrieval_grader = prompt | llm_json | JsonOutputParser()\n", "# question = \"溫室氣體是什麼\"\n", "# # docs = retriever.invoke(question)\n", "# docs = retriever.get_relevant_documents(question, k=10)\n", "# doc_txt = docs[1].page_content\n", "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# for doc in docs:\n", "# doc_txt = doc.page_content\n", "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hallucination Grader" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 'yes'}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "### Hallucination Grader\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "# Prompt\n", "prompt = PromptTemplate(\n", " template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> \n", " You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n", " Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n", " Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. \n", " Return the a JSON with a single key 'score' and no premable or explanation. \n", " <|eot_id|><|start_header_id|>user<|end_header_id|>\n", " Here are the facts:\n", " \\n ------- \\n\n", " {documents} \n", " \\n ------- \\n\n", " Here is the answer: {generation} \n", " Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.\n", " <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", " input_variables=[\"generation\", \"documents\"],\n", ")\n", "\n", "hallucination_grader = prompt | llm_json | JsonOutputParser()\n", "\n", "question = \"誰需要繳交碳費?\"\n", "docs = retriever.get_relevant_documents(question, k=10)\n", "\n", "generation = faiss_query(question, docs, llm)\n", "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n", "\n", "hallucination_grader.invoke({\"documents\": docs, \"generation\": generation})" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Answer Grader" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "### Answer Grader\n", "\n", "# LLM\n", "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "# Prompt\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an \n", " answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is \n", " useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.\n", " <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:\n", " \\n ------- \\n\n", " {generation} \n", " \\n ------- \\n\n", " Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", " input_variables=[\"generation\", \"question\"],\n", ")\n", "\n", "answer_grader = prompt | llm_json | JsonOutputParser()\n", "# answer_grader.invoke({\"question\": question, \"generation\": generation})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SQL" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib import reload # Python 3.4+\n", "import text_to_sql2\n", "reload(text_to_sql2)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] } ], "source": [ "\n", "from text_to_sql import run\n", "from langchain_community.utilities import SQLDatabase\n", "import os\n", "URI: str = os.environ.get('SUPABASE_URI')\n", "db = SQLDatabase.from_uri(URI)\n", "\n", "def run_text_to_sql(question: str):\n", " selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']\n", " # question = \"建準去年的固定燃燒總排放量是多少?\"\n", " query, result, answer = run(db, question, selected_table, llm)\n", " \n", " return answer, query" ] }, { "cell_type": "code", "execution_count": 187, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 187, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib import reload # Python 3.4+\n", "import text_to_sql2\n", "reload(text_to_sql2)" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] } ], "source": [ "from text_to_sql2 import run, get_query, query_to_nl\n", "from langchain_community.utilities import SQLDatabase\n", "import os\n", "URI: str = os.environ.get('SUPABASE_URI')\n", "db = SQLDatabase.from_uri(URI)\n", "\n", "def run_text_to_sql(question: str):\n", " selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']\n", " # question = \"建準去年的固定燃燒總排放量是多少?\"\n", " query, result, answer = run(db, question, selected_table, llm)\n", " \n", " return answer, query\n", "\n", "def _get_query(question: str):\n", " selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']\n", " query = get_query(db, question, selected_table, llm)\n", " return query\n", "\n", "def _query_to_nl(question: str, query: str):\n", " answer = query_to_nl(db, question, query, llm)\n", " return answer" ] }, { "cell_type": "code", "execution_count": 193, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別2總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%建準%'\n", "AND \"類別\" = '類別2-能源間接排放'\n", "AND \"盤查標準\" = 'GHG'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n" ] }, { "data": { "text/plain": [ "'SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別2總排放量\"\\nFROM \"104_112碳排放公開及建準資料\"\\nWHERE \"事業名稱\" like \\'%建準%\\'\\nAND \"類別\" = \\'類別2-能源間接排放\\'\\nAND \"盤查標準\" = \\'GHG\\'\\nAND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;'" ] }, "execution_count": 193, "metadata": {}, "output_type": "execute_result" } ], "source": [ "question = \"建準去年的類別2總排放量是多少?\"\n", "query = _get_query(question)\n", "query" ] }, { "cell_type": "code", "execution_count": 194, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(19952.8,)]\n" ] }, { "data": { "text/plain": [ "('SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別2總排放量\"\\nFROM \"104_112碳排放公開及建準資料\"\\nWHERE \"事業名稱\" like \\'%建準%\\'\\nAND \"類別\" = \\'類別2-能源間接排放\\'\\nAND \"盤查標準\" = \\'GHG\\'\\nAND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;',\n", " '[(19952.8,)]',\n", " '根據資料,去年的類別2總排放量是 19,952.8 公噸CO2e。')" ] }, "execution_count": 194, "metadata": {}, "output_type": "execute_result" } ], "source": [ "answer = _query_to_nl(question, query)\n", "answer" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SELECT SUM(\"排放量(公噸CO2e)\") AS \"台積電2021年類別2總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%台灣積體電路製造股份有限公司%'\n", "AND \"類別\" = '類別2-能源間接排放'\n", "AND \"盤查標準\" = 'GHG'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n", "[(None,)]\n" ] }, { "data": { "text/plain": [ "('根據資料,台積電去年的類別2總排放量為0公噸CO2e。',\n", " 'SELECT SUM(\"排放量(公噸CO2e)\") AS \"台積電2021年類別2總排放量\"\\nFROM \"104_112碳排放公開及建準資料\"\\nWHERE \"事業名稱\" like \\'%台灣積體電路製造股份有限公司%\\'\\nAND \"類別\" = \\'類別2-能源間接排放\\'\\nAND \"盤查標準\" = \\'GHG\\'\\nAND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;')" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "source": [ "question = \"台積電去年的類別2總排放量是多少?\"\n", "answer, query = run_text_to_sql(question)\n", "answer, query" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SQL Grader" ] }, { "cell_type": "code", "execution_count": 177, "metadata": {}, "outputs": [], "source": [ "\n", "### SQL Grader\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n", " You are a SQL query grader assessing correctness of PostgreSQL query to a user question. \n", " Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.\n", " \n", " Here is database description:\n", " {table_info}\n", " \n", " You need to check that each where statement is correctly filtered out what user question need.\n", " \n", " For example, if user question is \"建準去年的固定燃燒總排放量是多少?\", and the PostgreSQL query is \n", " \"SELECT SUM(\"排放量(公噸CO2e)\") AS \"下游租賃總排放量\"\n", " FROM \"104_112碳排放公開及建準資料\"\n", " WHERE \"事業名稱\" like '%建準%'\n", " AND \"排放源\" = '下游租賃'\n", " AND \"盤查標準\" = 'GHG'\n", " AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n", " For the above example, we can find that user asked for \"固定燃燒\", but the PostgreSQL query gives \"排放源\" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n", " \n", " Another example like \"建準去年的固定燃燒總排放量是多少?\", and the PostgreSQL query is \n", " \"SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n", " FROM \"104_112碳排放公開及建準資料\"\n", " WHERE \"事業名稱\" like '%台積電%'\n", " AND \"排放源\" = '固定燃燒'\n", " AND \"盤查標準\" = 'GHG'\n", " AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n", " For the above example, we can find that user asked for \"建準\", but the PostgreSQL query gives \"事業名稱\" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n", " \n", " and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.\n", " \n", " If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n", " You need to strictly examine whether the sql PostgreSQL query matches the user question.\n", " Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \\n\n", " Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n", " <|eot_id|>\n", " \n", " <|start_header_id|>user<|end_header_id|>\n", " Here is the PostgreSQL query: \\n\\n {sql_query} \\n\\n\n", " Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", " \"\"\",\n", " input_variables=[\"table_info\", \"question\", \"sql_query\"],\n", ")\n", "\n", "sql_query_grader = prompt | llm_json | JsonOutputParser()" ] }, { "cell_type": "code", "execution_count": 180, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 'no'}\n" ] } ], "source": [ "from text_to_sql2 import table_description\n", "question = \"建準去年的類別一排放量\"\n", "# sql_query = \"\"\"\n", "# SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"類別一排放量\"\n", "# FROM \"2023 清冊數據(GHG)\"\n", "# WHERE \"類別\" = '類別一-直接排放'\n", "# \"\"\"\n", "question = \"台積電去年的固定燃燒總排放量是多少?\"\n", "sql_query = \"\"\"\n", "SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%建準%'\n", "AND \"排放源\" = '固定燃燒'\n", "AND \"盤查標準\" = 'GHG'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n", "\"\"\"\n", "print(sql_query_grader.invoke({\"table_info\": table_description(), \"question\": question, \"sql_query\": sql_query}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Router" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'datasource': 'company_private_data'}\n" ] } ], "source": [ "### Router\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n", " You are an expert at routing a user question to a vectorstore or company private data. \n", " Use company private data for questions about the informations about a company's greenhouse gas emissions data.\n", " Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. \n", " You do not need to be stringent with the keywords in the question related to these topics. \n", " Give a binary choice 'company_private_data' or 'vectorstore' based on the question. \n", " Return the a JSON with a single key 'datasource' and no premable or explanation. \n", " \n", " Question to route: {question} \n", " <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", " input_variables=[\"question\"],\n", ")\n", "\n", "question_router = prompt | llm_json | JsonOutputParser()\n", "question = \"建準去年的類別1排放量是多少?\"\n", "# docs = retriever.get_relevant_documents(question)\n", "# doc_txt = docs[1].page_content\n", "print(question_router.invoke({\"question\": question}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Node" ] }, { "cell_type": "code", "execution_count": 215, "metadata": {}, "outputs": [], "source": [ "# RAG + text-to-sql\n", "\n", "from pprint import pprint\n", "from typing import List\n", "\n", "from langchain_core.documents import Document\n", "from typing_extensions import TypedDict\n", "\n", "from langgraph.graph import END, StateGraph, START\n", "\n", "### State\n", "\n", "\n", "class GraphState(TypedDict):\n", " \"\"\"\n", " Represents the state of our graph.\n", "\n", " Attributes:\n", " question: question\n", " generation: LLM generation\n", " company_private_data: whether to search company private data\n", " documents: list of documents\n", " \"\"\"\n", "\n", " question: str\n", " generation: str\n", " documents: List[str]\n", " retry: int\n", " sql_query: str" ] }, { "cell_type": "code", "execution_count": 221, "metadata": {}, "outputs": [], "source": [ "def retrieve_and_generation(state):\n", " \"\"\"\n", " Retrieve documents from vectorstore\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM\n", " \"\"\"\n", " print(\"---RETRIEVE---\")\n", " question = state[\"question\"]\n", "\n", " # Retrieval\n", " # documents = retriever.invoke(question)\n", " # TODO: correct Retrieval function\n", " documents = retriever.get_relevant_documents(question, k=30)\n", " # docs_documents = \"\\n\\n\".join(doc.page_content for doc in documents)\n", " print(documents)\n", " generation = faiss_query(question, documents, llm)\n", " return {\"documents\": documents, \"question\": question, \"generation\": generation}\n", "\n", "def company_private_data_get_sql_query(state):\n", " \"\"\"\n", " Get PostgreSQL query according to question\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): return generated PostgreSQL query and record retry times\n", " \"\"\"\n", " print(\"---SQL QUERY---\")\n", " question = state[\"question\"]\n", " \n", " if state[\"retry\"]:\n", " retry = state[\"retry\"]\n", " retry += 1\n", " else: \n", " retry = 0\n", " print(\"RETRY: \", retry)\n", " \n", " sql_query = _get_query(question)\n", " \n", " return {\"sql_query\": sql_query, \"question\": question, \"retry\": retry}\n", " \n", "def company_private_data_search(state):\n", " \"\"\"\n", " Execute PostgreSQL query and convert to nature language.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Appended sql results to state\n", " \"\"\"\n", "\n", " print(\"---SQL TO NL---\")\n", " print(state)\n", " question = state[\"question\"]\n", " sql_query = state[\"sql_query\"]\n", " generation = _query_to_nl(question, sql_query)\n", " \n", " # generation = [company_private_data_result]\n", " \n", " return {\"sql_query\": sql_query, \"question\": question, \"generation\": generation}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conditional edge" ] }, { "cell_type": "code", "execution_count": 229, "metadata": {}, "outputs": [], "source": [ "### Conditional edge\n", "\n", "\n", "def route_question(state):\n", " \"\"\"\n", " Route question to web search or RAG.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " str: Next node to call\n", " \"\"\"\n", "\n", " print(\"---ROUTE QUESTION---\")\n", " question = state[\"question\"]\n", " print(question)\n", " source = question_router.invoke({\"question\": question})\n", " print(source)\n", " print(source[\"datasource\"])\n", " if source[\"datasource\"] == \"company_private_data\":\n", " print(\"---ROUTE QUESTION TO TEXT-TO-SQL---\")\n", " return \"company_private_data\"\n", " elif source[\"datasource\"] == \"vectorstore\":\n", " print(\"---ROUTE QUESTION TO RAG---\")\n", " return \"vectorstore\"\n", " \n", "def grade_generation_v_documents_and_question(state):\n", " \"\"\"\n", " Determines whether the generation is grounded in the document and answers question.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " str: Decision for next node to call\n", " \"\"\"\n", "\n", " print(\"---CHECK HALLUCINATIONS---\")\n", " question = state[\"question\"]\n", " documents = state[\"documents\"]\n", " generation = state[\"generation\"]\n", "\n", " \n", " # print(docs_documents)\n", " # print(generation)\n", " score = hallucination_grader.invoke(\n", " {\"documents\": documents, \"generation\": generation}\n", " )\n", " print(score)\n", " grade = score[\"score\"]\n", "\n", " # Check hallucination\n", " if grade in [\"yes\", \"true\", 1, \"1\"]:\n", " print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n", " # Check question-answering\n", " print(\"---GRADE GENERATION vs QUESTION---\")\n", " score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n", " grade = score[\"score\"]\n", " if grade in [\"yes\", \"true\", 1, \"1\"]:\n", " print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n", " return \"useful\"\n", " else:\n", " print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n", " return \"not useful\"\n", " else:\n", " pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n", " return \"not supported\"\n", " \n", "def grade_sql_query(state):\n", " \"\"\"\n", " Determines whether the Postgresql query are correct to the question\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Decision for retry or continue\n", " \"\"\"\n", "\n", " print(\"---CHECK SQL CORRECTNESS TO QUESTION---\")\n", " question = state[\"question\"]\n", " sql_query = state[\"sql_query\"]\n", " retry = state[\"retry\"]\n", "\n", " # Score each doc\n", " \n", " score = sql_query_grader.invoke({\"table_info\": table_description(), \"question\": question, \"sql_query\": sql_query})\n", " grade = score[\"score\"]\n", " # Document relevant\n", " if grade in [\"yes\", \"true\", 1, \"1\"]:\n", " print(\"---GRADE: CORRECT SQL QUERY---\")\n", " return \"correct\"\n", " elif retry >= 5:\n", " print(\"---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---\")\n", " return \"failed\"\n", " else:\n", " print(\"---GRADE: INCORRECT SQL QUERY---\")\n", " return \"incorrect\"\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Graph" ] }, { "cell_type": "code", "execution_count": 233, "metadata": {}, "outputs": [], "source": [ "\n", "from langgraph.pregel import RetryPolicy\n", "\n", "workflow = StateGraph(GraphState)\n", "\n", "# Define the nodes\n", "workflow.add_node(\"company_private_data_query\", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search\n", "workflow.add_node(\"company_private_data_search\", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search\n", "workflow.add_node(\"retrieve_and_generation\", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve\n", "# workflow.add_node(\"grade_generation\", grade_documents) # grade documents\n", "# workflow.add_node(\"generate\", generate) # generatae\n", "\n", "workflow.add_conditional_edges(\n", " START,\n", " route_question,\n", " {\n", " \"company_private_data\": \"company_private_data_query\",\n", " \"vectorstore\": \"retrieve_and_generation\",\n", " },\n", ")\n", "\n", "workflow.add_conditional_edges(\n", " \"retrieve_and_generation\",\n", " grade_generation_v_documents_and_question,\n", " {\n", " \"not supported\": \"retrieve_and_generation\",\n", " \"useful\": END,\n", " \"not useful\": \"retrieve_and_generation\",\n", " },\n", ")\n", "workflow.add_conditional_edges(\n", " \"company_private_data_query\",\n", " grade_sql_query,\n", " {\n", " \"correct\": \"company_private_data_search\",\n", " \"incorrect\": \"company_private_data_query\",\n", " \"failed\": END\n", " \n", " },\n", ")\n", "workflow.add_edge(\"company_private_data_search\", END)\n", "\n", "\n", "# workflow.add_edge(\"company_private_data_search\", END)\n", "\n", "app = workflow.compile()" ] }, { "cell_type": "code", "execution_count": 224, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "%%{init: {'flowchart': {'curve': 'linear'}}}%%\n", "graph TD;\n", "\t__start__([__start__]):::first\n", "\tcompany_private_data_query(company_private_data_query)\n", "\tcompany_private_data_search(company_private_data_search)\n", "\tretrieve_and_generation(retrieve_and_generation)\n", "\t__end__([__end__]):::last\n", "\tcompany_private_data_query --> company_private_data_search;\n", "\t__start__ -.  company_private_data  .-> company_private_data_query;\n", "\t__start__ -.  vectorstore  .-> retrieve_and_generation;\n", "\tretrieve_and_generation -.  not supported  .-> retrieve_and_generation;\n", "\tretrieve_and_generation -.  useful  .-> __end__;\n", "\tretrieve_and_generation -.  not useful  .-> retrieve_and_generation;\n", "\tcompany_private_data_search -.  not supported  .-> company_private_data_query;\n", "\tcompany_private_data_search -.  useful  .-> __end__;\n", "\tcompany_private_data_search -.  not useful  .-> retrieve_and_generation;\n", "\tclassDef default fill:#f2f0ff,line-height:1.2\n", "\tclassDef first fill-opacity:0\n", "\tclassDef last fill:#bfb6fc\n", "\n" ] } ], "source": [ "print(app.get_graph().draw_mermaid())" ] }, { "cell_type": "code", "execution_count": 234, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import Image, display\n", "from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles\n", "\n", "display(\n", " Image(\n", " app.get_graph().draw_mermaid_png(\n", " draw_method=MermaidDrawMethod.API,\n", " output_file_path=\"agent_workflow.png\",\n", " )\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test" ] }, { "cell_type": "code", "execution_count": 288, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---ROUTE QUESTION---\n", "建準去年的類別八排放量\n", "{'datasource': 'company_private_data'}\n", "company_private_data\n", "---ROUTE QUESTION TO TEXT-TO-SQL---\n", "---SQL QUERY---\n", "RETRY: 0\n", "SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別8總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%建準%'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\n", "AND \"類別\" = '類別8-';\n", "---CHECK SQL CORRECTNESS TO QUESTION---\n", "---GRADE: CORRECT SQL QUERY---\n", "'Finished running: company_private_data_query:'\n", "---SQL TO NL---\n", "{'question': '建準去年的類別八排放量', 'generation': None, 'documents': None, 'retry': 0, 'sql_query': 'SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別8總排放量\"\\nFROM \"104_112碳排放公開及建準資料\"\\nWHERE \"事業名稱\" like \\'%建準%\\'\\nAND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\\nAND \"類別\" = \\'類別8-\\';'}\n", "[(None,)]\n", "'Finished running: company_private_data_search:'\n", "('SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別8總排放量\"\\n'\n", " 'FROM \"104_112碳排放公開及建準資料\"\\n'\n", " 'WHERE \"事業名稱\" like \\'%建準%\\'\\n'\n", " 'AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\\n'\n", " 'AND \"類別\" = \\'類別8-\\';',\n", " '[(None,)]',\n", " '根據 SQL '\n", " '查詢和結果,去年建準的類別八排放量為空,可能是資料庫中沒有符合條件的資料。建議檢查資料庫中是否有符合條件的資料,或調整查詢條件以確保能夠正確查詢到相應的數據。')\n" ] } ], "source": [ "# Test\n", "\n", "inputs = {\"question\": \"建準去年的類別八排放量\"}\n", "for output in app.stream(inputs, {\"recursion_limit\": 10}):\n", " for key, value in output.items():\n", " pprint(f\"Finished running: {key}:\")\n", "pprint(value[\"generation\"])" ] }, { "cell_type": "code", "execution_count": 181, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'建準去年的類別一排放量是13.5953。'" ] }, "execution_count": 181, "metadata": {}, "output_type": "execute_result" } ], "source": [ "value[\"generation\"]" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'stop' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[27], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n", "\u001b[0;31mNameError\u001b[0m: name 'stop' is not defined" ] } ], "source": [ "stop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# from RAG_strategy import multi_query_chain\n", "llm = ChatOllama(model=local_llm, temperature=0)\n", "question = \"溫室氣體是什麼\"\n", "generate_queries = multi_query_chain(llm)\n", "\n", "questions = generate_queries.invoke(question)\n", "questions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def ask_more_detail_chain(llm):\n", " # Multi Query: Different Perspectives\n", " template = \"\"\"\n", " <|begin_of_text|>\n", " \n", " <|start_header_id|>system<|end_header_id|>\n", " 你是一個來自台灣的AI助理,你的專長是根據使用者提供的文本來進一步詢問當中的細節,例如名詞解釋,請用繁體中文。 \\n\n", " You are an AI language model assistant. \n", " Your task is to generate three questions about the given user context as additional explanation.\n", " By generating three in-depth questions about the user context, your goal is to help the user realize more details. \n", " Provide these questions separated by newlines.\n", " For example:\n", " context: 建準廣興廠去年2023年一共自產發電了684,508度綠電\n", " in-depth question:什麼是綠電?\\n為何要使用綠電? \n", "\n", " output must in user's language and no preamble or explanation.\n", " <|eot_id|>\n", " \n", " <|start_header_id|>user<|end_header_id|>\n", " \n", " \n", " \n", " Original context: {question}\n", " three questions:\n", " <|eot_id|>\n", " \n", " <|start_header_id|>assistant<|end_header_id|>\"\"\"\n", " prompt_perspectives = ChatPromptTemplate.from_template(template)\n", "\n", " \n", " # llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n", " # llm = ChatOllama(model=\"llama3\", num_gpu=1, temperature=0)\n", "\n", " generate_queries = (\n", " prompt_perspectives \n", " | llm\n", " | StrOutputParser() \n", " | (lambda x: x.split(\"\\n\"))\n", " )\n", "\n", " return generate_queries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm = ChatOllama(model=local_llm, temperature=0)\n", "question = \"建準廣興廠去年的固定燃燒排放量是多少?\"\n", "generate_queries = ask_more_detail_chain(llm)\n", "\n", "questions = generate_queries.invoke(question)\n", "questions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm = ChatOllama(model=local_llm, temperature=0)\n", "question = \"固定燃燒是什麼?\"\n", "docs = retriever.get_relevant_documents(question, k=10)\n", "generation = faiss_query(question, docs, llm)\n", "generation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import datetime\n", "from typing import Literal, Optional, Tuple\n", "\n", "from langchain_core.pydantic_v1 import BaseModel, Field\n", "\n", "\n", "class SubQuery(BaseModel):\n", " \"\"\"Search over a database of tutorial videos about a software library.\"\"\"\n", "\n", " sub_query: str = Field(\n", " ...,\n", " description=\"A very specific query against the database.\",\n", " )" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[SubQuery(sub_query='查詢去年類別一的排放量')]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langchain.output_parsers import PydanticToolsParser\n", "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_openai import ChatOpenAI\n", "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", "\n", "system = \"\"\"You are an expert at converting user questions into database queries. \\\n", "You have access to a database of tutorial videos about a software library for building LLM-powered applications. \\\n", "\n", "Perform query decomposition. Given a user question, break it down into distinct sub questions that \\\n", "you need to answer in order to answer the original question.\n", "\n", "If there are acronyms or words you are not familiar with, do not try to rephrase them.\n", "用繁體中文.\n", "\"\"\"\n", "prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\"system\", system),\n", " (\"human\", \"{question}用繁體中文\"),\n", " ]\n", ")\n", "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n", "llm_with_tools = llm.bind_tools([SubQuery])\n", "parser = PydanticToolsParser(tools=[SubQuery])\n", "query_analyzer = prompt | llm_with_tools | parser\n", "\n", "query_analyzer.invoke({\"question\": \"建準去年的類別一排放量?溫室氣體是什麼?\"})" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query_analyzer.invoke({\"question\": \"溫室氣體是什麼?建準去年的類別一排放量?\"})" ] } ], "metadata": { "kernelspec": { "display_name": "llama3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }