{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LLM" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "### LLM\n", "from langchain_community.chat_models import ChatOllama\n", "# local_llm = \"llama3.1:8b-instruct-fp16\"\n", "local_llm = \"llama3-groq-tool-use:latest\"\n", "\n", "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "llm = ChatOllama(model=local_llm, temperature=0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# RAG" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Retriever" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "FAISS index loaded from faiss_index.bin\n", "Metadata loaded from faiss_metadata.pkl\n", "Using existing FAISS index and metadata.\n", "Creating FAISS retriever...\n" ] } ], "source": [ "from faiss_index import create_faiss_retriever, faiss_query\n", "retriever = create_faiss_retriever()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generation" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "from langchain.prompts import ChatPromptTemplate\n", "from langchain_core.output_parsers import StrOutputParser\n", "def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:\n", " context = docs\n", " # try:\n", " # context = \"\\n\".join(doc.page_content for doc in docs)\n", " # except:\n", " # context = \"\\n\".join(doc for doc in docs)\n", " \n", " system_prompt: str = \"你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。\"\n", " template = \"\"\"\n", " <|begin_of_text|>\n", " \n", " <|start_header_id|>system<|end_header_id|>\n", " 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \\n\n", " You should not mention anything about \"根據提供的文件內容\" or other similar terms.\n", " Use five sentences maximum and keep the answer concise.\n", " 如果你不知道答案請回答:\"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。\"\n", " 勿回答無關資訊\n", " <|eot_id|>\n", " \n", " <|start_header_id|>user<|end_header_id|>\n", " Answer the following question based on this context:\n", "\n", " {context}\n", "\n", " Question: {question}\n", " 用繁體中文回答問題\n", " <|eot_id|>\n", " \n", " <|start_header_id|>assistant<|end_header_id|>\n", " \"\"\"\n", " prompt = ChatPromptTemplate.from_template(\n", " system_prompt + \"\\n\\n\" +\n", " template\n", " )\n", " \n", " # prompt = ChatPromptTemplate.from_template(\n", " # system_prompt + \"\\n\\n\" +\n", " # \"Answer the following question based on this context:\\n\\n\"\n", " # \"{context}\\n\\n\"\n", " # \"Question: {question}\\n\"\n", " # \"Answer in the same language as the question. If you don't know the answer, \"\n", " # \"say 'I'm sorry, I don't have enough information to answer that question.'\"\n", " # )\n", "\n", " \n", " # chain = prompt | taide_llm | StrOutputParser()\n", " chain = prompt | llm | StrOutputParser()\n", " return chain.invoke({\"context\": context, \"question\": question})" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# docs = retriever.get_relevant_documents(question, k=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "question = \"誰需要繳交碳費?\"\n", "docs = retriever.get_relevant_documents(question, k=50)\n", "docs" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "### Generate\n", "# llm = ChatOllama(model=local_llm, temperature=0)\n", "\n", "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n", "# generation = faiss_query(question, docs_documents, llm)\n", "# generation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Retrieval Grader" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "### Retrieval Grader\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance \n", " of a retrieved document to a user question. If the document contains keywords related to the user question, \n", " grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n", " Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \\n\n", " Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n", " <|eot_id|><|start_header_id|>user<|end_header_id|>\n", " Here is the retrieved document: \\n\\n {document} \\n\\n\n", " Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", " \"\"\",\n", " input_variables=[\"question\", \"document\"],\n", ")\n", "\n", "retrieval_grader = prompt | llm_json | JsonOutputParser()\n", "# question = \"溫室氣體是什麼\"\n", "# # docs = retriever.invoke(question)\n", "# docs = retriever.get_relevant_documents(question, k=10)\n", "# doc_txt = docs[1].page_content\n", "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# for doc in docs:\n", "# doc_txt = doc.page_content\n", "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hallucination Grader" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 'yes'}" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "### Hallucination Grader\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "# Prompt\n", "prompt = PromptTemplate(\n", " template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> \n", " You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n", " Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n", " Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. \n", " Return the a JSON with a single key 'score' and no premable or explanation. \n", " <|eot_id|><|start_header_id|>user<|end_header_id|>\n", " Here are the facts:\n", " \\n ------- \\n\n", " {documents} \n", " \\n ------- \\n\n", " Here is the answer: {generation} \n", " Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.\n", " <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", " input_variables=[\"generation\", \"documents\"],\n", ")\n", "\n", "hallucination_grader = prompt | llm_json | JsonOutputParser()\n", "\n", "question = \"誰需要繳交碳費?\"\n", "docs = retriever.get_relevant_documents(question, k=10)\n", "\n", "generation = faiss_query(question, docs, llm)\n", "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n", "\n", "hallucination_grader.invoke({\"documents\": docs, \"generation\": generation})" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Answer Grader" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "### Answer Grader\n", "\n", "# LLM\n", "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "# Prompt\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an \n", " answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is \n", " useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.\n", " <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:\n", " \\n ------- \\n\n", " {generation} \n", " \\n ------- \\n\n", " Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", " input_variables=[\"generation\", \"question\"],\n", ")\n", "\n", "answer_grader = prompt | llm_json | JsonOutputParser()\n", "# answer_grader.invoke({\"question\": question, \"generation\": generation})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SQL" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib import reload # Python 3.4+\n", "import text_to_sql2\n", "reload(text_to_sql2)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n", "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] } ], "source": [ "\n", "from text_to_sql import run, get_query, query_to_nl\n", "from langchain_community.utilities import SQLDatabase\n", "import os\n", "URI: str = os.environ.get('SUPABASE_URI')\n", "db = SQLDatabase.from_uri(URI)\n", "\n", "def run_text_to_sql(question: str):\n", " selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']\n", " # question = \"建準去年的固定燃燒總排放量是多少?\"\n", " query, result, answer = run(db, question, selected_table, llm)\n", " \n", " return answer, query\n", "\n", "def _get_query(question: str):\n", " selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(ISO)']\n", " query = get_query(db, question, selected_table, llm)\n", " return query\n", "\n", "def _query_to_nl(question: str, query: str):\n", " answer = query_to_nl(db, question, query, llm)\n", " return answer" ] }, { "cell_type": "code", "execution_count": 187, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 187, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib import reload # Python 3.4+\n", "import text_to_sql2\n", "reload(text_to_sql2)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n", " self._metadata.reflect(\n" ] } ], "source": [ "from text_to_sql2 import run, get_query, query_to_nl\n", "from langchain_community.utilities import SQLDatabase\n", "import os\n", "URI: str = os.environ.get('SUPABASE_URI')\n", "db = SQLDatabase.from_uri(URI)\n", "\n", "def run_text_to_sql(question: str):\n", " selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']\n", " # question = \"建準去年的固定燃燒總排放量是多少?\"\n", " query, result, answer = run(db, question, selected_table, llm)\n", " \n", " return answer, query\n", "\n", "def _get_query(question: str):\n", " selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']\n", " query = get_query(db, question, selected_table, llm)\n", " return query\n", "\n", "def _query_to_nl(question: str, query: str):\n", " answer = query_to_nl(db, question, query, llm)\n", " return answer" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SELECT SUM(\"昆山廣興廠\") AS \"建準廣興廠綠電使用量\"\n", "FROM \"2023 清冊數據(GHG)\"\n", "WHERE \"類別\" = '類別5-使用來自組織產品間接排放'\n" ] }, { "data": { "text/plain": [ "'SELECT SUM(\"昆山廣興廠\") AS \"建準廣興廠綠電使用量\"\\nFROM \"2023 清冊數據(GHG)\"\\nWHERE \"類別\" = \\'類別5-使用來自組織產品間接排放\\''" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "question = \"建準廣興廠去年的綠電使用量是多少?\"\n", "query = _get_query(question)\n", "query" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(494312.5775,)]\n" ] }, { "data": { "text/plain": [ "'建準廣興廠去年的綠電使用量是494312.5775。'" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "answer = _query_to_nl(question, query)\n", "answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SQL Grader" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "\n", "### SQL Grader\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n", " You are a SQL query grader assessing correctness of PostgreSQL query to a user question. \n", " Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.\n", " \n", " Here is database description:\n", " {table_info}\n", " \n", " You need to check that each where statement is correctly filtered out what user question need.\n", " \n", " For example, if user question is \"建準去年的固定燃燒總排放量是多少?\", and the PostgreSQL query is \n", " \"SELECT SUM(\"排放量(公噸CO2e)\") AS \"下游租賃總排放量\"\n", " FROM \"104_112碳排放公開及建準資料\"\n", " WHERE \"事業名稱\" like '%建準%'\n", " AND \"排放源\" = '下游租賃'\n", " AND \"盤查標準\" = 'GHG'\n", " AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n", " For the above example, we can find that user asked for \"固定燃燒\", but the PostgreSQL query gives \"排放源\" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n", " \n", " Another example like \"建準去年的固定燃燒總排放量是多少?\", and the PostgreSQL query is \n", " \"SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n", " FROM \"104_112碳排放公開及建準資料\"\n", " WHERE \"事業名稱\" like '%台積電%'\n", " AND \"排放源\" = '固定燃燒'\n", " AND \"盤查標準\" = 'GHG'\n", " AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\"\n", " For the above example, we can find that user asked for \"建準\", but the PostgreSQL query gives \"事業名稱\" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.\n", " \n", " and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.\n", " \n", " If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. \n", " You need to strictly examine whether the sql PostgreSQL query matches the user question.\n", " Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \\n\n", " Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n", " <|eot_id|>\n", " \n", " <|start_header_id|>user<|end_header_id|>\n", " Here is the PostgreSQL query: \\n\\n {sql_query} \\n\\n\n", " Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", " \"\"\",\n", " input_variables=[\"table_info\", \"question\", \"sql_query\"],\n", ")\n", "\n", "sql_query_grader = prompt | llm_json | JsonOutputParser()" ] }, { "cell_type": "code", "execution_count": 180, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 'no'}\n" ] } ], "source": [ "from text_to_sql2 import table_description\n", "question = \"建準去年的類別一排放量\"\n", "# sql_query = \"\"\"\n", "# SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"類別一排放量\"\n", "# FROM \"2023 清冊數據(GHG)\"\n", "# WHERE \"類別\" = '類別一-直接排放'\n", "# \"\"\"\n", "question = \"台積電去年的固定燃燒總排放量是多少?\"\n", "sql_query = \"\"\"\n", "SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%建準%'\n", "AND \"排放源\" = '固定燃燒'\n", "AND \"盤查標準\" = 'GHG'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n", "\"\"\"\n", "print(sql_query_grader.invoke({\"table_info\": table_description(), \"question\": question, \"sql_query\": sql_query}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional details" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from importlib import reload # Python 3.4+\n", "import post_processing_sqlparse\n", "reload(post_processing_sqlparse)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "sql_query = \"\"\"\n", "SELECT SUM(\"排放量(公噸CO2e)\") AS \"固定燃燒總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%建準%'\n", "AND \"類別\" = '類別1-直接排放'\n", "AND \"盤查標準\" = 'GHG'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1;\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from post_processing_sqlparse import get_query_columns, parse_sql_for_stock_info, get_table_name" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['SUM']\n", "('固定燃燒', None)\n", "\"104_112碳排放公開及建準資料\"\n" ] } ], "source": [ "print(get_query_columns(sql_query, get_real_name=True))\n", "print(parse_sql_for_stock_info(sql_query))\n", "print(get_table_name(sql_query))" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "def generate_additional_detail(sql_query):\n", " terms = parse_sql_for_stock_info(sql_query)\n", " answer = \"\"\n", " for term in terms:\n", " if term is None: continue\n", " question_format = [f\"什麼是{term}?\", f\"{term}的用途是什麼\", f\"如何計算{term}?\"]\n", " for question in question_format:\n", " # question = f\"什麼是{term}?\"\n", " documents = retriever.get_relevant_documents(question, k=30)\n", " generation = faiss_query(question, documents, llm)\n", " answer += generation\n", " answer += \"\\n\"\n", " # print(question)\n", " # print(generation)\n", " return answer" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'直接排放(Category 1)指的是固定燃燒排放源中使用天然氣的設備所產生的溫室氣體排放量。\\n直接排放的用途主要包括固定燃燒和製程中使用含氟氣體及 N2O所產生之排放源。\\n直接排放的計算可以根據溫室氣體排放量盤查作業指引113年版進行。主要步驟如下:\\n\\n1. 依照活動數據,計算低位熱值和燃料用量。\\n2. 使用質量平衡法或直接監測法計算二氧化碳排放量。\\n3. 將排放係數乘以燃料用量和低位熱值,以取得單位產品用量。\\n\\n這些步驟可以幫助您計算類別1-直接排放的數據。\\n'" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "answer = generate_additional_detail(sql_query)\n", "answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Router" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'datasource': '自有數據'}\n" ] } ], "source": [ "### Router\n", "\n", "from langchain_community.chat_models import ChatOllama\n", "from langchain_core.output_parsers import JsonOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "\n", "# LLM\n", "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n", "\n", "prompt = PromptTemplate(\n", " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n", " You are an expert at routing a user question to a 專業知識 or 自有數據. \n", " Use company private data for questions about the informations about a company's greenhouse gas emissions data.\n", " Otherwise, use the 專業知識 for questions on ESG field knowledge or news about ESG. \n", " You do not need to be stringent with the keywords in the question related to these topics. \n", " Give a binary choice '自有數據' or '專業知識' based on the question. \n", " Return the a JSON with a single key 'datasource' and no premable or explanation. \n", " \n", " Question to route: {question} \n", " <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", " input_variables=[\"question\"],\n", ")\n", "\n", "question_router = prompt | llm_json | JsonOutputParser()\n", "question = \"建準去年的類別1排放量是多少?\"\n", "question = \"建準去年的綠電使用量是多少?\"\n", "# docs = retriever.get_relevant_documents(question)\n", "# doc_txt = docs[1].page_content\n", "print(question_router.invoke({\"question\": question}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Node" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# RAG + text-to-sql\n", "\n", "from pprint import pprint\n", "from typing import List\n", "\n", "from langchain_core.documents import Document\n", "from typing_extensions import TypedDict\n", "\n", "from langgraph.graph import END, StateGraph, START\n", "\n", "### State\n", "\n", "\n", "class GraphState(TypedDict):\n", " \"\"\"\n", " Represents the state of our graph.\n", "\n", " Attributes:\n", " question: question\n", " generation: LLM generation\n", " company_private_data: whether to search company private data\n", " documents: list of documents\n", " \"\"\"\n", "\n", " question: str\n", " generation: str\n", " documents: List[str]\n", " retry: int\n", " sql_query: str" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "def retrieve_and_generation(state):\n", " \"\"\"\n", " Retrieve documents from vectorstore\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM\n", " \"\"\"\n", " print(\"---RETRIEVE---\")\n", " question = state[\"question\"]\n", "\n", " # Retrieval\n", " # documents = retriever.invoke(question)\n", " # TODO: correct Retrieval function\n", " documents = retriever.get_relevant_documents(question, k=30)\n", " # docs_documents = \"\\n\\n\".join(doc.page_content for doc in documents)\n", " print(documents)\n", " generation = faiss_query(question, documents, llm)\n", " return {\"documents\": documents, \"question\": question, \"generation\": generation}\n", "\n", "def company_private_data_get_sql_query(state):\n", " \"\"\"\n", " Get PostgreSQL query according to question\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): return generated PostgreSQL query and record retry times\n", " \"\"\"\n", " print(\"---SQL QUERY---\")\n", " question = state[\"question\"]\n", " \n", " if state[\"retry\"]:\n", " retry = state[\"retry\"]\n", " retry += 1\n", " else: \n", " retry = 0\n", " print(\"RETRY: \", retry)\n", " \n", " sql_query = _get_query(question)\n", " \n", " return {\"sql_query\": sql_query, \"question\": question, \"retry\": retry}\n", " \n", "def company_private_data_search(state):\n", " \"\"\"\n", " Execute PostgreSQL query and convert to nature language.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Appended sql results to state\n", " \"\"\"\n", "\n", " print(\"---SQL TO NL---\")\n", " print(state)\n", " question = state[\"question\"]\n", " sql_query = state[\"sql_query\"]\n", " generation = _query_to_nl(question, sql_query)\n", " \n", " # generation = [company_private_data_result]\n", " \n", " return {\"sql_query\": sql_query, \"question\": question, \"generation\": generation}\n", "\n", "def company_private_data_search(state):\n", " \"\"\"\n", " Execute PostgreSQL query and convert to nature language.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Appended sql results to state\n", " \"\"\"\n", "\n", " print(\"---SQL TO NL---\")\n", " print(state)\n", " question = state[\"question\"]\n", " sql_query = state[\"sql_query\"]\n", " generation = _query_to_nl(question, sql_query)\n", " \n", " # generation = [company_private_data_result]\n", " \n", " return {\"sql_query\": sql_query, \"question\": question, \"generation\": generation}\n", "\n", "def additional_explanation(state):\n", " \"\"\"_summary_\n", "\n", " Args:\n", " state (_type_): _description_\n", " \n", " Returns:\n", " state (dict): Appended additional explanation to state\n", " \"\"\"\n", " \n", " print(\"---ADDITIONAL EXPLANATION---\")\n", " print(state)\n", " question = state[\"question\"]\n", " sql_query = state[\"sql_query\"]\n", " generation = generate_additional_detail(sql_query)\n", " \n", " # generation = [company_private_data_result]\n", " \n", " return {\"sql_query\": sql_query, \"question\": question, \"generation\": generation}\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conditional edge" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "### Conditional edge\n", "\n", "\n", "def route_question(state):\n", " \"\"\"\n", " Route question to web search or RAG.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " str: Next node to call\n", " \"\"\"\n", "\n", " print(\"---ROUTE QUESTION---\")\n", " question = state[\"question\"]\n", " # print(question)\n", " question_router = Router()\n", " source = question_router.invoke({\"question\": question})\n", " # print(source)\n", " print(source[\"datasource\"])\n", " if source[\"datasource\"] == \"自有數據\":\n", " print(\"---ROUTE QUESTION TO TEXT-TO-SQL---\")\n", " return \"自有數據\"\n", " elif source[\"datasource\"] == \"專業知識\":\n", " print(\"---ROUTE QUESTION TO RAG---\")\n", " return \"專業知識\"\n", " \n", "def grade_generation_v_documents_and_question(state):\n", " \"\"\"\n", " Determines whether the generation is grounded in the document and answers question.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " str: Decision for next node to call\n", " \"\"\"\n", "\n", " print(\"---CHECK HALLUCINATIONS---\")\n", " question = state[\"question\"]\n", " documents = state[\"documents\"]\n", " generation = state[\"generation\"]\n", "\n", " \n", " # print(docs_documents)\n", " # print(generation)\n", " score = hallucination_grader.invoke(\n", " {\"documents\": documents, \"generation\": generation}\n", " )\n", " print(score)\n", " grade = score[\"score\"]\n", "\n", " # Check hallucination\n", " if grade in [\"yes\", \"true\", 1, \"1\"]:\n", " print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n", " # Check question-answering\n", " print(\"---GRADE GENERATION vs QUESTION---\")\n", " score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n", " grade = score[\"score\"]\n", " if grade in [\"yes\", \"true\", 1, \"1\"]:\n", " print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n", " return \"useful\"\n", " else:\n", " print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n", " return \"not useful\"\n", " else:\n", " pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n", " return \"not supported\"\n", " \n", "def grade_sql_query(state):\n", " \"\"\"\n", " Determines whether the Postgresql query are correct to the question\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): Decision for retry or continue\n", " \"\"\"\n", "\n", " print(\"---CHECK SQL CORRECTNESS TO QUESTION---\")\n", " question = state[\"question\"]\n", " sql_query = state[\"sql_query\"]\n", " retry = state[\"retry\"]\n", "\n", " # Score each doc\n", " \n", " score = sql_query_grader.invoke({\"table_info\": table_description(), \"question\": question, \"sql_query\": sql_query})\n", " grade = score[\"score\"]\n", " # Document relevant\n", " if grade in [\"yes\", \"true\", 1, \"1\"]:\n", " print(\"---GRADE: CORRECT SQL QUERY---\")\n", " return \"correct\"\n", " elif retry >= 5:\n", " print(\"---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---\")\n", " return \"failed\"\n", " else:\n", " print(\"---GRADE: INCORRECT SQL QUERY---\")\n", " return \"incorrect\"\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Graph" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "\n", "from langgraph.pregel import RetryPolicy\n", "\n", "workflow = StateGraph(GraphState)\n", "\n", "# Define the nodes\n", "workflow.add_node(\"Text-to-SQL\", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5)) # web search\n", "workflow.add_node(\"SQL Answer\", company_private_data_search, retry=RetryPolicy(max_attempts=5)) # web search\n", "workflow.add_node(\"Additoinal Explanation\", additional_explanation, retry=RetryPolicy(max_attempts=5)) # retrieve\n", "workflow.add_node(\"RAG\", retrieve_and_generation, retry=RetryPolicy(max_attempts=5)) # retrieve\n", "# workflow.add_node(\"grade_generation\", grade_documents) # grade documents\n", "# workflow.add_node(\"generate\", generate) # generatae\n", "\n", "workflow.add_conditional_edges(\n", " START,\n", " route_question,\n", " {\n", " \"自有數據\": \"Text-to-SQL\",\n", " \"專業知識\": \"RAG\",\n", " },\n", ")\n", "\n", "workflow.add_conditional_edges(\n", " \"RAG\",\n", " grade_generation_v_documents_and_question,\n", " {\n", " \"not supported\": \"RAG\",\n", " \"useful\": END,\n", " \"not useful\": \"RAG\",\n", " },\n", ")\n", "workflow.add_conditional_edges(\n", " \"Text-to-SQL\",\n", " grade_sql_query,\n", " {\n", " \"correct\": \"SQL Answer\",\n", " \"incorrect\": \"Text-to-SQL\",\n", " \"failed\": \"RAG\"\n", " \n", " },\n", ")\n", "workflow.add_edge(\"SQL Answer\", \"Additoinal Explanation\")\n", "workflow.add_edge(\"Additoinal Explanation\", END)\n", "\n", "\n", "\n", "# workflow.add_edge(\"company_private_data_search\", END)\n", "\n", "app = workflow.compile()" ] }, { "cell_type": "code", "execution_count": 224, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "%%{init: {'flowchart': {'curve': 'linear'}}}%%\n", "graph TD;\n", "\t__start__([__start__]):::first\n", "\tcompany_private_data_query(company_private_data_query)\n", "\tcompany_private_data_search(company_private_data_search)\n", "\tretrieve_and_generation(retrieve_and_generation)\n", "\t__end__([__end__]):::last\n", "\tcompany_private_data_query --> company_private_data_search;\n", "\t__start__ -.  company_private_data  .-> company_private_data_query;\n", "\t__start__ -.  vectorstore  .-> retrieve_and_generation;\n", "\tretrieve_and_generation -.  not supported  .-> retrieve_and_generation;\n", "\tretrieve_and_generation -.  useful  .-> __end__;\n", "\tretrieve_and_generation -.  not useful  .-> retrieve_and_generation;\n", "\tcompany_private_data_search -.  not supported  .-> company_private_data_query;\n", "\tcompany_private_data_search -.  useful  .-> __end__;\n", "\tcompany_private_data_search -.  not useful  .-> retrieve_and_generation;\n", "\tclassDef default fill:#f2f0ff,line-height:1.2\n", "\tclassDef first fill-opacity:0\n", "\tclassDef last fill:#bfb6fc\n", "\n" ] } ], "source": [ "print(app.get_graph().draw_mermaid())" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import Image, display\n", "from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles\n", "\n", "display(\n", " Image(\n", " app.get_graph().draw_mermaid_png(\n", " draw_method=MermaidDrawMethod.API,\n", " output_file_path=\"agent_workflow.png\",\n", " )\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test" ] }, { "cell_type": "code", "execution_count": 288, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---ROUTE QUESTION---\n", "建準去年的類別八排放量\n", "{'datasource': 'company_private_data'}\n", "company_private_data\n", "---ROUTE QUESTION TO TEXT-TO-SQL---\n", "---SQL QUERY---\n", "RETRY: 0\n", "SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別8總排放量\"\n", "FROM \"104_112碳排放公開及建準資料\"\n", "WHERE \"事業名稱\" like '%建準%'\n", "AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\n", "AND \"類別\" = '類別8-';\n", "---CHECK SQL CORRECTNESS TO QUESTION---\n", "---GRADE: CORRECT SQL QUERY---\n", "'Finished running: company_private_data_query:'\n", "---SQL TO NL---\n", "{'question': '建準去年的類別八排放量', 'generation': None, 'documents': None, 'retry': 0, 'sql_query': 'SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別8總排放量\"\\nFROM \"104_112碳排放公開及建準資料\"\\nWHERE \"事業名稱\" like \\'%建準%\\'\\nAND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\\nAND \"類別\" = \\'類別8-\\';'}\n", "[(None,)]\n", "'Finished running: company_private_data_search:'\n", "('SELECT SUM(\"排放量(公噸CO2e)\") AS \"類別8總排放量\"\\n'\n", " 'FROM \"104_112碳排放公開及建準資料\"\\n'\n", " 'WHERE \"事業名稱\" like \\'%建準%\\'\\n'\n", " 'AND \"年度\" = EXTRACT(YEAR FROM CURRENT_DATE)-1\\n'\n", " 'AND \"類別\" = \\'類別8-\\';',\n", " '[(None,)]',\n", " '根據 SQL '\n", " '查詢和結果,去年建準的類別八排放量為空,可能是資料庫中沒有符合條件的資料。建議檢查資料庫中是否有符合條件的資料,或調整查詢條件以確保能夠正確查詢到相應的數據。')\n" ] } ], "source": [ "# Test\n", "\n", "inputs = {\"question\": \"建準去年的類別八排放量\"}\n", "for output in app.stream(inputs, {\"recursion_limit\": 10}):\n", " for key, value in output.items():\n", " pprint(f\"Finished running: {key}:\")\n", "pprint(value[\"generation\"])" ] }, { "cell_type": "code", "execution_count": 181, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'建準去年的類別一排放量是13.5953。'" ] }, "execution_count": 181, "metadata": {}, "output_type": "execute_result" } ], "source": [ "value[\"generation\"]" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'stop' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[27], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mstop\u001b[49m\n", "\u001b[0;31mNameError\u001b[0m: name 'stop' is not defined" ] } ], "source": [ "stop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# from RAG_strategy import multi_query_chain\n", "llm = ChatOllama(model=local_llm, temperature=0)\n", "question = \"溫室氣體是什麼\"\n", "generate_queries = multi_query_chain(llm)\n", "\n", "questions = generate_queries.invoke(question)\n", "questions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def ask_more_detail_chain(llm):\n", " # Multi Query: Different Perspectives\n", " template = \"\"\"\n", " <|begin_of_text|>\n", " \n", " <|start_header_id|>system<|end_header_id|>\n", " 你是一個來自台灣的AI助理,你的專長是根據使用者提供的文本來進一步詢問當中的細節,例如名詞解釋,請用繁體中文。 \\n\n", " You are an AI language model assistant. \n", " Your task is to generate three questions about the given user context as additional explanation.\n", " By generating three in-depth questions about the user context, your goal is to help the user realize more details. \n", " Provide these questions separated by newlines.\n", " For example:\n", " context: 建準廣興廠去年2023年一共自產發電了684,508度綠電\n", " in-depth question:什麼是綠電?\\n為何要使用綠電? \n", "\n", " output must in user's language and no preamble or explanation.\n", " <|eot_id|>\n", " \n", " <|start_header_id|>user<|end_header_id|>\n", " \n", " \n", " \n", " Original context: {question}\n", " three questions:\n", " <|eot_id|>\n", " \n", " <|start_header_id|>assistant<|end_header_id|>\"\"\"\n", " prompt_perspectives = ChatPromptTemplate.from_template(template)\n", "\n", " \n", " # llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n", " # llm = ChatOllama(model=\"llama3\", num_gpu=1, temperature=0)\n", "\n", " generate_queries = (\n", " prompt_perspectives \n", " | llm\n", " | StrOutputParser() \n", " | (lambda x: x.split(\"\\n\"))\n", " )\n", "\n", " return generate_queries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm = ChatOllama(model=local_llm, temperature=0)\n", "question = \"建準廣興廠去年的固定燃燒排放量是多少?\"\n", "generate_queries = ask_more_detail_chain(llm)\n", "\n", "questions = generate_queries.invoke(question)\n", "questions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "llm = ChatOllama(model=local_llm, temperature=0)\n", "question = \"固定燃燒是什麼?\"\n", "docs = retriever.get_relevant_documents(question, k=10)\n", "generation = faiss_query(question, docs, llm)\n", "generation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import datetime\n", "from typing import Literal, Optional, Tuple\n", "\n", "from langchain_core.pydantic_v1 import BaseModel, Field\n", "\n", "\n", "class SubQuery(BaseModel):\n", " \"\"\"Search over a database of tutorial videos about a software library.\"\"\"\n", "\n", " sub_query: str = Field(\n", " ...,\n", " description=\"A very specific query against the database.\",\n", " )" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[SubQuery(sub_query='建準去年的類別一排放量是多少?'), SubQuery(sub_query='溫室氣體是什麼?')]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langchain.output_parsers import PydanticToolsParser\n", "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_openai import ChatOpenAI\n", "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", "\n", "system = \"\"\"You are an expert at converting user questions into database queries. \\\n", "\n", "Perform query decomposition. Given a user question, break it down into distinct sub questions that \\\n", "you need to answer in order to answer the original question.\n", "\n", "If there are acronyms or words you are not familiar with, do not try to rephrase them.\n", "用繁體中文.\n", "\"\"\"\n", "prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\"system\", system),\n", " (\"human\", \"{question}用繁體中文\"),\n", " ]\n", ")\n", "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0.5)\n", "llm_with_tools = llm.bind_tools([SubQuery])\n", "parser = PydanticToolsParser(tools=[SubQuery])\n", "query_analyzer = prompt | llm_with_tools | parser\n", "\n", "query_analyzer.invoke({\"question\": \"建準去年的類別一排放量?溫室氣體是什麼?\"})" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[SubQuery(sub_query='什麼是溫室氣體'), SubQuery(sub_query='去年的類別一排放量是多少')]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query_analyzer.invoke({\"question\": \"溫室氣體是什麼?建準去年的類別一排放量?\"})" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query_analyzer.invoke({\"question\": \"建準去年的類別一排放量?\"})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "You have access to a database of tutorial videos about a software library for building LLM-powered applications. \\" ] } ], "metadata": { "kernelspec": { "display_name": "llama3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }