瀏覽代碼

新增測試 multi-agents 的 ipynb

ling 5 月之前
父節點
當前提交
18f665deb2
共有 2 個文件被更改,包括 1647 次插入20 次删除
  1. 1619 0
      ai_agent.ipynb
  2. 28 20
      text_to_sql.py

+ 1619 - 0
ai_agent.ipynb

@@ -0,0 +1,1619 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# LLM"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### LLM\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "# local_llm = \"llama3.1:8b-instruct-fp16\"\n",
+    "local_llm = \"llama3-groq-tool-use:latest\"\n",
+    "\n",
+    "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "llm = ChatOllama(model=local_llm, temperature=0)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# RAG"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "FAISS index loaded from faiss_index.bin\n",
+      "Metadata loaded from faiss_metadata.pkl\n",
+      "Using existing FAISS index and metadata.\n",
+      "Creating FAISS retriever...\n"
+     ]
+    }
+   ],
+   "source": [
+    "from faiss_index import create_faiss_retriever, faiss_query\n",
+    "retriever = create_faiss_retriever()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain.prompts import ChatPromptTemplate\n",
+    "from langchain_core.output_parsers import StrOutputParser\n",
+    "def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:\n",
+    "    context = docs\n",
+    "    # try:\n",
+    "    #     context = \"\\n\".join(doc.page_content for doc in docs)\n",
+    "    # except:\n",
+    "    #     context = \"\\n\".join(doc for doc in docs)\n",
+    "    \n",
+    "    system_prompt: str = \"你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。\"\n",
+    "    template = \"\"\"\n",
+    "    <|begin_of_text|>\n",
+    "    \n",
+    "    <|start_header_id|>system<|end_header_id|>\n",
+    "    你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \\n\n",
+    "    You should not mention anything about \"根據提供的文件內容\" or other similar terms.\n",
+    "    Use five sentences maximum and keep the answer concise.\n",
+    "    如果你不知道答案請回答:\"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。\"\n",
+    "    勿回答無關資訊\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>user<|end_header_id|>\n",
+    "    Answer the following question based on this context:\n",
+    "\n",
+    "    {context}\n",
+    "\n",
+    "    Question: {question}\n",
+    "    用繁體中文\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>assistant<|end_header_id|>\n",
+    "    \"\"\"\n",
+    "    prompt = ChatPromptTemplate.from_template(\n",
+    "        system_prompt + \"\\n\\n\" +\n",
+    "        template\n",
+    "    )\n",
+    "    \n",
+    "    # prompt = ChatPromptTemplate.from_template(\n",
+    "    #     system_prompt + \"\\n\\n\" +\n",
+    "    #     \"Answer the following question based on this context:\\n\\n\"\n",
+    "    #     \"{context}\\n\\n\"\n",
+    "    #     \"Question: {question}\\n\"\n",
+    "    #     \"Answer in the same language as the question. If you don't know the answer, \"\n",
+    "    #     \"say 'I'm sorry, I don't have enough information to answer that question.'\"\n",
+    "    # )\n",
+    "\n",
+    "    \n",
+    "    # chain = prompt | taide_llm | StrOutputParser()\n",
+    "    chain = prompt | llm | StrOutputParser()\n",
+    "    return chain.invoke({\"context\": context, \"question\": question})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Generate\n",
+    "# llm = ChatOllama(model=local_llm, temperature=0)\n",
+    "\n",
+    "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n",
+    "# generation = faiss_query(question, docs_documents, llm)\n",
+    "# generation"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Retrieval Grader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Retrieval Grader\n",
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
+    "# LLM\n",
+    "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance \n",
+    "    of a retrieved document to a user question. If the document contains keywords related to the user question, \n",
+    "    grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n",
+    "    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \\n\n",
+    "    Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n",
+    "     <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+    "    Here is the retrieved document: \\n\\n {document} \\n\\n\n",
+    "    Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
+    "    \"\"\",\n",
+    "    input_variables=[\"question\", \"document\"],\n",
+    ")\n",
+    "\n",
+    "retrieval_grader = prompt | llm_json | JsonOutputParser()\n",
+    "# question = \"溫室氣體是什麼\"\n",
+    "# # docs = retriever.invoke(question)\n",
+    "# docs = retriever.get_relevant_documents(question, k=10)\n",
+    "# doc_txt = docs[1].page_content\n",
+    "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# for doc in docs:\n",
+    "#     doc_txt = doc.page_content\n",
+    "#     print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Hallucination Grader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Hallucination Grader\n",
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
+    "# LLM\n",
+    "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "# Prompt\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
+    "    You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n",
+    "    Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n",
+    "    Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. \n",
+    "    Return the a JSON with a single key 'score' and no premable or explanation. \n",
+    "    <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+    "    Here are the facts:\n",
+    "    \\n ------- \\n\n",
+    "    {documents} \n",
+    "    \\n ------- \\n\n",
+    "    Here is the answer: {generation}  <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
+    "    input_variables=[\"generation\", \"documents\"],\n",
+    ")\n",
+    "\n",
+    "hallucination_grader = prompt | llm_json | JsonOutputParser()\n",
+    "\n",
+    "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n",
+    "\n",
+    "# hallucination_grader.invoke({\"documents\": docs_documents, \"generation\": generation})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Answer Grader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Answer Grader\n",
+    "\n",
+    "# LLM\n",
+    "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "# Prompt\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an \n",
+    "    answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is \n",
+    "    useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.\n",
+    "     <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:\n",
+    "    \\n ------- \\n\n",
+    "    {generation} \n",
+    "    \\n ------- \\n\n",
+    "    Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
+    "    input_variables=[\"generation\", \"question\"],\n",
+    ")\n",
+    "\n",
+    "answer_grader = prompt | llm_json | JsonOutputParser()\n",
+    "# answer_grader.invoke({\"question\": question, \"generation\": generation})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# SQL"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n",
+      "  self._metadata.reflect(\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<module 'text_to_sql' from '/home/ling/systex/text_to_sql.py'>"
+      ]
+     },
+     "execution_count": 32,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from importlib import reload  # Python 3.4+\n",
+    "import text_to_sql\n",
+    "reload(text_to_sql)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n",
+      "  self._metadata.reflect(\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "from text_to_sql import run\n",
+    "from langchain_community.utilities import SQLDatabase\n",
+    "import os\n",
+    "URI: str =  os.environ.get('SUPABASE_URI')\n",
+    "db = SQLDatabase.from_uri(URI)\n",
+    "\n",
+    "def run_text_to_sql(question: str):\n",
+    "    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']\n",
+    "    # question = \"建準去年的固定燃燒總排放量是多少?\"\n",
+    "    query, result, answer = run(db, question, selected_table, llm)\n",
+    "    \n",
+    "    return  answer, query"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "question = \"建準去年的固定燃燒總排放量是多少?\"\n",
+    "run_text_to_sql(question)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Hallucination Grader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Hallucination Grader\n",
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
+    "# LLM\n",
+    "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "# Prompt\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
+    "    You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n",
+    "    Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n",
+    "    Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. \n",
+    "    Return the a JSON with a single key 'score' and no premable or explanation. \n",
+    "    <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
+    "    Here are the facts:\n",
+    "    \\n ------- \\n\n",
+    "    {documents} \n",
+    "    \\n ------- \\n\n",
+    "    Here is the answer: {generation}  <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
+    "    input_variables=[\"generation\", \"documents\"],\n",
+    ")\n",
+    "\n",
+    "hallucination_grader = prompt | llm_json | JsonOutputParser()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Router"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'datasource': 'company_private_data'}\n"
+     ]
+    }
+   ],
+   "source": [
+    "### Router\n",
+    "\n",
+    "from langchain_community.chat_models import ChatOllama\n",
+    "from langchain_core.output_parsers import JsonOutputParser\n",
+    "from langchain_core.prompts import PromptTemplate\n",
+    "\n",
+    "# LLM\n",
+    "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
+    "\n",
+    "prompt = PromptTemplate(\n",
+    "    template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
+    "    You are an expert at routing a user question to a vectorstore or company private data. \n",
+    "    Use company private data for questions about the informations about a company's greenhouse gas emissions data.\n",
+    "    Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. \n",
+    "    You do not need to be stringent with the keywords in the question related to these topics. \n",
+    "    Give a binary choice 'company_private_data' or 'vectorstore' based on the question. \n",
+    "    Return the a JSON with a single key 'datasource' and no premable or explanation. \n",
+    "    \n",
+    "    Question to route: {question} \n",
+    "    <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
+    "    input_variables=[\"question\"],\n",
+    ")\n",
+    "\n",
+    "question_router = prompt | llm_json | JsonOutputParser()\n",
+    "question = \"建準去年的類別1排放量是多少?\"\n",
+    "# docs = retriever.get_relevant_documents(question)\n",
+    "# doc_txt = docs[1].page_content\n",
+    "print(question_router.invoke({\"question\": question}))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Node"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# RAG + text-to-sql\n",
+    "\n",
+    "from pprint import pprint\n",
+    "from typing import List\n",
+    "\n",
+    "from langchain_core.documents import Document\n",
+    "from typing_extensions import TypedDict\n",
+    "\n",
+    "from langgraph.graph import END, StateGraph, START\n",
+    "\n",
+    "### State\n",
+    "\n",
+    "\n",
+    "class GraphState(TypedDict):\n",
+    "    \"\"\"\n",
+    "    Represents the state of our graph.\n",
+    "\n",
+    "    Attributes:\n",
+    "        question: question\n",
+    "        generation: LLM generation\n",
+    "        company_private_data: whether to search company private data\n",
+    "        documents: list of documents\n",
+    "    \"\"\"\n",
+    "\n",
+    "    question: str\n",
+    "    generation: str\n",
+    "    company_private_data: str\n",
+    "    documents: List[str]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def retrieve_and_generation(state):\n",
+    "    \"\"\"\n",
+    "    Retrieve documents from vectorstore\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM\n",
+    "    \"\"\"\n",
+    "    print(\"---RETRIEVE---\")\n",
+    "    question = state[\"question\"]\n",
+    "\n",
+    "    # Retrieval\n",
+    "    # documents = retriever.invoke(question)\n",
+    "    # TODO: correct Retrieval function\n",
+    "    documents = retriever.get_relevant_documents(question, k=10)\n",
+    "    docs_documents = \"\\n\\n\".join(doc.page_content for doc in documents)\n",
+    "    generation = faiss_query(question, docs_documents, llm)\n",
+    "    return {\"documents\": docs_documents, \"question\": question, \"generation\": generation}\n",
+    "\n",
+    "\n",
+    "def company_private_data_search(state):\n",
+    "    \"\"\"\n",
+    "    Web search based based on the question\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        state (dict): Appended web results to documents\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---SQL---\")\n",
+    "    question = state[\"question\"]\n",
+    "    # documents = state[\"documents\"]\n",
+    "\n",
+    "    # company_private_data_search\n",
+    "    # docs = web_search_tool.invoke({\"query\": question})\n",
+    "    # web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
+    "    # web_results = Document(page_content=web_results)\n",
+    "    # TODO: correct company_private_data_search function\n",
+    "    company_private_data_result, sql_query = run_text_to_sql(question)\n",
+    "    company_private_data_result\n",
+    "    \n",
+    "    # generation = [company_private_data_result]\n",
+    "    \n",
+    "    return {\"documents\": sql_query, \"question\": question, \"generation\": company_private_data_result}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Conditional edge"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### Conditional edge\n",
+    "\n",
+    "\n",
+    "def route_question(state):\n",
+    "    \"\"\"\n",
+    "    Route question to web search or RAG.\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        str: Next node to call\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---ROUTE QUESTION---\")\n",
+    "    question = state[\"question\"]\n",
+    "    print(question)\n",
+    "    source = question_router.invoke({\"question\": question})\n",
+    "    print(source)\n",
+    "    print(source[\"datasource\"])\n",
+    "    if source[\"datasource\"] == \"company_private_data\":\n",
+    "        print(\"---ROUTE QUESTION TO TEXT-TO-SQL---\")\n",
+    "        return \"company_private_data\"\n",
+    "    elif source[\"datasource\"] == \"vectorstore\":\n",
+    "        print(\"---ROUTE QUESTION TO RAG---\")\n",
+    "        return \"vectorstore\"\n",
+    "    \n",
+    "def grade_generation_v_documents_and_question(state):\n",
+    "    \"\"\"\n",
+    "    Determines whether the generation is grounded in the document and answers question.\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        str: Decision for next node to call\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---CHECK HALLUCINATIONS---\")\n",
+    "    question = state[\"question\"]\n",
+    "    documents = state[\"documents\"]\n",
+    "    generation = state[\"generation\"]\n",
+    "\n",
+    "    \n",
+    "    # print(docs_documents)\n",
+    "    # print(generation)\n",
+    "    score = hallucination_grader.invoke(\n",
+    "        {\"documents\": documents, \"generation\": generation}\n",
+    "    )\n",
+    "    print(score)\n",
+    "    grade = score[\"score\"]\n",
+    "\n",
+    "    # Check hallucination\n",
+    "    if grade in [\"yes\", \"true\", 1, \"1\"]:\n",
+    "        print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n",
+    "        # Check question-answering\n",
+    "        print(\"---GRADE GENERATION vs QUESTION---\")\n",
+    "        score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
+    "        grade = score[\"score\"]\n",
+    "        if grade in [\"yes\", \"true\", 1, \"1\"]:\n",
+    "            print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n",
+    "            return \"useful\"\n",
+    "        else:\n",
+    "            print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n",
+    "            return \"not useful\"\n",
+    "    else:\n",
+    "        pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n",
+    "        return \"not supported\"\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Graph"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 65,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "workflow = StateGraph(GraphState)\n",
+    "\n",
+    "# Define the nodes\n",
+    "workflow.add_node(\"company_private_data_search\", company_private_data_search)  # web search\n",
+    "workflow.add_node(\"retrieve_and_generation\", retrieve_and_generation)  # retrieve\n",
+    "# workflow.add_node(\"grade_generation\", grade_documents)  # grade documents\n",
+    "# workflow.add_node(\"generate\", generate)  # generatae\n",
+    "\n",
+    "workflow.add_conditional_edges(\n",
+    "    START,\n",
+    "    route_question,\n",
+    "    {\n",
+    "        \"company_private_data\": \"company_private_data_search\",\n",
+    "        \"vectorstore\": \"retrieve_and_generation\",\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "workflow.add_conditional_edges(\n",
+    "    \"retrieve_and_generation\",\n",
+    "    grade_generation_v_documents_and_question,\n",
+    "    {\n",
+    "        \"not supported\": \"retrieve_and_generation\",\n",
+    "        \"useful\": END,\n",
+    "        \"not useful\": \"retrieve_and_generation\",\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "workflow.add_conditional_edges(\n",
+    "    \"company_private_data_search\",\n",
+    "    grade_generation_v_documents_and_question,\n",
+    "    {\n",
+    "        \"not supported\": \"company_private_data_search\",\n",
+    "        \"useful\": END,\n",
+    "        \"not useful\": \"retrieve_and_generation\",\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "# workflow.add_edge(\"company_private_data_search\", END)\n",
+    "\n",
+    "app = workflow.compile()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "%%{init: {'flowchart': {'curve': 'linear'}}}%%\n",
+      "graph TD;\n",
+      "\t__start__([__start__]):::first\n",
+      "\tcompany_private_data_search(company_private_data_search)\n",
+      "\tretrieve_and_generation(retrieve_and_generation)\n",
+      "\t__end__([__end__]):::last\n",
+      "\tcompany_private_data_search --> __end__;\n",
+      "\t__start__ -. &nbspcompany_private_data&nbsp .-> company_private_data_search;\n",
+      "\t__start__ -. &nbspvectorstore&nbsp .-> retrieve_and_generation;\n",
+      "\tretrieve_and_generation -. &nbspnot supported&nbsp .-> retrieve_and_generation;\n",
+      "\tretrieve_and_generation -. &nbspuseful&nbsp .-> __end__;\n",
+      "\tretrieve_and_generation -. &nbspnot useful&nbsp .-> company_private_data_search;\n",
+      "\tclassDef default fill:#f2f0ff,line-height:1.2\n",
+      "\tclassDef first fill-opacity:0\n",
+      "\tclassDef last fill:#bfb6fc\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(app.get_graph().draw_mermaid())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'app' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[1], line 6\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image, display\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrunnables\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgraph\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CurveStyle, MermaidDrawMethod, NodeStyles\n\u001b[1;32m      4\u001b[0m display(\n\u001b[1;32m      5\u001b[0m     Image(\n\u001b[0;32m----> 6\u001b[0m         \u001b[43mapp\u001b[49m\u001b[38;5;241m.\u001b[39mget_graph()\u001b[38;5;241m.\u001b[39mdraw_mermaid_png(\n\u001b[1;32m      7\u001b[0m             draw_method\u001b[38;5;241m=\u001b[39mMermaidDrawMethod\u001b[38;5;241m.\u001b[39mAPI,\n\u001b[1;32m      8\u001b[0m             output_file_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124magent_workflow.png\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m      9\u001b[0m         )\n\u001b[1;32m     10\u001b[0m     )\n\u001b[1;32m     11\u001b[0m )\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'app' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "from IPython.display import Image, display\n",
+    "from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles\n",
+    "\n",
+    "display(\n",
+    "    Image(\n",
+    "        app.get_graph().draw_mermaid_png(\n",
+    "            draw_method=MermaidDrawMethod.API,\n",
+    "            output_file_path=\"agent_workflow.png\",\n",
+    "        )\n",
+    "    )\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 67,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "---ROUTE QUESTION---\n",
+      "建準2024的固定燃燒總排放量是多少?\n",
+      "{'datasource': 'company_private_data'}\n",
+      "company_private_data\n",
+      "---ROUTE QUESTION TO TEXT-TO-SQL---\n",
+      "---SQL---\n",
+      "SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"固定燃燒總排放量\"\n",
+      "FROM \"2023 清冊數據(GHG)\"\n",
+      "WHERE \"排放源\" = '固定燃燒'\n",
+      "[(13.5953,)]\n",
+      "---CHECK HALLUCINATIONS---\n",
+      "{'score': True}\n",
+      "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\n",
+      "---GRADE GENERATION vs QUESTION---\n",
+      "---DECISION: GENERATION ADDRESSES QUESTION---\n",
+      "'Finished running: company_private_data_search:'\n",
+      "'建準2024的固定燃燒總排放量是13.5953。'\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Test\n",
+    "\n",
+    "inputs = {\"question\": \"建準2024的固定燃燒總排放量是多少?\"}\n",
+    "for output in app.stream(inputs):\n",
+    "    for key, value in output.items():\n",
+    "        pprint(f\"Finished running: {key}:\")\n",
+    "pprint(value[\"generation\"])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# RAG + text-to-sql\n",
+    "\n",
+    "from pprint import pprint\n",
+    "from typing import List\n",
+    "\n",
+    "from langchain_core.documents import Document\n",
+    "from typing_extensions import TypedDict\n",
+    "\n",
+    "from langgraph.graph import END, StateGraph, START\n",
+    "\n",
+    "### State\n",
+    "\n",
+    "\n",
+    "class GraphState(TypedDict):\n",
+    "    \"\"\"\n",
+    "    Represents the state of our graph.\n",
+    "\n",
+    "    Attributes:\n",
+    "        question: question\n",
+    "        generation: LLM generation\n",
+    "        company_private_data: whether to search company private data\n",
+    "        documents: list of documents\n",
+    "    \"\"\"\n",
+    "\n",
+    "    question: str\n",
+    "    generation: str\n",
+    "    company_private_data: str\n",
+    "    documents: List[str]\n",
+    "\n",
+    "\n",
+    "### Nodes\n",
+    "\n",
+    "\n",
+    "def retrieve(state):\n",
+    "    \"\"\"\n",
+    "    Retrieve documents from vectorstore\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        state (dict): New key added to state, documents, that contains retrieved documents\n",
+    "    \"\"\"\n",
+    "    print(\"---RETRIEVE---\")\n",
+    "    question = state[\"question\"]\n",
+    "\n",
+    "    # Retrieval\n",
+    "    # documents = retriever.invoke(question)\n",
+    "    # TODO: correct Retrieval function\n",
+    "    documents = retriever.get_relevant_documents(question, k=10)\n",
+    "    docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n",
+    "    generation = faiss_query(question, docs_documents, llm)\n",
+    "    return {\"documents\": documents, \"question\": question}\n",
+    "\n",
+    "\n",
+    "def generate(state):\n",
+    "    \"\"\"\n",
+    "    Generate answer using RAG on retrieved documents\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        state (dict): New key added to state, generation, that contains LLM generation\n",
+    "    \"\"\"\n",
+    "    print(\"---GENERATE---\")\n",
+    "    question = state[\"question\"]\n",
+    "    documents = state[\"documents\"]\n",
+    "\n",
+    "    # RAG generation\n",
+    "    # generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
+    "    # TODO: correct generation function\n",
+    "    print(documents)\n",
+    "    print(question)\n",
+    "    generation = faiss_query(question, documents, llm)\n",
+    "    generation = eval(generation)['data']['answer']\n",
+    "    return {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
+    "\n",
+    "\n",
+    "def grade_documents(state):\n",
+    "    \"\"\"\n",
+    "    Determines whether the retrieved documents are relevant to the question\n",
+    "    If any document is not relevant, we will set a flag to run web search\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        state (dict): Filtered out irrelevant documents and updated company_private_data state\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---CHECK DOCUMENT RELEVANCE TO QUESTION---\")\n",
+    "    question = state[\"question\"]\n",
+    "    documents = state[\"documents\"]\n",
+    "\n",
+    "    # Score each doc\n",
+    "    filtered_docs = []\n",
+    "    company_private_data = \"No\"\n",
+    "    for d in documents:\n",
+    "        score = retrieval_grader.invoke(\n",
+    "            {\"question\": question, \"document\": d.page_content}\n",
+    "        )\n",
+    "        grade = score[\"score\"]\n",
+    "        # Document relevant\n",
+    "        if grade.lower() == \"yes\":\n",
+    "            print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
+    "            filtered_docs.append(d)\n",
+    "        # Document not relevant\n",
+    "        else:\n",
+    "            print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
+    "            # We do not include the document in filtered_docs\n",
+    "            # We set a flag to indicate that we want to run web search\n",
+    "            # company_private_data = \"Yes\"\n",
+    "            continue\n",
+    "    return {\"documents\": filtered_docs, \"question\": question, \"company_private_data\": company_private_data}\n",
+    "\n",
+    "\n",
+    "def company_private_data_search(state):\n",
+    "    \"\"\"\n",
+    "    Web search based based on the question\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        state (dict): Appended web results to documents\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---WEB SEARCH---\")\n",
+    "    question = state[\"question\"]\n",
+    "    documents = state[\"documents\"]\n",
+    "\n",
+    "    # company_private_data_search\n",
+    "    # docs = web_search_tool.invoke({\"query\": question})\n",
+    "    # web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
+    "    # web_results = Document(page_content=web_results)\n",
+    "    # TODO: correct company_private_data_search function\n",
+    "    company_private_data_result = run_text_to_sql(question)\n",
+    "    if documents is not None:\n",
+    "        documents.append(company_private_data_result)\n",
+    "    else:\n",
+    "        documents = [company_private_data_result]\n",
+    "    return {\"documents\": documents, \"question\": question}\n",
+    "\n",
+    "\n",
+    "### Conditional edge\n",
+    "\n",
+    "\n",
+    "def route_question(state):\n",
+    "    \"\"\"\n",
+    "    Route question to web search or RAG.\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        str: Next node to call\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---ROUTE QUESTION---\")\n",
+    "    question = state[\"question\"]\n",
+    "    print(question)\n",
+    "    source = question_router.invoke({\"question\": question})\n",
+    "    print(source)\n",
+    "    print(source[\"datasource\"])\n",
+    "    if source[\"datasource\"] == \"company_private_data\":\n",
+    "        print(\"---ROUTE QUESTION TO TEXT-TO-SQL---\")\n",
+    "        return \"company_private_data\"\n",
+    "    elif source[\"datasource\"] == \"vectorstore\":\n",
+    "        print(\"---ROUTE QUESTION TO RAG---\")\n",
+    "        return \"vectorstore\"\n",
+    "\n",
+    "\n",
+    "def decide_to_generate(state):\n",
+    "    \"\"\"\n",
+    "    Determines whether to generate an answer, or add company_private_data\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        str: Binary decision for next node to call\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---ASSESS GRADED DOCUMENTS---\")\n",
+    "    state[\"question\"]\n",
+    "    company_private_data = state[\"company_private_data\"]\n",
+    "    state[\"documents\"]\n",
+    "    \n",
+    "\n",
+    "    if company_private_data == \"Yes\":\n",
+    "        # All documents have been filtered check_relevance\n",
+    "        # We will re-generate a new query\n",
+    "        print(\n",
+    "            \"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE company_private_data---\"\n",
+    "        )\n",
+    "        return \"company_private_data\"\n",
+    "    else:\n",
+    "        # We have relevant documents, so generate answer\n",
+    "        print(\"---DECISION: GENERATE---\")\n",
+    "        return \"generate\"\n",
+    "\n",
+    "\n",
+    "### Conditional edge\n",
+    "\n",
+    "\n",
+    "def grade_generation_v_documents_and_question(state):\n",
+    "    \"\"\"\n",
+    "    Determines whether the generation is grounded in the document and answers question.\n",
+    "\n",
+    "    Args:\n",
+    "        state (dict): The current graph state\n",
+    "\n",
+    "    Returns:\n",
+    "        str: Decision for next node to call\n",
+    "    \"\"\"\n",
+    "\n",
+    "    print(\"---CHECK HALLUCINATIONS---\")\n",
+    "    question = state[\"question\"]\n",
+    "    documents = state[\"documents\"]\n",
+    "    generation = state[\"generation\"]\n",
+    "\n",
+    "    try:\n",
+    "        docs_documents = \"\\n\\n\".join(doc.page_content for doc in documents)\n",
+    "    except AttributeError:\n",
+    "        docs_documents = \"\\n\\n\".join(doc for doc in documents)\n",
+    "    # print(docs_documents)\n",
+    "    # print(generation)\n",
+    "    score = hallucination_grader.invoke(\n",
+    "        {\"documents\": docs_documents, \"generation\": generation}\n",
+    "    )\n",
+    "    print(score)\n",
+    "    grade = score[\"score\"]\n",
+    "\n",
+    "    # Check hallucination\n",
+    "    if grade == \"yes\":\n",
+    "        print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n",
+    "        # Check question-answering\n",
+    "        print(\"---GRADE GENERATION vs QUESTION---\")\n",
+    "        score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
+    "        grade = score[\"score\"]\n",
+    "        if grade == \"yes\":\n",
+    "            print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n",
+    "            return \"useful\"\n",
+    "        else:\n",
+    "            print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n",
+    "            return \"not useful\"\n",
+    "    else:\n",
+    "        pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n",
+    "        return \"not supported\"\n",
+    "\n",
+    "\n",
+    "workflow = StateGraph(GraphState)\n",
+    "\n",
+    "# Define the nodes\n",
+    "workflow.add_node(\"company_private_data_search\", company_private_data_search)  # web search\n",
+    "workflow.add_node(\"retrieve\", retrieve)  # retrieve\n",
+    "workflow.add_node(\"grade_documents\", grade_documents)  # grade documents\n",
+    "workflow.add_node(\"generate\", generate)  # generatae"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Build graph\n",
+    "workflow.add_conditional_edges(\n",
+    "    START,\n",
+    "    route_question,\n",
+    "    {\n",
+    "        \"company_private_data\": \"company_private_data_search\",\n",
+    "        \"vectorstore\": \"retrieve\",\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
+    "workflow.add_conditional_edges(\n",
+    "    \"grade_documents\",\n",
+    "    decide_to_generate,\n",
+    "    {\n",
+    "        \"company_private_data\": \"company_private_data_search\",\n",
+    "        \"generate\": \"generate\",\n",
+    "    },\n",
+    ")\n",
+    "# workflow.add_edge(\"retrieve\", \"generate\")\n",
+    "workflow.add_edge(\"company_private_data_search\", \"generate\")\n",
+    "workflow.add_conditional_edges(\n",
+    "    \"generate\",\n",
+    "    grade_generation_v_documents_and_question,\n",
+    "    {\n",
+    "        \"not supported\": \"generate\",\n",
+    "        \"useful\": END,\n",
+    "        \"not useful\": \"company_private_data_search\",\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "app = workflow.compile()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "ValueError",
+     "evalue": "no intersection found (point inside ?!). view: <langchain_core.runnables.graph_ascii.VertexViewer object at 0x70acd7938500> topt: (0.0, 25.5)",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprint_ascii\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph.py:485\u001b[0m, in \u001b[0;36mGraph.print_ascii\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    483\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprint_ascii\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    484\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Print the graph as an ASCII art string.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 485\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw_ascii\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph.py:478\u001b[0m, in \u001b[0;36mGraph.draw_ascii\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    475\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Draw the graph as an ASCII art string.\"\"\"\u001b[39;00m\n\u001b[1;32m    476\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrunnables\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgraph_ascii\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m draw_ascii\n\u001b[0;32m--> 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdraw_ascii\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    479\u001b[0m \u001b[43m    \u001b[49m\u001b[43m{\u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnodes\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    480\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medges\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    481\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph_ascii.py:251\u001b[0m, in \u001b[0;36mdraw_ascii\u001b[0;34m(vertices, edges)\u001b[0m\n\u001b[1;32m    248\u001b[0m Xs \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m    249\u001b[0m Ys \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 251\u001b[0m sug \u001b[38;5;241m=\u001b[39m \u001b[43m_build_sugiyama_layout\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvertices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medges\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    253\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m vertex \u001b[38;5;129;01min\u001b[39;00m sug\u001b[38;5;241m.\u001b[39mg\u001b[38;5;241m.\u001b[39msV:\n\u001b[1;32m    254\u001b[0m     \u001b[38;5;66;03m# NOTE: moving boxes w/2 to the left\u001b[39;00m\n\u001b[1;32m    255\u001b[0m     Xs\u001b[38;5;241m.\u001b[39mappend(vertex\u001b[38;5;241m.\u001b[39mview\u001b[38;5;241m.\u001b[39mxy[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m-\u001b[39m vertex\u001b[38;5;241m.\u001b[39mview\u001b[38;5;241m.\u001b[39mw \u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2.0\u001b[39m)\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph_ascii.py:210\u001b[0m, in \u001b[0;36m_build_sugiyama_layout\u001b[0;34m(vertices, edges)\u001b[0m\n\u001b[1;32m    207\u001b[0m sug\u001b[38;5;241m.\u001b[39mxspace \u001b[38;5;241m=\u001b[39m minw\n\u001b[1;32m    208\u001b[0m sug\u001b[38;5;241m.\u001b[39mroute_edge \u001b[38;5;241m=\u001b[39m route_with_lines\n\u001b[0;32m--> 210\u001b[0m \u001b[43msug\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    212\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sug\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/layouts.py:442\u001b[0m, in \u001b[0;36mSugiyamaLayout.draw\u001b[0;34m(self, N)\u001b[0m\n\u001b[1;32m    440\u001b[0m         \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m    441\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msetxy()\n\u001b[0;32m--> 442\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw_edges\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/layouts.py:814\u001b[0m, in \u001b[0;36mSugiyamaLayout.draw_edges\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    812\u001b[0m l\u001b[38;5;241m.\u001b[39mappend(e\u001b[38;5;241m.\u001b[39mv[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mview\u001b[38;5;241m.\u001b[39mxy)\n\u001b[1;32m    813\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 814\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mroute_edge\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ml\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    815\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m    816\u001b[0m     \u001b[38;5;28;01mpass\u001b[39;00m\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/routing.py:31\u001b[0m, in \u001b[0;36mroute_with_lines\u001b[0;34m(e, pts)\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mroute_with_lines\u001b[39m(e, pts):\n\u001b[1;32m     30\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(e, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mview\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 31\u001b[0m     tail_pos \u001b[38;5;241m=\u001b[39m \u001b[43mintersectR\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mv\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtopt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpts\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     32\u001b[0m     head_pos \u001b[38;5;241m=\u001b[39m intersectR(e\u001b[38;5;241m.\u001b[39mv[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mview, topt\u001b[38;5;241m=\u001b[39mpts[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m])\n\u001b[1;32m     33\u001b[0m     pts[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m tail_pos\n",
+      "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/utils/geometry.py:69\u001b[0m, in \u001b[0;36mintersectR\u001b[0;34m(view, topt)\u001b[0m\n\u001b[1;32m     66\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m (x, y)\n\u001b[1;32m     67\u001b[0m \u001b[38;5;66;03m# there can't be no intersection unless the endpoint was\u001b[39;00m\n\u001b[1;32m     68\u001b[0m \u001b[38;5;66;03m# inside the bb !\u001b[39;00m\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m     70\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mno intersection found (point inside ?!). view: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m topt: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (view, topt)\n\u001b[1;32m     71\u001b[0m )\n",
+      "\u001b[0;31mValueError\u001b[0m: no intersection found (point inside ?!). view: <langchain_core.runnables.graph_ascii.VertexViewer object at 0x70acd7938500> topt: (0.0, 25.5)"
+     ]
+    }
+   ],
+   "source": [
+    "app.get_graph().print_ascii()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def multi_query_chain(llm):\n",
+    "    # Multi Query: Different Perspectives\n",
+    "    template = \"\"\"\n",
+    "    <|begin_of_text|>\n",
+    "    \n",
+    "    <|start_header_id|>system<|end_header_id|>\n",
+    "    你是一個來自台灣的AI助理,你的專長是根據使用者的問題來產生三種不同版本的問題,請用繁體中文。 \\n\n",
+    "    You are an AI language model assistant. Your task is to generate three \n",
+    "    different versions of the given user question to retrieve relevant documents from a vector \n",
+    "    database. By generating multiple perspectives on the user question, your goal is to help\n",
+    "    the user overcome some of the limitations of the distance-based similarity search. \n",
+    "    Provide these alternative questions separated by newlines. \n",
+    "\n",
+    "    \n",
+    "    output must in user's language and no preamble or explanation..\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>user<|end_header_id|>\n",
+    "    Original question: {question}\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>assistant<|end_header_id|>\"\"\"\n",
+    "    prompt_perspectives = ChatPromptTemplate.from_template(template)\n",
+    "\n",
+    "    # You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.\n",
+    "    # llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
+    "    # llm = ChatOllama(model=\"llama3\", num_gpu=1, temperature=0)\n",
+    "\n",
+    "    generate_queries = (\n",
+    "        prompt_perspectives \n",
+    "        | llm\n",
+    "        | StrOutputParser() \n",
+    "        | (lambda x: x.split(\"\\n\"))\n",
+    "    )\n",
+    "\n",
+    "    return generate_queries"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['1. 溫室氣體的定義是什麼?', '2. 溫室氣體有哪些種類?', '3. 溫室氣體對環境的影響是什麼?']"
+      ]
+     },
+     "execution_count": 135,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# from RAG_strategy import multi_query_chain\n",
+    "llm = ChatOllama(model=local_llm, temperature=0)\n",
+    "question = \"溫室氣體是什麼\"\n",
+    "generate_queries = multi_query_chain(llm)\n",
+    "\n",
+    "questions = generate_queries.invoke(question)\n",
+    "questions"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def ask_more_detail_chain(llm):\n",
+    "    # Multi Query: Different Perspectives\n",
+    "    template = \"\"\"\n",
+    "    <|begin_of_text|>\n",
+    "    \n",
+    "    <|start_header_id|>system<|end_header_id|>\n",
+    "    你是一個來自台灣的AI助理,你的專長是根據使用者提供的文本來進一步詢問當中的細節,例如名詞解釋,請用繁體中文。 \\n\n",
+    "    You are an AI language model assistant. \n",
+    "    Your task is to generate three questions about the given user context as additional explanation.\n",
+    "    By generating three in-depth questions about the user context, your goal is to help the user realize more details. \n",
+    "    Provide these questions separated by newlines.\n",
+    "    For example:\n",
+    "    context: 建準廣興廠去年2023年一共自產發電了684,508度綠電\n",
+    "    in-depth question:什麼是綠電?\\n為何要使用綠電? \n",
+    "\n",
+    "    output must in user's language and no preamble or explanation.\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>user<|end_header_id|>\n",
+    "    \n",
+    "    \n",
+    "    \n",
+    "    Original context: {question}\n",
+    "    three questions:\n",
+    "    <|eot_id|>\n",
+    "    \n",
+    "    <|start_header_id|>assistant<|end_header_id|>\"\"\"\n",
+    "    prompt_perspectives = ChatPromptTemplate.from_template(template)\n",
+    "\n",
+    "    \n",
+    "    # llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
+    "    # llm = ChatOllama(model=\"llama3\", num_gpu=1, temperature=0)\n",
+    "\n",
+    "    generate_queries = (\n",
+    "        prompt_perspectives \n",
+    "        | llm\n",
+    "        | StrOutputParser() \n",
+    "        | (lambda x: x.split(\"\\n\"))\n",
+    "    )\n",
+    "\n",
+    "    return generate_queries"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['什麼是固定燃燒排放量?']"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "llm = ChatOllama(model=local_llm, temperature=0)\n",
+    "question = \"建準廣興廠去年的固定燃燒排放量是多少?\"\n",
+    "generate_queries = ask_more_detail_chain(llm)\n",
+    "\n",
+    "questions = generate_queries.invoke(question)\n",
+    "questions"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'固定燃燒(Fixed Combustion)是一種溫室氣體排放的方法,涉及到化石燃料的燃燒,以產生熱或蒸汽。例如:鍋爐、加熱爐、緊急發電機等設施。'"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "llm = ChatOllama(model=local_llm, temperature=0)\n",
+    "question = \"固定燃燒是什麼?\"\n",
+    "docs = retriever.get_relevant_documents(question, k=10)\n",
+    "generation = faiss_query(question, docs, llm)\n",
+    "generation"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain.tools import StructuredTool"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def rag(question: str):\n",
+    "    llm = ChatOllama(model=local_llm, temperature=0)\n",
+    "    docs = retriever.get_relevant_documents(question, k=10)\n",
+    "    answer = faiss_query(question, docs, llm)\n",
+    "    \n",
+    "    return answer\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "rag_tool = StructuredTool.from_function(\n",
+    "    func=rag,\n",
+    "    name=\"RAG\",\n",
+    "    description=\"查詢碳盤查知識庫\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/_api/deprecation.py:141: LangChainDeprecationWarning: The method `BaseTool.__call__` was deprecated in langchain-core 0.1.47 and will be removed in 1.0. Use invoke instead.\n",
+      "  warn_deprecated(\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "'溫室氣體是一種能吸收和釋放紅外線輻射的氣體,存在於大氣中,並將熱能留在地球表面,無法散出大氣層外。'\n"
+     ]
+    }
+   ],
+   "source": [
+    "from pprint import pprint\n",
+    "# from langchain_community.chat_models import ChatOllama\n",
+    "# from langchain_community.chat_models.ollama import ChatOllama\n",
+    "from langchain_ollama import ChatOllama\n",
+    "llm = ChatOllama(model=local_llm, temperature=0)\n",
+    "pprint(rag_tool(\"溫室氣體是什麼\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "工具名稱:RAG\n",
+      "工具描述:查詢碳盤查知識庫\n",
+      "工具參數:{'question': {'title': 'Question', 'type': 'string'}}\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f'工具名稱:{rag_tool.name}')\n",
+    "print(f'工具描述:{rag_tool.description}')\n",
+    "print(f'工具參數:{rag_tool.args}')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sql_tool = StructuredTool.from_function(\n",
+    "    func=run_text_to_sql,\n",
+    "    name=\"SQL\",\n",
+    "    description=\"查詢客戶內部資料\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"固定燃燒總排放量\"\n",
+      "FROM \"2023 清冊數據(GHG)\"\n",
+      "WHERE \"排放源\" = '固定燃燒'\n",
+      "[(13.5953,)]\n",
+      "'建準去年的固定燃燒總排放量是13.5953。'\n"
+     ]
+    }
+   ],
+   "source": [
+    "pprint(sql_tool(\"建準去年的固定燃燒總排放量是多少?\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tools = [rag_tool, sql_tool]\n",
+    "model_with_tools = llm.bind_tools(tools)\n",
+    "tool_map = {tool.name: tool for tool in tools}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain_core.runnables import RunnablePassthrough\n",
+    "from operator import itemgetter, attrgetter\n",
+    "from langchain_core.runnables import (\n",
+    "    RunnableLambda, RunnablePassthrough)\n",
+    "\n",
+    "def call_tool(tool_invocation):\n",
+    "    tool = tool_map[tool_invocation[\"type\"]]\n",
+    "    return RunnablePassthrough.assign(\n",
+    "        output=itemgetter(\"args\") | tool)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain.output_parsers import JsonOutputToolsParser\n",
+    "call_tool_list = RunnableLambda(call_tool).map()\n",
+    "chain = (model_with_tools\n",
+    "         | JsonOutputToolsParser()\n",
+    "         | call_tool_list)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[{'args': {'question': '測試測試測試介紹溫室氣體是什麼'},\n",
+       "  'type': 'RAG',\n",
+       "  'output': '溫室氣體(Greenhouse Gas)指的是會吸收和釋放紅外線輻射並存在大氣中的氣體。這些氣體包括二氧化碳(CO2)、甲烷(CH4)、氧化亞氮(N2O)、氫氟碳化物(HFCs)、全氟碳化物(PFCs)、六氟化硫(SF6)及三氟化氮(NF3)。'}]"
+      ]
+     },
+     "execution_count": 74,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# \"建準去年的綠電使用量是多少?\"\n",
+    "# 溫室氣體是什麼\n",
+    "chain.invoke(\"測試測試測試介紹溫室氣體是什麼?建準綠電?測試測試測試測試測試測試\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain_core.prompts import MessagesPlaceholder\n",
+    "from langchain.agents import (\n",
+    "    AgentExecutor, create_openai_tools_agent)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain_community.chat_models import ChatOllama\n",
+    "local_llm = \"llama3-groq-tool-use:latest\"\n",
+    "llm = ChatOllama(model=local_llm, temperature=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "prompt = ChatPromptTemplate.from_messages([\n",
+    "    ('system',\"\"\"你是一位善用工具的繁體中文助理,會判斷問題來選擇適合的工具,\n",
+    "                與公司私人數據來回答有關公司溫室氣體排放數據資訊的問題(工具名稱:SQL),\n",
+    "                或使用向量庫來解答有關 ESG 領域知識或有關 ESG 新聞的問題(工具名稱:RAG)。\"\"\"),\n",
+    "    ('human','{input}'),\n",
+    "    MessagesPlaceholder(variable_name=\"agent_scratchpad\")\n",
+    "])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# tools = [weather_data, search_run]\n",
+    "from langchain_openai import ChatOpenAI\n",
+    "\n",
+    "openai = ChatOpenAI(temperature=0)\n",
+    "agent = create_openai_tools_agent(llm=llm,\n",
+    "                                  tools=tools,\n",
+    "                                  prompt=prompt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "agent_executor = AgentExecutor(agent=agent,\n",
+    "                               tools=tools,\n",
+    "                               verbose=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "\n",
+      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
+      "\u001b[32;1m\u001b[1;3m\n",
+      "Invoking: `SQL` with `{'question': \"SELECT SUM(fixed_burn_total) AS total FROM last_year_data WHERE category = 'CO2'\"}`\n",
+      "\n",
+      "\n",
+      "\u001b[0mTo answer your question, I can help you with that. Could you please specify what kind of data you are looking for? For example, the total CO2 emissions or the direct emissions from a specific factory?\n",
+      "Error: (psycopg2.errors.SyntaxError) syntax error at or near \"To\"\n",
+      "LINE 1: To answer your question, I can help you with that. Could you...\n",
+      "        ^\n",
+      "\n",
+      "[SQL: To answer your question, I can help you with that. Could you please specify what kind of data you are looking for? For example, the total CO2 emissions or the direct emissions from a specific factory?]\n",
+      "(Background on this error at: https://sqlalche.me/e/20/f405)\n",
+      "\u001b[33;1m\u001b[1;3mIt seems like there's an issue with your SQL query. The error message indicates a syntax error near \"To\". Could you please provide the correct SQL query or clarify what you're trying to achieve?\u001b[0m\u001b[32;1m\u001b[1;3mIt looks like there was a mistake in the SQL query I tried to execute. Can you confirm if the category name is 'CO2' and if so, should it be enclosed within single quotes?\u001b[0m\n",
+      "\n",
+      "\u001b[1m> Finished chain.\u001b[0m\n",
+      "It looks like there was a mistake in the SQL query I tried to execute. Can you confirm if the category name is 'CO2' and if so, should it be enclosed within single quotes?\n"
+     ]
+    }
+   ],
+   "source": [
+    "result = agent_executor.invoke({\"input\": \"建準去年的固定燃燒總排放量是多少?碳權是什麼?\"})\n",
+    "print(result['output'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "\n",
+      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
+      "\u001b[32;1m\u001b[1;3m您好!我可以幫助您查詢一些資訊。您需要查詢什麼?\u001b[0m\n",
+      "\n",
+      "\u001b[1m> Finished chain.\u001b[0m\n",
+      "您好!我可以幫助您查詢一些資訊。您需要查詢什麼?\n"
+     ]
+    }
+   ],
+   "source": [
+    "result = agent_executor.invoke({\"input\": \"你好\"})\n",
+    "print(result['output'])"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "llama3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 28 - 20
text_to_sql.py

@@ -43,19 +43,26 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
 #     pipeline_kwargs={"return_full_text": False},
 #     device=0, device_map='cuda')
 
+##########################################################################################
+from langchain_community.chat_models import ChatOllama
+# local_llm = "llama3-groq-tool-use:latest"
+local_llm = "llama3-groq-tool-use:latest"
+llm = ChatOllama(model=local_llm, temperature=0)
+##########################################################################################
+# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+# llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False,
+#         "max_new_tokens": 512},
+#     device=0, device_map='cuda')
+# print(llm.pipeline)
+# llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+##########################################################################################
 
-model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
-tokenizer = AutoTokenizer.from_pretrained(model_id)
-
-llm = HuggingFacePipeline.from_model_id(
-    model_id=model_id,
-    task="text-generation",
-    model_kwargs={"torch_dtype": torch.bfloat16},
-    pipeline_kwargs={"return_full_text": False,
-        "max_new_tokens": 512},
-    device=0, device_map='cuda')
-print(llm.pipeline)
-llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
 # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
 
 # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1, 
@@ -124,7 +131,7 @@ def table_description():
 
     return database_description
 
-def write_query_chain(db):
+def write_query_chain(db, llm):
 
     template = """
     <|begin_of_text|>
@@ -140,7 +147,7 @@ def write_query_chain(db):
     Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
     
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
-    ***Pay attention to only return PostgreSQL query.\n\
+    ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
     <|eot_id|>
         
     <|begin_of_text|><|start_header_id|>user<|end_header_id|>
@@ -150,8 +157,9 @@ def write_query_chain(db):
     database description:
     {database_description}
 
+    Provide ONLY PostgreSQL query and NO premable or explanation!
     The following SQL query best answers the question `{input}`:
-    ```sql
+    
     <|eot_id|>
     
     <|start_header_id|>assistant<|end_header_id|>
@@ -176,7 +184,7 @@ def write_query_chain(db):
 
     return write_query
 
-def sql_to_nl_chain():
+def sql_to_nl_chain(llm):
     # llm = Ollama(model = "llama3.1", num_gpu=1)
     # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
     # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
@@ -210,9 +218,9 @@ def sql_to_nl_chain():
 
     return chain
 
-def run(db, question, selected_table):
+def run(db, question, selected_table, llm):
 
-    write_query = write_query_chain(db)
+    write_query = write_query_chain(db, llm)
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     query = re.split('SQL query: ', query)[-1]
@@ -222,7 +230,7 @@ def run(db, question, selected_table):
     result = execute_query.invoke(query)
     print(result)
 
-    chain = sql_to_nl_chain()
+    chain = sql_to_nl_chain(llm)
     answer = chain.invoke({"question": question, "query": query, "result": result})
 
     return query, result, answer
@@ -234,7 +242,7 @@ if __name__ == "__main__":
     start = time.time()
     
     selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']
-    question = "建準廣興廠去年的綠電使用量是多少?"
+    question = "建準去年的固定燃燒總排放量是多少?"
     query, result, answer = run(db, question, selected_table)
     print("question: ", question)
     print("query: ", query)