|
@@ -0,0 +1,309 @@
|
|
|
+import re
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+from langchain_community.utilities import SQLDatabase
|
|
|
+import os
|
|
|
+URI: str = os.environ.get('SUPABASE_URI')
|
|
|
+db = SQLDatabase.from_uri(URI)
|
|
|
+
|
|
|
+# print(db.dialect)
|
|
|
+# print(db.get_usable_table_names())
|
|
|
+# db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
|
|
|
+
|
|
|
+context = db.get_context()
|
|
|
+# print(list(context))
|
|
|
+# print(context["table_info"])
|
|
|
+
|
|
|
+from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
|
|
|
+from langchain.chains import create_sql_query_chain
|
|
|
+from langchain_community.llms import Ollama
|
|
|
+
|
|
|
+from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
|
|
+from operator import itemgetter
|
|
|
+
|
|
|
+from langchain_core.output_parsers import StrOutputParser
|
|
|
+from langchain_core.prompts import PromptTemplate
|
|
|
+from langchain_core.runnables import RunnablePassthrough
|
|
|
+
|
|
|
+# Load model directly
|
|
|
+from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
|
|
|
+import torch
|
|
|
+from langchain_huggingface import HuggingFacePipeline
|
|
|
+
|
|
|
+# Load model directly
|
|
|
+from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
+# model_id = "defog/llama-3-sqlcoder-8b"
|
|
|
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
+# sql_llm = HuggingFacePipeline.from_model_id(
|
|
|
+# model_id=model_id,
|
|
|
+# task="text-generation",
|
|
|
+# model_kwargs={"torch_dtype": torch.bfloat16},
|
|
|
+# pipeline_kwargs={"return_full_text": False},
|
|
|
+# device=0, device_map='cuda')
|
|
|
+
|
|
|
+##########################################################################################
|
|
|
+from langchain_community.chat_models import ChatOllama
|
|
|
+# local_llm = "llama3-groq-tool-use:latest"
|
|
|
+local_llm = "llama3-groq-tool-use:latest"
|
|
|
+llm = ChatOllama(model=local_llm, temperature=0)
|
|
|
+##########################################################################################
|
|
|
+# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
|
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
+
|
|
|
+# llm = HuggingFacePipeline.from_model_id(
|
|
|
+# model_id=model_id,
|
|
|
+# task="text-generation",
|
|
|
+# model_kwargs={"torch_dtype": torch.bfloat16},
|
|
|
+# pipeline_kwargs={"return_full_text": False,
|
|
|
+# "max_new_tokens": 512},
|
|
|
+# device=0, device_map='cuda')
|
|
|
+# print(llm.pipeline)
|
|
|
+# llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
|
|
|
+##########################################################################################
|
|
|
+
|
|
|
+# model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
|
|
|
+
|
|
|
+# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
|
|
|
+# model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
|
|
|
+#, device="auto", load_in_4bit=True
|
|
|
+# llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
+
|
|
|
+# llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
+
|
|
|
+# llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
|
|
|
+def get_examples():
|
|
|
+ examples = [
|
|
|
+ {
|
|
|
+ "input": "建準廣興廠2023年的自產電力的綠電使用量是多少?",
|
|
|
+ "query": """SELECT SUM("用電度數(kwh)") AS "自產電力綠電使用量"
|
|
|
+ FROM "用電度數"
|
|
|
+ WHERE "項目" = '自產電力(綠電)'
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = 2023;""",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "input": "建準廣興廠去年的類別1總排放量是多少?",
|
|
|
+ "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
|
|
|
+ FROM "建準碳排放清冊數據"
|
|
|
+ WHERE "事業名稱" like '%建準%'
|
|
|
+ AND "事業名稱" like '%廣興廠%'
|
|
|
+ AND ("類別" like '%類別1-直接排放%' OR "排放源" like '%類別1-直接排放%')
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?",
|
|
|
+ "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
|
|
|
+ FROM "建準碳排放清冊數據"
|
|
|
+ WHERE "事業名稱" like '%建準%'
|
|
|
+ AND "事業名稱" like '%台北辦事處%'
|
|
|
+ AND ("類別" like '%類別2-能源間接排放%' OR "排放源" like '%類別2-能源間接排放%')
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = 2022;""",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "input": "建準去年的固定燃燒總排放量是多少?",
|
|
|
+ "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
|
|
|
+ FROM "建準碳排放清冊數據"
|
|
|
+ WHERE "事業名稱" like '%建準%'
|
|
|
+ AND ("類別" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
|
|
|
+ AND "盤查標準" = 'GHG'
|
|
|
+ AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
|
|
|
+ },
|
|
|
+
|
|
|
+
|
|
|
+ ]
|
|
|
+
|
|
|
+ return examples
|
|
|
+
|
|
|
+def table_description():
|
|
|
+ database_description = (
|
|
|
+ "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據`. "
|
|
|
+ "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
|
|
|
+ "The `建準碳排放清冊數據` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
|
|
|
+ "It includes the following columns:\n"
|
|
|
+ "- `年度`: 盤查年度\n"
|
|
|
+ "- `事業名稱`: 建準據點"
|
|
|
+ "- `國家`: 據點所在國家"
|
|
|
+ "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
|
|
|
+ " \t*類別1-直接排放:\n"
|
|
|
+ " \t*類別2-能源間接排放\n"
|
|
|
+ " \t*類別3-運輸間接排放\n"
|
|
|
+ " \t*類別4-組織使用產品間接排放\n"
|
|
|
+ " \t*類別5-使用來自組織產品間接排放\n"
|
|
|
+ " \t*類別6\n"
|
|
|
+ "- `排放源`: 由`類別`欄位進一步劃分的細項,包含以下選項:`固定燃燒`, `移動燃燒`, `製程排放`, `逸散排放`, `土地利用`, "
|
|
|
+ "`外購電力`, `外購能源`, `上游運輸`, `下游運輸`, `員工通勤`, `商務旅行`, `訪客運輸`, "
|
|
|
+ "`購買產品`, `外購燃料及能資源`, `資本貨物`, `上游租賃`, `廢棄物處理`, `廢棄物清運`, `其他委外業務`, "
|
|
|
+ "`產品加工`, `產品使用`, `產品最終處理`, `下游租賃`, `投資排放`, `其他`, `其他間接排放` \n"
|
|
|
+ "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
|
|
|
+ "- `盤查標準`: ISO or GHG\n"
|
|
|
+
|
|
|
+
|
|
|
+ "The `用電度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
|
|
|
+ "It includes the following columns:\n"
|
|
|
+ "- `年度`: 盤查年度\n"
|
|
|
+ "- `事業名稱`: 建準據點"
|
|
|
+ "- `國家`: 據點所在國家"
|
|
|
+ "- `項目`: 用電項目,包含以下:\n"
|
|
|
+ " \t*外購電力(灰電): 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
|
|
|
+ " \t*外購電力(綠電): 綠電(太陽光電)的外購電力度數(kwh)\n"
|
|
|
+ " \t*自產電力(綠電): 綠電(太陽光電)的自產電力度數(kwh)\n"
|
|
|
+ "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
|
|
|
+ "- `盤查標準`: ISO or GHG\n"
|
|
|
+
|
|
|
+ "The `用水度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
|
|
|
+ "It includes the following columns:\n"
|
|
|
+ "- `年度`: 盤查年度\n"
|
|
|
+ "- `事業名稱`: 建準據點"
|
|
|
+ "- `國家`: 據點所在國家"
|
|
|
+ "- `自來水度數(立方公尺 m³)`: 用水度數,單位為m³\n"
|
|
|
+ "- `盤查標準`: ISO or GHG\n"
|
|
|
+ )
|
|
|
+
|
|
|
+ return database_description
|
|
|
+
|
|
|
+def write_query_chain(db, llm):
|
|
|
+
|
|
|
+ template = """
|
|
|
+ <|begin_of_text|>
|
|
|
+
|
|
|
+ <|start_header_id|>system<|end_header_id|>
|
|
|
+ Generate a SQL query to answer this question: `{input}`
|
|
|
+
|
|
|
+ You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
|
|
|
+ then look at the results of the query and return the answer to the input question.\n\
|
|
|
+ Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
|
|
|
+ You can order the results to return the most informative data in the database.\n\
|
|
|
+ Never query for all columns from a table. You must query only the columns that are needed to answer the question.
|
|
|
+ Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
|
|
|
+
|
|
|
+ ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
|
|
|
+ ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
|
|
+ DDL statements:
|
|
|
+ {table_info}
|
|
|
+
|
|
|
+ The following is a description of database. Please refer to the database description to give the correct WHERE statement in the PostgreSQL query.\
|
|
|
+ In particular, the details of the `排放源` and `類別` columns.\n
|
|
|
+ database description:
|
|
|
+ {database_description}
|
|
|
+
|
|
|
+ Provide ONLY PostgreSQL query and NO premable or explanation!
|
|
|
+ Below are a number of examples of questions and their corresponding SQL queries.\n\
|
|
|
+
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>assistant<|end_header_id|>
|
|
|
+ """
|
|
|
+ # prompt_template = PromptTemplate.from_template(template)
|
|
|
+
|
|
|
+ example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
|
|
|
+ prompt = FewShotPromptTemplate(
|
|
|
+ examples=get_examples(),
|
|
|
+ example_prompt=example_prompt,
|
|
|
+ prefix=template,
|
|
|
+ suffix="User input: {input}\nSQL query: ",
|
|
|
+ input_variables=["input", "top_k", "table_info"],
|
|
|
+ )
|
|
|
+
|
|
|
+ # llm = Ollama(model = "sqlcoder", num_gpu=1)
|
|
|
+ # llm = HuggingFacePipeline(pipeline=pipe)
|
|
|
+
|
|
|
+
|
|
|
+ write_query = create_sql_query_chain(llm, db, prompt)
|
|
|
+
|
|
|
+
|
|
|
+ return write_query
|
|
|
+
|
|
|
+def sql_to_nl_chain(llm):
|
|
|
+ # llm = Ollama(model = "llama3.1", num_gpu=1)
|
|
|
+ # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
|
|
|
+ # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
|
|
|
+ answer_prompt = PromptTemplate.from_template(
|
|
|
+ """
|
|
|
+ <|begin_of_text|>
|
|
|
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
|
|
+ Given the following user question, corresponding SQL query, and SQL result, answer the user question.
|
|
|
+ 根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
|
|
|
+
|
|
|
+
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
|
|
+ Question: {question}
|
|
|
+ SQL Query: {query}
|
|
|
+ SQL Result: {result}
|
|
|
+ Answer:
|
|
|
+ <|eot_id|>
|
|
|
+
|
|
|
+ <|start_header_id|>assistant<|end_header_id|>
|
|
|
+
|
|
|
+ """
|
|
|
+ )
|
|
|
+
|
|
|
+ # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
|
|
|
+ chain = answer_prompt | llm | StrOutputParser()
|
|
|
+
|
|
|
+ return chain
|
|
|
+
|
|
|
+def get_query(db, question, selected_table, llm):
|
|
|
+ write_query = write_query_chain(db, llm)
|
|
|
+ query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
|
|
|
+
|
|
|
+ query = re.split('SQL query: ', query)[-1]
|
|
|
+ # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
|
|
|
+ print(query)
|
|
|
+
|
|
|
+ return query
|
|
|
+
|
|
|
+def query_to_nl(db, question, query, llm):
|
|
|
+ execute_query = QuerySQLDataBaseTool(db=db)
|
|
|
+ result = execute_query.invoke(query)
|
|
|
+ print(result)
|
|
|
+
|
|
|
+ chain = sql_to_nl_chain(llm)
|
|
|
+ answer = chain.invoke({"question": question, "query": query, "result": result})
|
|
|
+
|
|
|
+ return answer
|
|
|
+
|
|
|
+def run(db, question, selected_table, llm):
|
|
|
+
|
|
|
+ write_query = write_query_chain(db, llm)
|
|
|
+ query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
|
|
|
+
|
|
|
+ query = re.split('SQL query: ', query)[-1]
|
|
|
+ # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
|
|
|
+ print(query)
|
|
|
+
|
|
|
+ execute_query = QuerySQLDataBaseTool(db=db)
|
|
|
+ result = execute_query.invoke(query)
|
|
|
+ print(result)
|
|
|
+
|
|
|
+ chain = sql_to_nl_chain(llm)
|
|
|
+ answer = chain.invoke({"question": question, "query": query, "result": result})
|
|
|
+
|
|
|
+ return query, result, answer
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import time
|
|
|
+
|
|
|
+ start = time.time()
|
|
|
+
|
|
|
+ selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據']
|
|
|
+ question = "建準去年的上游運輸總排放量是多少?"
|
|
|
+ # question = "台積電2022年的直接排放總排放量是多少?"
|
|
|
+ # question = "建準廣興廠去年的灰電使用量"
|
|
|
+ query, result, answer = run(db, question, selected_table, llm)
|
|
|
+ print("question: ", question)
|
|
|
+ print("query: ", query)
|
|
|
+ print("result: ", result)
|
|
|
+ print("answer: ", answer)
|
|
|
+
|
|
|
+ print(time.time()-start)
|
|
|
+
|