|
@@ -0,0 +1,1619 @@
|
|
|
|
+{
|
|
|
|
+ "cells": [
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# LLM"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 1,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### LLM\n",
|
|
|
|
+ "from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "# local_llm = \"llama3.1:8b-instruct-fp16\"\n",
|
|
|
|
+ "local_llm = \"llama3-groq-tool-use:latest\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
|
|
|
|
+ "llm = ChatOllama(model=local_llm, temperature=0)\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# RAG"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 2,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stderr",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
|
|
+ " from .autonotebook import tqdm as notebook_tqdm\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "FAISS index loaded from faiss_index.bin\n",
|
|
|
|
+ "Metadata loaded from faiss_metadata.pkl\n",
|
|
|
|
+ "Using existing FAISS index and metadata.\n",
|
|
|
|
+ "Creating FAISS retriever...\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "from faiss_index import create_faiss_retriever, faiss_query\n",
|
|
|
|
+ "retriever = create_faiss_retriever()"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 3,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "from langchain.prompts import ChatPromptTemplate\n",
|
|
|
|
+ "from langchain_core.output_parsers import StrOutputParser\n",
|
|
|
|
+ "def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:\n",
|
|
|
|
+ " context = docs\n",
|
|
|
|
+ " # try:\n",
|
|
|
|
+ " # context = \"\\n\".join(doc.page_content for doc in docs)\n",
|
|
|
|
+ " # except:\n",
|
|
|
|
+ " # context = \"\\n\".join(doc for doc in docs)\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " system_prompt: str = \"你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。\"\n",
|
|
|
|
+ " template = \"\"\"\n",
|
|
|
|
+ " <|begin_of_text|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ " 你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \\n\n",
|
|
|
|
+ " You should not mention anything about \"根據提供的文件內容\" or other similar terms.\n",
|
|
|
|
+ " Use five sentences maximum and keep the answer concise.\n",
|
|
|
|
+ " 如果你不知道答案請回答:\"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。\"\n",
|
|
|
|
+ " 勿回答無關資訊\n",
|
|
|
|
+ " <|eot_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ " Answer the following question based on this context:\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " {context}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Question: {question}\n",
|
|
|
|
+ " 用繁體中文\n",
|
|
|
|
+ " <|eot_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>assistant<|end_header_id|>\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " prompt = ChatPromptTemplate.from_template(\n",
|
|
|
|
+ " system_prompt + \"\\n\\n\" +\n",
|
|
|
|
+ " template\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " # prompt = ChatPromptTemplate.from_template(\n",
|
|
|
|
+ " # system_prompt + \"\\n\\n\" +\n",
|
|
|
|
+ " # \"Answer the following question based on this context:\\n\\n\"\n",
|
|
|
|
+ " # \"{context}\\n\\n\"\n",
|
|
|
|
+ " # \"Question: {question}\\n\"\n",
|
|
|
|
+ " # \"Answer in the same language as the question. If you don't know the answer, \"\n",
|
|
|
|
+ " # \"say 'I'm sorry, I don't have enough information to answer that question.'\"\n",
|
|
|
|
+ " # )\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " # chain = prompt | taide_llm | StrOutputParser()\n",
|
|
|
|
+ " chain = prompt | llm | StrOutputParser()\n",
|
|
|
|
+ " return chain.invoke({\"context\": context, \"question\": question})"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 4,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Generate\n",
|
|
|
|
+ "# llm = ChatOllama(model=local_llm, temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n",
|
|
|
|
+ "# generation = faiss_query(question, docs_documents, llm)\n",
|
|
|
|
+ "# generation"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Retrieval Grader"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 5,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Retrieval Grader\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "from langchain_core.output_parsers import JsonOutputParser\n",
|
|
|
|
+ "from langchain_core.prompts import PromptTemplate\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# LLM\n",
|
|
|
|
+ "# llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "prompt = PromptTemplate(\n",
|
|
|
|
+ " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance \n",
|
|
|
|
+ " of a retrieved document to a user question. If the document contains keywords related to the user question, \n",
|
|
|
|
+ " grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n",
|
|
|
|
+ " Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \\n\n",
|
|
|
|
+ " Provide the binary score as a JSON with a single key 'score' and no premable or explanation.\n",
|
|
|
|
+ " <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ " Here is the retrieved document: \\n\\n {document} \\n\\n\n",
|
|
|
|
+ " Here is the user question: {question} \\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
|
|
|
+ " \"\"\",\n",
|
|
|
|
+ " input_variables=[\"question\", \"document\"],\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "retrieval_grader = prompt | llm_json | JsonOutputParser()\n",
|
|
|
|
+ "# question = \"溫室氣體是什麼\"\n",
|
|
|
|
+ "# # docs = retriever.invoke(question)\n",
|
|
|
|
+ "# docs = retriever.get_relevant_documents(question, k=10)\n",
|
|
|
|
+ "# doc_txt = docs[1].page_content\n",
|
|
|
|
+ "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 6,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "# for doc in docs:\n",
|
|
|
|
+ "# doc_txt = doc.page_content\n",
|
|
|
|
+ "# print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Hallucination Grader"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 7,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Hallucination Grader\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "from langchain_core.output_parsers import JsonOutputParser\n",
|
|
|
|
+ "from langchain_core.prompts import PromptTemplate\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# LLM\n",
|
|
|
|
+ "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# Prompt\n",
|
|
|
|
+ "prompt = PromptTemplate(\n",
|
|
|
|
+ " template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
|
|
|
|
+ " You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n",
|
|
|
|
+ " Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n",
|
|
|
|
+ " Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. \n",
|
|
|
|
+ " Return the a JSON with a single key 'score' and no premable or explanation. \n",
|
|
|
|
+ " <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ " Here are the facts:\n",
|
|
|
|
+ " \\n ------- \\n\n",
|
|
|
|
+ " {documents} \n",
|
|
|
|
+ " \\n ------- \\n\n",
|
|
|
|
+ " Here is the answer: {generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
|
|
|
|
+ " input_variables=[\"generation\", \"documents\"],\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "hallucination_grader = prompt | llm_json | JsonOutputParser()\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# hallucination_grader.invoke({\"documents\": docs_documents, \"generation\": generation})"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Answer Grader"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 8,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Answer Grader\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# LLM\n",
|
|
|
|
+ "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# Prompt\n",
|
|
|
|
+ "prompt = PromptTemplate(\n",
|
|
|
|
+ " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an \n",
|
|
|
|
+ " answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is \n",
|
|
|
|
+ " useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.\n",
|
|
|
|
+ " <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:\n",
|
|
|
|
+ " \\n ------- \\n\n",
|
|
|
|
+ " {generation} \n",
|
|
|
|
+ " \\n ------- \\n\n",
|
|
|
|
+ " Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
|
|
|
|
+ " input_variables=[\"generation\", \"question\"],\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "answer_grader = prompt | llm_json | JsonOutputParser()\n",
|
|
|
|
+ "# answer_grader.invoke({\"question\": question, \"generation\": generation})"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# SQL"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 32,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stderr",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n",
|
|
|
|
+ " self._metadata.reflect(\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "data": {
|
|
|
|
+ "text/plain": [
|
|
|
|
+ "<module 'text_to_sql' from '/home/ling/systex/text_to_sql.py'>"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ "execution_count": 32,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "from importlib import reload # Python 3.4+\n",
|
|
|
|
+ "import text_to_sql\n",
|
|
|
|
+ "reload(text_to_sql)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 49,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stderr",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_community/utilities/sql_database.py:123: SAWarning: Did not recognize type 'vector' of column 'embedding'\n",
|
|
|
|
+ " self._metadata.reflect(\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "\n",
|
|
|
|
+ "from text_to_sql import run\n",
|
|
|
|
+ "from langchain_community.utilities import SQLDatabase\n",
|
|
|
|
+ "import os\n",
|
|
|
|
+ "URI: str = os.environ.get('SUPABASE_URI')\n",
|
|
|
|
+ "db = SQLDatabase.from_uri(URI)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def run_text_to_sql(question: str):\n",
|
|
|
|
+ " selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']\n",
|
|
|
|
+ " # question = \"建準去年的固定燃燒總排放量是多少?\"\n",
|
|
|
|
+ " query, result, answer = run(db, question, selected_table, llm)\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " return answer, query"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "question = \"建準去年的固定燃燒總排放量是多少?\"\n",
|
|
|
|
+ "run_text_to_sql(question)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Hallucination Grader"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Hallucination Grader\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "from langchain_core.output_parsers import JsonOutputParser\n",
|
|
|
|
+ "from langchain_core.prompts import PromptTemplate\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# LLM\n",
|
|
|
|
+ "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# Prompt\n",
|
|
|
|
+ "prompt = PromptTemplate(\n",
|
|
|
|
+ " template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
|
|
|
|
+ " You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n",
|
|
|
|
+ " Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n",
|
|
|
|
+ " Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. \n",
|
|
|
|
+ " Return the a JSON with a single key 'score' and no premable or explanation. \n",
|
|
|
|
+ " <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ " Here are the facts:\n",
|
|
|
|
+ " \\n ------- \\n\n",
|
|
|
|
+ " {documents} \n",
|
|
|
|
+ " \\n ------- \\n\n",
|
|
|
|
+ " Here is the answer: {generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
|
|
|
|
+ " input_variables=[\"generation\", \"documents\"],\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "hallucination_grader = prompt | llm_json | JsonOutputParser()"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# Router"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 11,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "{'datasource': 'company_private_data'}\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Router\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "from langchain_core.output_parsers import JsonOutputParser\n",
|
|
|
|
+ "from langchain_core.prompts import PromptTemplate\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# LLM\n",
|
|
|
|
+ "llm_json = ChatOllama(model=local_llm, format=\"json\", temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "prompt = PromptTemplate(\n",
|
|
|
|
+ " template=\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|> \n",
|
|
|
|
+ " You are an expert at routing a user question to a vectorstore or company private data. \n",
|
|
|
|
+ " Use company private data for questions about the informations about a company's greenhouse gas emissions data.\n",
|
|
|
|
+ " Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. \n",
|
|
|
|
+ " You do not need to be stringent with the keywords in the question related to these topics. \n",
|
|
|
|
+ " Give a binary choice 'company_private_data' or 'vectorstore' based on the question. \n",
|
|
|
|
+ " Return the a JSON with a single key 'datasource' and no premable or explanation. \n",
|
|
|
|
+ " \n",
|
|
|
|
+ " Question to route: {question} \n",
|
|
|
|
+ " <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n",
|
|
|
|
+ " input_variables=[\"question\"],\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "question_router = prompt | llm_json | JsonOutputParser()\n",
|
|
|
|
+ "question = \"建準去年的類別1排放量是多少?\"\n",
|
|
|
|
+ "# docs = retriever.get_relevant_documents(question)\n",
|
|
|
|
+ "# doc_txt = docs[1].page_content\n",
|
|
|
|
+ "print(question_router.invoke({\"question\": question}))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# Node"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 12,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "# RAG + text-to-sql\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from pprint import pprint\n",
|
|
|
|
+ "from typing import List\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langchain_core.documents import Document\n",
|
|
|
|
+ "from typing_extensions import TypedDict\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langgraph.graph import END, StateGraph, START\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "### State\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "class GraphState(TypedDict):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Represents the state of our graph.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Attributes:\n",
|
|
|
|
+ " question: question\n",
|
|
|
|
+ " generation: LLM generation\n",
|
|
|
|
+ " company_private_data: whether to search company private data\n",
|
|
|
|
+ " documents: list of documents\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " question: str\n",
|
|
|
|
+ " generation: str\n",
|
|
|
|
+ " company_private_data: str\n",
|
|
|
|
+ " documents: List[str]"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 48,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "def retrieve_and_generation(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Retrieve documents from vectorstore\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " print(\"---RETRIEVE---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # Retrieval\n",
|
|
|
|
+ " # documents = retriever.invoke(question)\n",
|
|
|
|
+ " # TODO: correct Retrieval function\n",
|
|
|
|
+ " documents = retriever.get_relevant_documents(question, k=10)\n",
|
|
|
|
+ " docs_documents = \"\\n\\n\".join(doc.page_content for doc in documents)\n",
|
|
|
|
+ " generation = faiss_query(question, docs_documents, llm)\n",
|
|
|
|
+ " return {\"documents\": docs_documents, \"question\": question, \"generation\": generation}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def company_private_data_search(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Web search based based on the question\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " state (dict): Appended web results to documents\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---SQL---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " # documents = state[\"documents\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # company_private_data_search\n",
|
|
|
|
+ " # docs = web_search_tool.invoke({\"query\": question})\n",
|
|
|
|
+ " # web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
|
|
|
|
+ " # web_results = Document(page_content=web_results)\n",
|
|
|
|
+ " # TODO: correct company_private_data_search function\n",
|
|
|
|
+ " company_private_data_result, sql_query = run_text_to_sql(question)\n",
|
|
|
|
+ " company_private_data_result\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " # generation = [company_private_data_result]\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " return {\"documents\": sql_query, \"question\": question, \"generation\": company_private_data_result}"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# Conditional edge"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 64,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "### Conditional edge\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def route_question(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Route question to web search or RAG.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " str: Next node to call\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---ROUTE QUESTION---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " print(question)\n",
|
|
|
|
+ " source = question_router.invoke({\"question\": question})\n",
|
|
|
|
+ " print(source)\n",
|
|
|
|
+ " print(source[\"datasource\"])\n",
|
|
|
|
+ " if source[\"datasource\"] == \"company_private_data\":\n",
|
|
|
|
+ " print(\"---ROUTE QUESTION TO TEXT-TO-SQL---\")\n",
|
|
|
|
+ " return \"company_private_data\"\n",
|
|
|
|
+ " elif source[\"datasource\"] == \"vectorstore\":\n",
|
|
|
|
+ " print(\"---ROUTE QUESTION TO RAG---\")\n",
|
|
|
|
+ " return \"vectorstore\"\n",
|
|
|
|
+ " \n",
|
|
|
|
+ "def grade_generation_v_documents_and_question(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Determines whether the generation is grounded in the document and answers question.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " str: Decision for next node to call\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---CHECK HALLUCINATIONS---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " documents = state[\"documents\"]\n",
|
|
|
|
+ " generation = state[\"generation\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " # print(docs_documents)\n",
|
|
|
|
+ " # print(generation)\n",
|
|
|
|
+ " score = hallucination_grader.invoke(\n",
|
|
|
|
+ " {\"documents\": documents, \"generation\": generation}\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ " print(score)\n",
|
|
|
|
+ " grade = score[\"score\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # Check hallucination\n",
|
|
|
|
+ " if grade in [\"yes\", \"true\", 1, \"1\"]:\n",
|
|
|
|
+ " print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n",
|
|
|
|
+ " # Check question-answering\n",
|
|
|
|
+ " print(\"---GRADE GENERATION vs QUESTION---\")\n",
|
|
|
|
+ " score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
|
|
|
|
+ " grade = score[\"score\"]\n",
|
|
|
|
+ " if grade in [\"yes\", \"true\", 1, \"1\"]:\n",
|
|
|
|
+ " print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n",
|
|
|
|
+ " return \"useful\"\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n",
|
|
|
|
+ " return \"not useful\"\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n",
|
|
|
|
+ " return \"not supported\"\n",
|
|
|
|
+ " "
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "# Graph"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 65,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "\n",
|
|
|
|
+ "workflow = StateGraph(GraphState)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# Define the nodes\n",
|
|
|
|
+ "workflow.add_node(\"company_private_data_search\", company_private_data_search) # web search\n",
|
|
|
|
+ "workflow.add_node(\"retrieve_and_generation\", retrieve_and_generation) # retrieve\n",
|
|
|
|
+ "# workflow.add_node(\"grade_generation\", grade_documents) # grade documents\n",
|
|
|
|
+ "# workflow.add_node(\"generate\", generate) # generatae\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "workflow.add_conditional_edges(\n",
|
|
|
|
+ " START,\n",
|
|
|
|
+ " route_question,\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"company_private_data\": \"company_private_data_search\",\n",
|
|
|
|
+ " \"vectorstore\": \"retrieve_and_generation\",\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "workflow.add_conditional_edges(\n",
|
|
|
|
+ " \"retrieve_and_generation\",\n",
|
|
|
|
+ " grade_generation_v_documents_and_question,\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"not supported\": \"retrieve_and_generation\",\n",
|
|
|
|
+ " \"useful\": END,\n",
|
|
|
|
+ " \"not useful\": \"retrieve_and_generation\",\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "workflow.add_conditional_edges(\n",
|
|
|
|
+ " \"company_private_data_search\",\n",
|
|
|
|
+ " grade_generation_v_documents_and_question,\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"not supported\": \"company_private_data_search\",\n",
|
|
|
|
+ " \"useful\": END,\n",
|
|
|
|
+ " \"not useful\": \"retrieve_and_generation\",\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# workflow.add_edge(\"company_private_data_search\", END)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "app = workflow.compile()"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 37,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "%%{init: {'flowchart': {'curve': 'linear'}}}%%\n",
|
|
|
|
+ "graph TD;\n",
|
|
|
|
+ "\t__start__([__start__]):::first\n",
|
|
|
|
+ "\tcompany_private_data_search(company_private_data_search)\n",
|
|
|
|
+ "\tretrieve_and_generation(retrieve_and_generation)\n",
|
|
|
|
+ "\t__end__([__end__]):::last\n",
|
|
|
|
+ "\tcompany_private_data_search --> __end__;\n",
|
|
|
|
+ "\t__start__ -.  company_private_data  .-> company_private_data_search;\n",
|
|
|
|
+ "\t__start__ -.  vectorstore  .-> retrieve_and_generation;\n",
|
|
|
|
+ "\tretrieve_and_generation -.  not supported  .-> retrieve_and_generation;\n",
|
|
|
|
+ "\tretrieve_and_generation -.  useful  .-> __end__;\n",
|
|
|
|
+ "\tretrieve_and_generation -.  not useful  .-> company_private_data_search;\n",
|
|
|
|
+ "\tclassDef default fill:#f2f0ff,line-height:1.2\n",
|
|
|
|
+ "\tclassDef first fill-opacity:0\n",
|
|
|
|
+ "\tclassDef last fill:#bfb6fc\n",
|
|
|
|
+ "\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "print(app.get_graph().draw_mermaid())"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 1,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "ename": "NameError",
|
|
|
|
+ "evalue": "name 'app' is not defined",
|
|
|
|
+ "output_type": "error",
|
|
|
|
+ "traceback": [
|
|
|
|
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
|
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
|
|
+ "Cell \u001b[0;32mIn[1], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image, display\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrunnables\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgraph\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CurveStyle, MermaidDrawMethod, NodeStyles\n\u001b[1;32m 4\u001b[0m display(\n\u001b[1;32m 5\u001b[0m Image(\n\u001b[0;32m----> 6\u001b[0m \u001b[43mapp\u001b[49m\u001b[38;5;241m.\u001b[39mget_graph()\u001b[38;5;241m.\u001b[39mdraw_mermaid_png(\n\u001b[1;32m 7\u001b[0m draw_method\u001b[38;5;241m=\u001b[39mMermaidDrawMethod\u001b[38;5;241m.\u001b[39mAPI,\n\u001b[1;32m 8\u001b[0m output_file_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124magent_workflow.png\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 10\u001b[0m )\n\u001b[1;32m 11\u001b[0m )\n",
|
|
|
|
+ "\u001b[0;31mNameError\u001b[0m: name 'app' is not defined"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "from IPython.display import Image, display\n",
|
|
|
|
+ "from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "display(\n",
|
|
|
|
+ " Image(\n",
|
|
|
|
+ " app.get_graph().draw_mermaid_png(\n",
|
|
|
|
+ " draw_method=MermaidDrawMethod.API,\n",
|
|
|
|
+ " output_file_path=\"agent_workflow.png\",\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ ")"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "markdown",
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "source": [
|
|
|
|
+ "## Test"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 67,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "---ROUTE QUESTION---\n",
|
|
|
|
+ "建準2024的固定燃燒總排放量是多少?\n",
|
|
|
|
+ "{'datasource': 'company_private_data'}\n",
|
|
|
|
+ "company_private_data\n",
|
|
|
|
+ "---ROUTE QUESTION TO TEXT-TO-SQL---\n",
|
|
|
|
+ "---SQL---\n",
|
|
|
|
+ "SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"固定燃燒總排放量\"\n",
|
|
|
|
+ "FROM \"2023 清冊數據(GHG)\"\n",
|
|
|
|
+ "WHERE \"排放源\" = '固定燃燒'\n",
|
|
|
|
+ "[(13.5953,)]\n",
|
|
|
|
+ "---CHECK HALLUCINATIONS---\n",
|
|
|
|
+ "{'score': True}\n",
|
|
|
|
+ "---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\n",
|
|
|
|
+ "---GRADE GENERATION vs QUESTION---\n",
|
|
|
|
+ "---DECISION: GENERATION ADDRESSES QUESTION---\n",
|
|
|
|
+ "'Finished running: company_private_data_search:'\n",
|
|
|
|
+ "'建準2024的固定燃燒總排放量是13.5953。'\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "# Test\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "inputs = {\"question\": \"建準2024的固定燃燒總排放量是多少?\"}\n",
|
|
|
|
+ "for output in app.stream(inputs):\n",
|
|
|
|
+ " for key, value in output.items():\n",
|
|
|
|
+ " pprint(f\"Finished running: {key}:\")\n",
|
|
|
|
+ "pprint(value[\"generation\"])"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 17,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "# RAG + text-to-sql\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from pprint import pprint\n",
|
|
|
|
+ "from typing import List\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langchain_core.documents import Document\n",
|
|
|
|
+ "from typing_extensions import TypedDict\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "from langgraph.graph import END, StateGraph, START\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "### State\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "class GraphState(TypedDict):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Represents the state of our graph.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Attributes:\n",
|
|
|
|
+ " question: question\n",
|
|
|
|
+ " generation: LLM generation\n",
|
|
|
|
+ " company_private_data: whether to search company private data\n",
|
|
|
|
+ " documents: list of documents\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " question: str\n",
|
|
|
|
+ " generation: str\n",
|
|
|
|
+ " company_private_data: str\n",
|
|
|
|
+ " documents: List[str]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "### Nodes\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def retrieve(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Retrieve documents from vectorstore\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " state (dict): New key added to state, documents, that contains retrieved documents\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " print(\"---RETRIEVE---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # Retrieval\n",
|
|
|
|
+ " # documents = retriever.invoke(question)\n",
|
|
|
|
+ " # TODO: correct Retrieval function\n",
|
|
|
|
+ " documents = retriever.get_relevant_documents(question, k=10)\n",
|
|
|
|
+ " docs_documents = \"\\n\\n\".join(doc.page_content for doc in docs)\n",
|
|
|
|
+ " generation = faiss_query(question, docs_documents, llm)\n",
|
|
|
|
+ " return {\"documents\": documents, \"question\": question}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def generate(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Generate answer using RAG on retrieved documents\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " state (dict): New key added to state, generation, that contains LLM generation\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " print(\"---GENERATE---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " documents = state[\"documents\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # RAG generation\n",
|
|
|
|
+ " # generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n",
|
|
|
|
+ " # TODO: correct generation function\n",
|
|
|
|
+ " print(documents)\n",
|
|
|
|
+ " print(question)\n",
|
|
|
|
+ " generation = faiss_query(question, documents, llm)\n",
|
|
|
|
+ " generation = eval(generation)['data']['answer']\n",
|
|
|
|
+ " return {\"documents\": documents, \"question\": question, \"generation\": generation}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def grade_documents(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Determines whether the retrieved documents are relevant to the question\n",
|
|
|
|
+ " If any document is not relevant, we will set a flag to run web search\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " state (dict): Filtered out irrelevant documents and updated company_private_data state\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---CHECK DOCUMENT RELEVANCE TO QUESTION---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " documents = state[\"documents\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # Score each doc\n",
|
|
|
|
+ " filtered_docs = []\n",
|
|
|
|
+ " company_private_data = \"No\"\n",
|
|
|
|
+ " for d in documents:\n",
|
|
|
|
+ " score = retrieval_grader.invoke(\n",
|
|
|
|
+ " {\"question\": question, \"document\": d.page_content}\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ " grade = score[\"score\"]\n",
|
|
|
|
+ " # Document relevant\n",
|
|
|
|
+ " if grade.lower() == \"yes\":\n",
|
|
|
|
+ " print(\"---GRADE: DOCUMENT RELEVANT---\")\n",
|
|
|
|
+ " filtered_docs.append(d)\n",
|
|
|
|
+ " # Document not relevant\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n",
|
|
|
|
+ " # We do not include the document in filtered_docs\n",
|
|
|
|
+ " # We set a flag to indicate that we want to run web search\n",
|
|
|
|
+ " # company_private_data = \"Yes\"\n",
|
|
|
|
+ " continue\n",
|
|
|
|
+ " return {\"documents\": filtered_docs, \"question\": question, \"company_private_data\": company_private_data}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def company_private_data_search(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Web search based based on the question\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " state (dict): Appended web results to documents\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---WEB SEARCH---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " documents = state[\"documents\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # company_private_data_search\n",
|
|
|
|
+ " # docs = web_search_tool.invoke({\"query\": question})\n",
|
|
|
|
+ " # web_results = \"\\n\".join([d[\"content\"] for d in docs])\n",
|
|
|
|
+ " # web_results = Document(page_content=web_results)\n",
|
|
|
|
+ " # TODO: correct company_private_data_search function\n",
|
|
|
|
+ " company_private_data_result = run_text_to_sql(question)\n",
|
|
|
|
+ " if documents is not None:\n",
|
|
|
|
+ " documents.append(company_private_data_result)\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " documents = [company_private_data_result]\n",
|
|
|
|
+ " return {\"documents\": documents, \"question\": question}\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "### Conditional edge\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def route_question(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Route question to web search or RAG.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " str: Next node to call\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---ROUTE QUESTION---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " print(question)\n",
|
|
|
|
+ " source = question_router.invoke({\"question\": question})\n",
|
|
|
|
+ " print(source)\n",
|
|
|
|
+ " print(source[\"datasource\"])\n",
|
|
|
|
+ " if source[\"datasource\"] == \"company_private_data\":\n",
|
|
|
|
+ " print(\"---ROUTE QUESTION TO TEXT-TO-SQL---\")\n",
|
|
|
|
+ " return \"company_private_data\"\n",
|
|
|
|
+ " elif source[\"datasource\"] == \"vectorstore\":\n",
|
|
|
|
+ " print(\"---ROUTE QUESTION TO RAG---\")\n",
|
|
|
|
+ " return \"vectorstore\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def decide_to_generate(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Determines whether to generate an answer, or add company_private_data\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " str: Binary decision for next node to call\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---ASSESS GRADED DOCUMENTS---\")\n",
|
|
|
|
+ " state[\"question\"]\n",
|
|
|
|
+ " company_private_data = state[\"company_private_data\"]\n",
|
|
|
|
+ " state[\"documents\"]\n",
|
|
|
|
+ " \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " if company_private_data == \"Yes\":\n",
|
|
|
|
+ " # All documents have been filtered check_relevance\n",
|
|
|
|
+ " # We will re-generate a new query\n",
|
|
|
|
+ " print(\n",
|
|
|
|
+ " \"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE company_private_data---\"\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ " return \"company_private_data\"\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " # We have relevant documents, so generate answer\n",
|
|
|
|
+ " print(\"---DECISION: GENERATE---\")\n",
|
|
|
|
+ " return \"generate\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "### Conditional edge\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def grade_generation_v_documents_and_question(state):\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ " Determines whether the generation is grounded in the document and answers question.\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Args:\n",
|
|
|
|
+ " state (dict): The current graph state\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " Returns:\n",
|
|
|
|
+ " str: Decision for next node to call\n",
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " print(\"---CHECK HALLUCINATIONS---\")\n",
|
|
|
|
+ " question = state[\"question\"]\n",
|
|
|
|
+ " documents = state[\"documents\"]\n",
|
|
|
|
+ " generation = state[\"generation\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " try:\n",
|
|
|
|
+ " docs_documents = \"\\n\\n\".join(doc.page_content for doc in documents)\n",
|
|
|
|
+ " except AttributeError:\n",
|
|
|
|
+ " docs_documents = \"\\n\\n\".join(doc for doc in documents)\n",
|
|
|
|
+ " # print(docs_documents)\n",
|
|
|
|
+ " # print(generation)\n",
|
|
|
|
+ " score = hallucination_grader.invoke(\n",
|
|
|
|
+ " {\"documents\": docs_documents, \"generation\": generation}\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ " print(score)\n",
|
|
|
|
+ " grade = score[\"score\"]\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # Check hallucination\n",
|
|
|
|
+ " if grade == \"yes\":\n",
|
|
|
|
+ " print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n",
|
|
|
|
+ " # Check question-answering\n",
|
|
|
|
+ " print(\"---GRADE GENERATION vs QUESTION---\")\n",
|
|
|
|
+ " score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n",
|
|
|
|
+ " grade = score[\"score\"]\n",
|
|
|
|
+ " if grade == \"yes\":\n",
|
|
|
|
+ " print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n",
|
|
|
|
+ " return \"useful\"\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n",
|
|
|
|
+ " return \"not useful\"\n",
|
|
|
|
+ " else:\n",
|
|
|
|
+ " pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n",
|
|
|
|
+ " return \"not supported\"\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "workflow = StateGraph(GraphState)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "# Define the nodes\n",
|
|
|
|
+ "workflow.add_node(\"company_private_data_search\", company_private_data_search) # web search\n",
|
|
|
|
+ "workflow.add_node(\"retrieve\", retrieve) # retrieve\n",
|
|
|
|
+ "workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n",
|
|
|
|
+ "workflow.add_node(\"generate\", generate) # generatae"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 18,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "# Build graph\n",
|
|
|
|
+ "workflow.add_conditional_edges(\n",
|
|
|
|
+ " START,\n",
|
|
|
|
+ " route_question,\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"company_private_data\": \"company_private_data_search\",\n",
|
|
|
|
+ " \"vectorstore\": \"retrieve\",\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "workflow.add_edge(\"retrieve\", \"grade_documents\")\n",
|
|
|
|
+ "workflow.add_conditional_edges(\n",
|
|
|
|
+ " \"grade_documents\",\n",
|
|
|
|
+ " decide_to_generate,\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"company_private_data\": \"company_private_data_search\",\n",
|
|
|
|
+ " \"generate\": \"generate\",\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "# workflow.add_edge(\"retrieve\", \"generate\")\n",
|
|
|
|
+ "workflow.add_edge(\"company_private_data_search\", \"generate\")\n",
|
|
|
|
+ "workflow.add_conditional_edges(\n",
|
|
|
|
+ " \"generate\",\n",
|
|
|
|
+ " grade_generation_v_documents_and_question,\n",
|
|
|
|
+ " {\n",
|
|
|
|
+ " \"not supported\": \"generate\",\n",
|
|
|
|
+ " \"useful\": END,\n",
|
|
|
|
+ " \"not useful\": \"company_private_data_search\",\n",
|
|
|
|
+ " },\n",
|
|
|
|
+ ")\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "app = workflow.compile()"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": 19,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "ename": "ValueError",
|
|
|
|
+ "evalue": "no intersection found (point inside ?!). view: <langchain_core.runnables.graph_ascii.VertexViewer object at 0x70acd7938500> topt: (0.0, 25.5)",
|
|
|
|
+ "output_type": "error",
|
|
|
|
+ "traceback": [
|
|
|
|
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
|
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
|
|
+ "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprint_ascii\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph.py:485\u001b[0m, in \u001b[0;36mGraph.print_ascii\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprint_ascii\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 484\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Print the graph as an ASCII art string.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 485\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw_ascii\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph.py:478\u001b[0m, in \u001b[0;36mGraph.draw_ascii\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Draw the graph as an ASCII art string.\"\"\"\u001b[39;00m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrunnables\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgraph_ascii\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m draw_ascii\n\u001b[0;32m--> 478\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdraw_ascii\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 479\u001b[0m \u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnodes\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 480\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medges\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 481\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph_ascii.py:251\u001b[0m, in \u001b[0;36mdraw_ascii\u001b[0;34m(vertices, edges)\u001b[0m\n\u001b[1;32m 248\u001b[0m Xs \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 249\u001b[0m Ys \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 251\u001b[0m sug \u001b[38;5;241m=\u001b[39m \u001b[43m_build_sugiyama_layout\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvertices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43medges\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m vertex \u001b[38;5;129;01min\u001b[39;00m sug\u001b[38;5;241m.\u001b[39mg\u001b[38;5;241m.\u001b[39msV:\n\u001b[1;32m 254\u001b[0m \u001b[38;5;66;03m# NOTE: moving boxes w/2 to the left\u001b[39;00m\n\u001b[1;32m 255\u001b[0m Xs\u001b[38;5;241m.\u001b[39mappend(vertex\u001b[38;5;241m.\u001b[39mview\u001b[38;5;241m.\u001b[39mxy[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m-\u001b[39m vertex\u001b[38;5;241m.\u001b[39mview\u001b[38;5;241m.\u001b[39mw \u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2.0\u001b[39m)\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/runnables/graph_ascii.py:210\u001b[0m, in \u001b[0;36m_build_sugiyama_layout\u001b[0;34m(vertices, edges)\u001b[0m\n\u001b[1;32m 207\u001b[0m sug\u001b[38;5;241m.\u001b[39mxspace \u001b[38;5;241m=\u001b[39m minw\n\u001b[1;32m 208\u001b[0m sug\u001b[38;5;241m.\u001b[39mroute_edge \u001b[38;5;241m=\u001b[39m route_with_lines\n\u001b[0;32m--> 210\u001b[0m \u001b[43msug\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 212\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sug\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/layouts.py:442\u001b[0m, in \u001b[0;36mSugiyamaLayout.draw\u001b[0;34m(self, N)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msetxy()\n\u001b[0;32m--> 442\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw_edges\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/layouts.py:814\u001b[0m, in \u001b[0;36mSugiyamaLayout.draw_edges\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 812\u001b[0m l\u001b[38;5;241m.\u001b[39mappend(e\u001b[38;5;241m.\u001b[39mv[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mview\u001b[38;5;241m.\u001b[39mxy)\n\u001b[1;32m 813\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 814\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mroute_edge\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ml\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 815\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m 816\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/routing.py:31\u001b[0m, in \u001b[0;36mroute_with_lines\u001b[0;34m(e, pts)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mroute_with_lines\u001b[39m(e, pts):\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(e, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mview\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 31\u001b[0m tail_pos \u001b[38;5;241m=\u001b[39m \u001b[43mintersectR\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mv\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtopt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpts\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 32\u001b[0m head_pos \u001b[38;5;241m=\u001b[39m intersectR(e\u001b[38;5;241m.\u001b[39mv[\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mview, topt\u001b[38;5;241m=\u001b[39mpts[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m])\n\u001b[1;32m 33\u001b[0m pts[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m=\u001b[39m tail_pos\n",
|
|
|
|
+ "File \u001b[0;32m/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/grandalf/utils/geometry.py:69\u001b[0m, in \u001b[0;36mintersectR\u001b[0;34m(view, topt)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (x, y)\n\u001b[1;32m 67\u001b[0m \u001b[38;5;66;03m# there can't be no intersection unless the endpoint was\u001b[39;00m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;66;03m# inside the bb !\u001b[39;00m\n\u001b[0;32m---> 69\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 70\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mno intersection found (point inside ?!). view: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m topt: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (view, topt)\n\u001b[1;32m 71\u001b[0m )\n",
|
|
|
|
+ "\u001b[0;31mValueError\u001b[0m: no intersection found (point inside ?!). view: <langchain_core.runnables.graph_ascii.VertexViewer object at 0x70acd7938500> topt: (0.0, 25.5)"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "app.get_graph().print_ascii()"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "def multi_query_chain(llm):\n",
|
|
|
|
+ " # Multi Query: Different Perspectives\n",
|
|
|
|
+ " template = \"\"\"\n",
|
|
|
|
+ " <|begin_of_text|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ " 你是一個來自台灣的AI助理,你的專長是根據使用者的問題來產生三種不同版本的問題,請用繁體中文。 \\n\n",
|
|
|
|
+ " You are an AI language model assistant. Your task is to generate three \n",
|
|
|
|
+ " different versions of the given user question to retrieve relevant documents from a vector \n",
|
|
|
|
+ " database. By generating multiple perspectives on the user question, your goal is to help\n",
|
|
|
|
+ " the user overcome some of the limitations of the distance-based similarity search. \n",
|
|
|
|
+ " Provide these alternative questions separated by newlines. \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " output must in user's language and no preamble or explanation..\n",
|
|
|
|
+ " <|eot_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ " Original question: {question}\n",
|
|
|
|
+ " <|eot_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>assistant<|end_header_id|>\"\"\"\n",
|
|
|
|
+ " prompt_perspectives = ChatPromptTemplate.from_template(template)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " # You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.\n",
|
|
|
|
+ " # llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
|
|
|
|
+ " # llm = ChatOllama(model=\"llama3\", num_gpu=1, temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " generate_queries = (\n",
|
|
|
|
+ " prompt_perspectives \n",
|
|
|
|
+ " | llm\n",
|
|
|
|
+ " | StrOutputParser() \n",
|
|
|
|
+ " | (lambda x: x.split(\"\\n\"))\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " return generate_queries"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "data": {
|
|
|
|
+ "text/plain": [
|
|
|
|
+ "['1. 溫室氣體的定義是什麼?', '2. 溫室氣體有哪些種類?', '3. 溫室氣體對環境的影響是什麼?']"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ "execution_count": 135,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "# from RAG_strategy import multi_query_chain\n",
|
|
|
|
+ "llm = ChatOllama(model=local_llm, temperature=0)\n",
|
|
|
|
+ "question = \"溫室氣體是什麼\"\n",
|
|
|
|
+ "generate_queries = multi_query_chain(llm)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "questions = generate_queries.invoke(question)\n",
|
|
|
|
+ "questions"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "def ask_more_detail_chain(llm):\n",
|
|
|
|
+ " # Multi Query: Different Perspectives\n",
|
|
|
|
+ " template = \"\"\"\n",
|
|
|
|
+ " <|begin_of_text|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>system<|end_header_id|>\n",
|
|
|
|
+ " 你是一個來自台灣的AI助理,你的專長是根據使用者提供的文本來進一步詢問當中的細節,例如名詞解釋,請用繁體中文。 \\n\n",
|
|
|
|
+ " You are an AI language model assistant. \n",
|
|
|
|
+ " Your task is to generate three questions about the given user context as additional explanation.\n",
|
|
|
|
+ " By generating three in-depth questions about the user context, your goal is to help the user realize more details. \n",
|
|
|
|
+ " Provide these questions separated by newlines.\n",
|
|
|
|
+ " For example:\n",
|
|
|
|
+ " context: 建準廣興廠去年2023年一共自產發電了684,508度綠電\n",
|
|
|
|
+ " in-depth question:什麼是綠電?\\n為何要使用綠電? \n",
|
|
|
|
+ "\n",
|
|
|
|
+ " output must in user's language and no preamble or explanation.\n",
|
|
|
|
+ " <|eot_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>user<|end_header_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " \n",
|
|
|
|
+ " \n",
|
|
|
|
+ " Original context: {question}\n",
|
|
|
|
+ " three questions:\n",
|
|
|
|
+ " <|eot_id|>\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " <|start_header_id|>assistant<|end_header_id|>\"\"\"\n",
|
|
|
|
+ " prompt_perspectives = ChatPromptTemplate.from_template(template)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " # llm = ChatOpenAI(temperature=0, model=\"gpt-4-1106-preview\")\n",
|
|
|
|
+ " # llm = ChatOllama(model=\"llama3\", num_gpu=1, temperature=0)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " generate_queries = (\n",
|
|
|
|
+ " prompt_perspectives \n",
|
|
|
|
+ " | llm\n",
|
|
|
|
+ " | StrOutputParser() \n",
|
|
|
|
+ " | (lambda x: x.split(\"\\n\"))\n",
|
|
|
|
+ " )\n",
|
|
|
|
+ "\n",
|
|
|
|
+ " return generate_queries"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "data": {
|
|
|
|
+ "text/plain": [
|
|
|
|
+ "['什麼是固定燃燒排放量?']"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ "execution_count": 30,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "llm = ChatOllama(model=local_llm, temperature=0)\n",
|
|
|
|
+ "question = \"建準廣興廠去年的固定燃燒排放量是多少?\"\n",
|
|
|
|
+ "generate_queries = ask_more_detail_chain(llm)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "questions = generate_queries.invoke(question)\n",
|
|
|
|
+ "questions"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "data": {
|
|
|
|
+ "text/plain": [
|
|
|
|
+ "'固定燃燒(Fixed Combustion)是一種溫室氣體排放的方法,涉及到化石燃料的燃燒,以產生熱或蒸汽。例如:鍋爐、加熱爐、緊急發電機等設施。'"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ "execution_count": 21,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "llm = ChatOllama(model=local_llm, temperature=0)\n",
|
|
|
|
+ "question = \"固定燃燒是什麼?\"\n",
|
|
|
|
+ "docs = retriever.get_relevant_documents(question, k=10)\n",
|
|
|
|
+ "generation = faiss_query(question, docs, llm)\n",
|
|
|
|
+ "generation"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "from langchain.tools import StructuredTool"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "def rag(question: str):\n",
|
|
|
|
+ " llm = ChatOllama(model=local_llm, temperature=0)\n",
|
|
|
|
+ " docs = retriever.get_relevant_documents(question, k=10)\n",
|
|
|
|
+ " answer = faiss_query(question, docs, llm)\n",
|
|
|
|
+ " \n",
|
|
|
|
+ " return answer\n",
|
|
|
|
+ " "
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "rag_tool = StructuredTool.from_function(\n",
|
|
|
|
+ " func=rag,\n",
|
|
|
|
+ " name=\"RAG\",\n",
|
|
|
|
+ " description=\"查詢碳盤查知識庫\"\n",
|
|
|
|
+ ")"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stderr",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "/usr/local/anaconda3/envs/llama3/lib/python3.12/site-packages/langchain_core/_api/deprecation.py:141: LangChainDeprecationWarning: The method `BaseTool.__call__` was deprecated in langchain-core 0.1.47 and will be removed in 1.0. Use invoke instead.\n",
|
|
|
|
+ " warn_deprecated(\n"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "'溫室氣體是一種能吸收和釋放紅外線輻射的氣體,存在於大氣中,並將熱能留在地球表面,無法散出大氣層外。'\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "from pprint import pprint\n",
|
|
|
|
+ "# from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "# from langchain_community.chat_models.ollama import ChatOllama\n",
|
|
|
|
+ "from langchain_ollama import ChatOllama\n",
|
|
|
|
+ "llm = ChatOllama(model=local_llm, temperature=0)\n",
|
|
|
|
+ "pprint(rag_tool(\"溫室氣體是什麼\"))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "工具名稱:RAG\n",
|
|
|
|
+ "工具描述:查詢碳盤查知識庫\n",
|
|
|
|
+ "工具參數:{'question': {'title': 'Question', 'type': 'string'}}\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "print(f'工具名稱:{rag_tool.name}')\n",
|
|
|
|
+ "print(f'工具描述:{rag_tool.description}')\n",
|
|
|
|
+ "print(f'工具參數:{rag_tool.args}')"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "sql_tool = StructuredTool.from_function(\n",
|
|
|
|
+ " func=run_text_to_sql,\n",
|
|
|
|
+ " name=\"SQL\",\n",
|
|
|
|
+ " description=\"查詢客戶內部資料\"\n",
|
|
|
|
+ ")"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "SELECT SUM(\"高雄總部及運通廠\" + \"台北辦事處\" + \"昆山廣興廠\" + \"北海建準廠\" + \"北海立準廠\" + \"菲律賓建準廠\" + \"Inc\" + \"SAS\" + \"India\") AS \"固定燃燒總排放量\"\n",
|
|
|
|
+ "FROM \"2023 清冊數據(GHG)\"\n",
|
|
|
|
+ "WHERE \"排放源\" = '固定燃燒'\n",
|
|
|
|
+ "[(13.5953,)]\n",
|
|
|
|
+ "'建準去年的固定燃燒總排放量是13.5953。'\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "pprint(sql_tool(\"建準去年的固定燃燒總排放量是多少?\"))"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "tools = [rag_tool, sql_tool]\n",
|
|
|
|
+ "model_with_tools = llm.bind_tools(tools)\n",
|
|
|
|
+ "tool_map = {tool.name: tool for tool in tools}"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "from langchain_core.runnables import RunnablePassthrough\n",
|
|
|
|
+ "from operator import itemgetter, attrgetter\n",
|
|
|
|
+ "from langchain_core.runnables import (\n",
|
|
|
|
+ " RunnableLambda, RunnablePassthrough)\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "def call_tool(tool_invocation):\n",
|
|
|
|
+ " tool = tool_map[tool_invocation[\"type\"]]\n",
|
|
|
|
+ " return RunnablePassthrough.assign(\n",
|
|
|
|
+ " output=itemgetter(\"args\") | tool)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "from langchain.output_parsers import JsonOutputToolsParser\n",
|
|
|
|
+ "call_tool_list = RunnableLambda(call_tool).map()\n",
|
|
|
|
+ "chain = (model_with_tools\n",
|
|
|
|
+ " | JsonOutputToolsParser()\n",
|
|
|
|
+ " | call_tool_list)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "data": {
|
|
|
|
+ "text/plain": [
|
|
|
|
+ "[{'args': {'question': '測試測試測試介紹溫室氣體是什麼'},\n",
|
|
|
|
+ " 'type': 'RAG',\n",
|
|
|
|
+ " 'output': '溫室氣體(Greenhouse Gas)指的是會吸收和釋放紅外線輻射並存在大氣中的氣體。這些氣體包括二氧化碳(CO2)、甲烷(CH4)、氧化亞氮(N2O)、氫氟碳化物(HFCs)、全氟碳化物(PFCs)、六氟化硫(SF6)及三氟化氮(NF3)。'}]"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ "execution_count": 74,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "# \"建準去年的綠電使用量是多少?\"\n",
|
|
|
|
+ "# 溫室氣體是什麼\n",
|
|
|
|
+ "chain.invoke(\"測試測試測試介紹溫室氣體是什麼?建準綠電?測試測試測試測試測試測試\")"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "from langchain_core.prompts import MessagesPlaceholder\n",
|
|
|
|
+ "from langchain.agents import (\n",
|
|
|
|
+ " AgentExecutor, create_openai_tools_agent)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "from langchain_community.chat_models import ChatOllama\n",
|
|
|
|
+ "local_llm = \"llama3-groq-tool-use:latest\"\n",
|
|
|
|
+ "llm = ChatOllama(model=local_llm, temperature=0)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "prompt = ChatPromptTemplate.from_messages([\n",
|
|
|
|
+ " ('system',\"\"\"你是一位善用工具的繁體中文助理,會判斷問題來選擇適合的工具,\n",
|
|
|
|
+ " 與公司私人數據來回答有關公司溫室氣體排放數據資訊的問題(工具名稱:SQL),\n",
|
|
|
|
+ " 或使用向量庫來解答有關 ESG 領域知識或有關 ESG 新聞的問題(工具名稱:RAG)。\"\"\"),\n",
|
|
|
|
+ " ('human','{input}'),\n",
|
|
|
|
+ " MessagesPlaceholder(variable_name=\"agent_scratchpad\")\n",
|
|
|
|
+ "])"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "# tools = [weather_data, search_run]\n",
|
|
|
|
+ "from langchain_openai import ChatOpenAI\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "openai = ChatOpenAI(temperature=0)\n",
|
|
|
|
+ "agent = create_openai_tools_agent(llm=llm,\n",
|
|
|
|
+ " tools=tools,\n",
|
|
|
|
+ " prompt=prompt)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [],
|
|
|
|
+ "source": [
|
|
|
|
+ "agent_executor = AgentExecutor(agent=agent,\n",
|
|
|
|
+ " tools=tools,\n",
|
|
|
|
+ " verbose=True)"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
|
|
+ "\u001b[32;1m\u001b[1;3m\n",
|
|
|
|
+ "Invoking: `SQL` with `{'question': \"SELECT SUM(fixed_burn_total) AS total FROM last_year_data WHERE category = 'CO2'\"}`\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\u001b[0mTo answer your question, I can help you with that. Could you please specify what kind of data you are looking for? For example, the total CO2 emissions or the direct emissions from a specific factory?\n",
|
|
|
|
+ "Error: (psycopg2.errors.SyntaxError) syntax error at or near \"To\"\n",
|
|
|
|
+ "LINE 1: To answer your question, I can help you with that. Could you...\n",
|
|
|
|
+ " ^\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "[SQL: To answer your question, I can help you with that. Could you please specify what kind of data you are looking for? For example, the total CO2 emissions or the direct emissions from a specific factory?]\n",
|
|
|
|
+ "(Background on this error at: https://sqlalche.me/e/20/f405)\n",
|
|
|
|
+ "\u001b[33;1m\u001b[1;3mIt seems like there's an issue with your SQL query. The error message indicates a syntax error near \"To\". Could you please provide the correct SQL query or clarify what you're trying to achieve?\u001b[0m\u001b[32;1m\u001b[1;3mIt looks like there was a mistake in the SQL query I tried to execute. Can you confirm if the category name is 'CO2' and if so, should it be enclosed within single quotes?\u001b[0m\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\u001b[1m> Finished chain.\u001b[0m\n",
|
|
|
|
+ "It looks like there was a mistake in the SQL query I tried to execute. Can you confirm if the category name is 'CO2' and if so, should it be enclosed within single quotes?\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "result = agent_executor.invoke({\"input\": \"建準去年的固定燃燒總排放量是多少?碳權是什麼?\"})\n",
|
|
|
|
+ "print(result['output'])"
|
|
|
|
+ ]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "cell_type": "code",
|
|
|
|
+ "execution_count": null,
|
|
|
|
+ "metadata": {},
|
|
|
|
+ "outputs": [
|
|
|
|
+ {
|
|
|
|
+ "name": "stdout",
|
|
|
|
+ "output_type": "stream",
|
|
|
|
+ "text": [
|
|
|
|
+ "\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
|
|
|
+ "\u001b[32;1m\u001b[1;3m您好!我可以幫助您查詢一些資訊。您需要查詢什麼?\u001b[0m\n",
|
|
|
|
+ "\n",
|
|
|
|
+ "\u001b[1m> Finished chain.\u001b[0m\n",
|
|
|
|
+ "您好!我可以幫助您查詢一些資訊。您需要查詢什麼?\n"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "source": [
|
|
|
|
+ "result = agent_executor.invoke({\"input\": \"你好\"})\n",
|
|
|
|
+ "print(result['output'])"
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ ],
|
|
|
|
+ "metadata": {
|
|
|
|
+ "kernelspec": {
|
|
|
|
+ "display_name": "llama3",
|
|
|
|
+ "language": "python",
|
|
|
|
+ "name": "python3"
|
|
|
|
+ },
|
|
|
|
+ "language_info": {
|
|
|
|
+ "codemirror_mode": {
|
|
|
|
+ "name": "ipython",
|
|
|
|
+ "version": 3
|
|
|
|
+ },
|
|
|
|
+ "file_extension": ".py",
|
|
|
|
+ "mimetype": "text/x-python",
|
|
|
|
+ "name": "python",
|
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
|
+ "version": "3.12.4"
|
|
|
|
+ }
|
|
|
|
+ },
|
|
|
|
+ "nbformat": 4,
|
|
|
|
+ "nbformat_minor": 2
|
|
|
|
+}
|