123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- import re
- from dotenv import load_dotenv
- load_dotenv()
- from langchain_community.utilities import SQLDatabase
- import os
- URI: str = os.environ.get('SUPABASE_URI')
- db = SQLDatabase.from_uri(URI)
- # print(db.dialect)
- # print(db.get_usable_table_names())
- # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
- context = db.get_context()
- # print(list(context))
- # print(context["table_info"])
- from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
- from langchain.chains import create_sql_query_chain
- from langchain_community.llms import Ollama
- from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
- from operator import itemgetter
- from langchain_core.output_parsers import StrOutputParser
- from langchain_core.prompts import PromptTemplate
- from langchain_core.runnables import RunnablePassthrough
- # Load model directly
- from transformers import AutoTokenizer, AutoModelForCausalLM
- from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
- import torch
- from langchain_huggingface import HuggingFacePipeline
- # Load model directly
- from transformers import AutoTokenizer, AutoModelForCausalLM
- # model_id = "defog/llama-3-sqlcoder-8b"
- # tokenizer = AutoTokenizer.from_pretrained(model_id)
- # sql_llm = HuggingFacePipeline.from_model_id(
- # model_id=model_id,
- # task="text-generation",
- # model_kwargs={"torch_dtype": torch.bfloat16},
- # pipeline_kwargs={"return_full_text": False},
- # device=0, device_map='cuda')
- ##########################################################################################
- from langchain_community.chat_models import ChatOllama
- # local_llm = "llama3-groq-tool-use:latest"
- # local_llm = "llama3-groq-tool-use:latest"
- # local_llm = "sqlcoder:latest"
- # local_llm = "llama3.1:8b-instruct-q2_K"
- # llm = ChatOllama(model=local_llm, temperature=0)
- ##########################################################################################
- # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
- # tokenizer = AutoTokenizer.from_pretrained(model_id)
- # llm = HuggingFacePipeline.from_model_id(
- # model_id=model_id,
- # task="text-generation",
- # model_kwargs={"torch_dtype": torch.bfloat16},
- # pipeline_kwargs={"return_full_text": False,
- # "max_new_tokens": 512},
- # device=0, device_map='cuda')
- # print(llm.pipeline)
- # llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
- ##########################################################################################
- # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
- # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
- # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
- #, device="auto", load_in_4bit=True
- # llm = HuggingFacePipeline(pipeline=pipe)
- # llm = HuggingFacePipeline(pipeline=pipe)
- # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
- from langchain_openai import ChatOpenAI
- llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
- def get_examples():
- examples = [
- {
- "input": "建準去年固定燃燒總排放量",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
- },
- {
- "input": "廣興廠去年的固定燃燒排放量是多少?",
- "query": """FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "據點" = '昆山廣興廠'
- AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
- },
- {
- "input": "建準廣興廠去年自產電力的綠電使用量是多少?",
- "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
- FROM "用電度數"
- WHERE "項目" like '%綠電%'
- AND "事業名稱" like '%建準%'
- AND "據點" = '昆山廣興廠'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
- },
- {
- "input": "廣興廠2023綠電使用量",
- "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
- FROM "用電度數"
- WHERE "項目" like '%綠電%'
- AND "事業名稱" like '%建準%'
- AND "據點" = '昆山廣興廠'
- AND "盤查標準" = 'GHG'
- AND "年度" = 2023;""",
- },
- {
- "input": "北海廠去年的類別1總排放量",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "據點" in ('北海建準廠', '北海立準廠')
- AND "類別" = '類別1'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
- },
- {
- "input": "廣興廠去年的直接排放總排放量是多少?",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "據點" = '昆山廣興廠'
- AND ("類別項目" like '%直接排放%' OR "排放源" like '%直接排放%')
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
- },
- {
- "input": "建準台北辦事處2022年的類別2總排放量是多少?",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "據點" = '台北辦事處'
- AND "類別" = '類別2'
- AND "盤查標準" = 'GHG'
- AND "年度" = 2022;""",
- },
- {
- "input": "建準法國廠2022年的類別2總排放量",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "國家" = '法國'
- AND "類別" = '類別2'
- AND "盤查標準" = 'GHG'
- AND "年度" = 2022;""",
- },
- {
- "input": "建準北海2022的外購電力是多少",
- "query": """SELECT SUM("用電度數(kwh)") AS "外購電力"
- FROM "用電度數"
- WHERE "事業名稱" like '%建準%'
- AND "據點" in ('北海建準廠', '北海立準廠')
- AND "項目" like '%外購電力%'
- AND "盤查標準" = 'GHG'
- AND "年度" = 2022;""",
- },
- {
- "input": "2023建準印度的其他間接排放是多少",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "其他間接排放總量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "國家" = '印度'
- AND ("類別項目" like '%其他間接排放%' OR "排放源" like '%其他間接排放%')
- AND "盤查標準" = 'GHG'
- AND "年度" = 2023;""",
- },
- {
- "input": "建準台北前年的產品使用碳排放量是多少",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "產品使用總量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "據點" = '台北辦事處'
- AND ("類別項目" like '%產品使用%' OR "排放源" like '%產品使用%')
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-2;""",
- },
- {
- "input": "建準去年範疇三排放量",
- "query": """SELECT SUM("排放量(公噸CO2e)") AS "範疇三排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "範疇" = '範疇三'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
- },
- ]
- return examples
- def table_description():
- database_description = (
- "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據new`."
- "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
- "The `建準碳排放清冊數據new` table 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
- "It includes the following columns:\n"
- "- `年度`: 盤查年度\n"
- "- `事業名稱`: 公司名稱"
- "- `據點`: 建準廠房據點 include '高雄總部及運通廠', '台北辦事處', '昆山廣興廠', '北海建準廠', '北海立準廠', '菲律賓建準廠', 'Inc', 'SAS', 'India'"
- "- `國家`: 據點所在國家"
- "- `範疇`: 碳盤查中把溫室氣體排放源分成三大範疇"
- "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
- " \t*類別1-直接排放:\n"
- " \t*類別2-能源間接排放\n"
- " \t*類別3-運輸間接排放\n"
- " \t*類別4-組織使用產品間接排放\n"
- " \t*類別5-使用來自組織產品間接排放\n"
- " \t*類別6\n"
- "- `排放源`: 由`類別`欄位進一步劃分的細項,包含以下選項:`固定燃燒`, `移動燃燒`, `製程排放`, `逸散排放`, `土地利用`, "
- "`外購電力`, `外購能源`, `上游運輸`, `下游運輸`, `員工通勤`, `商務旅行`, `訪客運輸`, "
- "`購買產品`, `外購燃料及能資源`, `資本貨物`, `上游租賃`, `廢棄物處理`, `廢棄物清運`, `其他委外業務`, "
- "`產品加工`, `產品使用`, `產品最終處理`, `下游租賃`, `投資排放`, `其他`, `其他間接排放` \n"
- "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
- "- `盤查標準`: ISO or GHG\n"
-
- "The `用電度數` 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
- "It includes the following columns:\n"
- "- `年度`: 盤查年度\n"
- "- `事業名稱`: 建準據點"
- "- `國家`: 據點所在國家"
- "- `項目`: 用電項目,包含以下:\n"
- " \t*外購電力(灰電): 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
- " \t*外購電力(綠電): 綠電(太陽光電)的外購電力度數(kwh)\n"
- " \t*自產電力(綠電): 綠電(太陽光電)的自產電力度數(kwh)\n"
- "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
- "- `盤查標準`: ISO or GHG\n"
-
- "The `用水度數` 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
- "It includes the following columns:\n"
- "- `年度`: 盤查年度\n"
- "- `事業名稱`: 建準據點"
- "- `國家`: 據點所在國家"
- "- `自來水度數(立方公尺 m³)`: 用水度數,單位為m³\n"
- "- `盤查標準`: ISO or GHG\n"
- )
- return database_description
- def write_query_chain(db, llm):
- template = """
- <|begin_of_text|>
-
- <|start_header_id|>system<|end_header_id|>
- Generate a SQL query to answer this question: `{input}`
- 你是建準的AI助理,幫助建準查詢碳排放量,如果問題中有提到據點廠房,請使用 PostgreSQL query 進行篩選。
- You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
- then look at the results of the query and return the answer to the input question.\n\
- Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
- You can order the results to return the most informative data in the database.\n\
- Never query for all columns from a table. You must query only the columns that are needed to answer the question.
- Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
-
- Unless the user ask for the type of 盤查標準 to be 'ISO' or 'GHG', queries always include query "盤查標準"='GHG' in the WHERE clause.\n
- ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
- ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
- <|eot_id|>
-
- <|begin_of_text|><|start_header_id|>user<|end_header_id|>
- DDL statements:
- {table_info}
- The following is a description of database. Please refer to the database description to give the correct WHERE statement in the PostgreSQL query.\
- In particular, the details of the `排放源` and `類別` columns.\n
- database description:
- {database_description}
- Provide ONLY PostgreSQL query and NO premable or explanation!
- Below are a number of examples of questions and their corresponding SQL queries.\n\
-
- <|eot_id|>
- SQL query:
- """
- # <|start_header_id|>assistant<|end_header_id|>
- # prompt_template = PromptTemplate.from_template(template)
- example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
- prompt = FewShotPromptTemplate(
- examples=get_examples(),
- example_prompt=example_prompt,
- prefix=template,
- suffix="User input: {input}\nSQL query: ",
- input_variables=["input", "top_k", "table_info"],
- )
- # llm = Ollama(model = "sqlcoder", num_gpu=1)
- # llm = HuggingFacePipeline(pipeline=pipe)
-
-
- # sqlcoder = Ollama(model = "sqlcoder", num_gpu=1)
- write_query = create_sql_query_chain(llm, db, prompt)
- return write_query
- def sql_to_nl_chain(llm):
- # llm = Ollama(model = "llama3.1", num_gpu=1)
- # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
- # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
- answer_prompt = PromptTemplate.from_template(
- """
- <|begin_of_text|>
- <|begin_of_text|><|start_header_id|>system<|end_header_id|>
- Given the following user question, corresponding SQL query, and SQL result, answer the user question.
- 根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
- ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
- 如果有單位,請回答時使用單位。
-
- The following shows some example:
- Question: 建準廣興廠去年的類別1總排放量是多少?
- SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
- FROM "建準碳排放清冊數據new"
- WHERE "事業名稱" like '%建準%'
- AND "據點" = '昆山廣興廠'
- AND "類別" = '類別1'
- AND "盤查標準" = 'GHG'
- AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
- SQL Result: [(1102.3712,)]
- Answer: 建準廣興廠去年的類別1總排放量是1102.3712公噸CO2e
- 如果你不知道答案或SQL query 出現Error請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- 若 SQL Result 為 0 代表數據為0。
- 勿回答無關資訊
- <|eot_id|>
- <|begin_of_text|><|start_header_id|>user<|end_header_id|>
- Question: {question}
- SQL Query: {query}
- SQL Result: {result}
- Answer:
- <|eot_id|>
-
- <|start_header_id|>assistant<|end_header_id|>
-
- """
- )
- # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
- chain = answer_prompt | llm | StrOutputParser()
- return chain
- def get_query(db, question, selected_table, llm):
-
- write_query = write_query_chain(db, llm)
- query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
-
- # Regular expression pattern to extract SQL query
- sql_pattern = r'SELECT[\s\S]+?;'
- # Extract SQL query using re.search
- sql_query = re.search(sql_pattern, query)
- if sql_query:
- query = sql_query.group()
- # print(sql_query.group())
- else:
- print("No SQL query found.")
- query = re.split('SQL query: ', query)[-1]
- query = query.replace("```sql","").replace("```","")
- query = query.replace("碰排","碳排")
- query = query.replace("%%","%")
- # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
- print(query)
-
- execute_query = QuerySQLDataBaseTool(db=db)
- result = execute_query.invoke(query)
- print(result)
- return query, result
- def query_to_nl(question, query, result, llm):
- # execute_query = QuerySQLDataBaseTool(db=db)
- # result = execute_query.invoke(query)
- # print(result)
- chain = sql_to_nl_chain(llm)
- print(result)
- answer = chain.invoke({"question": question, "query": query, "result": result})
- return answer
- def run(db, question, selected_table, llm):
- write_query = write_query_chain(db, llm)
- query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
-
- query = re.split('SQL query: ', query)[-1]
- query = query.replace("```sql","").replace("```","")
- query = query.replace("碰排","碳排")
- query = query.replace("%%","%")
- # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
- print(query)
- execute_query = QuerySQLDataBaseTool(db=db)
- result = execute_query.invoke(query)
- print(result)
- chain = sql_to_nl_chain(llm)
- answer = chain.invoke({"question": question, "query": query, "result": result})
- return query, result, answer
- if __name__ == "__main__":
- import time
-
- start = time.time()
-
- selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
- # question = "建準廣興廠去年的上游運輸總排放量是多少?"
- question = "建準北海廠去年的固定燃燒排放量是多少?"
- # question = "建準北海廠去年類別1總排放量是多少?"
- # question = "台積電2022年的直接排放總排放量是多少?"
- # question = "建準廣興廠去年的灰電使用量"
- query, result, answer = run(db, question, selected_table, llm)
- print("question: ", question)
- print("query: ", query)
- print("result: ", result)
- print("answer: ", answer)
-
- print(time.time()-start)
|