import re from dotenv import load_dotenv load_dotenv() from langchain_community.utilities import SQLDatabase import os URI: str = os.environ.get('SUPABASE_URI') db = SQLDatabase.from_uri(URI) # print(db.dialect) # print(db.get_usable_table_names()) # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;') context = db.get_context() # print(list(context)) # print(context["table_info"]) from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate from langchain.chains import create_sql_query_chain from langchain_community.llms import Ollama from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool from operator import itemgetter from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnablePassthrough # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline import torch from langchain_huggingface import HuggingFacePipeline # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM # model_id = "defog/llama-3-sqlcoder-8b" # tokenizer = AutoTokenizer.from_pretrained(model_id) # sql_llm = HuggingFacePipeline.from_model_id( # model_id=model_id, # task="text-generation", # model_kwargs={"torch_dtype": torch.bfloat16}, # pipeline_kwargs={"return_full_text": False}, # device=0, device_map='cuda') ########################################################################################## from langchain_community.chat_models import ChatOllama # local_llm = "llama3-groq-tool-use:latest" local_llm = "llama3-groq-tool-use:latest" llm = ChatOllama(model=local_llm, temperature=0) ########################################################################################## # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" # tokenizer = AutoTokenizer.from_pretrained(model_id) # llm = HuggingFacePipeline.from_model_id( # model_id=model_id, # task="text-generation", # model_kwargs={"torch_dtype": torch.bfloat16}, # pipeline_kwargs={"return_full_text": False, # "max_new_tokens": 512}, # device=0, device_map='cuda') # print(llm.pipeline) # llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0] ########################################################################################## # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True) # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1, # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False}) #, device="auto", load_in_4bit=True # llm = HuggingFacePipeline(pipeline=pipe) # llm = HuggingFacePipeline(pipeline=pipe) # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1) def get_examples(): examples = [ { "input": "建準廣興廠去年的類別1總排放量是多少?", "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量" FROM "104_112碳排放公開及建準資料" WHERE "事業名稱" like '%建準%' AND "事業名稱" like '%廣興廠%' AND "類別" = '類別1-直接排放' AND "盤查標準" = 'GHG' AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""", }, { "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?", "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量" FROM "104_112碳排放公開及建準資料" WHERE "事業名稱" like '%建準%' AND "事業名稱" like '%台北辦事處%' AND "類別" = '類別2-能源間接排放' AND "盤查標準" = 'GHG' AND "年度" = 2022;""", }, { "input": "建準去年的固定燃燒總排放量是多少?", "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量" FROM "104_112碳排放公開及建準資料" WHERE "事業名稱" like '%建準%' AND "排放源" = '固定燃燒' AND "盤查標準" = 'GHG' AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""", }, { "input": "台積電2022年的類別1總排放量是多少?", "query": """SELECT SUM("排放量(公噸CO2e)") AS "台積電2022年類別1總排放量" FROM "104_112碳排放公開及建準資料" WHERE "事業名稱" like '%台灣積體電路製造股份有限公司%' AND "類別" = '類別1-直接排放' AND "年度" = 2022;""", }, ] return examples def table_description(): database_description = ( "The database consists of following table: `104_112碳排放公開及建準資料`, `水電使用量(ISO)` and `水電使用量(GHG)`. " "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n" "The `104_112碳排放公開及建準資料` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。" "It includes the following columns:\n" "- `年度`: 盤查年度\n" "- `類別`: 溫室氣體的排放類別,包含以下:\n" " \t*類別1-直接排放\n" " \t*類別2-能源間接排放\n" " \t*類別3-運輸間接排放\n" " \t*類別4-組織使用產品間接排放\n" " \t*類別5-使用來自組織產品間接排放\n" " \t*類別6\n" "- `排放源`: `類別`欄位進一步劃分的細項\n" "- `排放量(公噸CO2e)`: 溫室氣體排放量\n" "- `盤查標準`: ISO or GHG\n" "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。" "The `public.departments_table` table contains information about the various departments in the company. It includes:\n" "- `外購電力(灰電)`: 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n" "- `外購電力(綠電)`: 綠電(太陽光電)的外購電力度數(kwh)\n" "- `自產電力(綠電)`: 綠電(太陽光電)的自產電力度數(kwh)\n" "- `用水量`: 自來水的使用度數(m³)\n\n" ) return database_description def write_query_chain(db, llm): template = """ <|begin_of_text|> <|start_header_id|>system<|end_header_id|> Generate a SQL query to answer this question: `{input}` You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.\n\ Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.\n\ Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\ ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\ ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\ <|eot_id|> <|begin_of_text|><|start_header_id|>user<|end_header_id|> DDL statements: {table_info} database description: {database_description} Provide ONLY PostgreSQL query and NO premable or explanation! Below are a number of examples of questions and their corresponding SQL queries.\n\ <|eot_id|> <|start_header_id|>assistant<|end_header_id|> """ # prompt_template = PromptTemplate.from_template(template) example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`:\nSQL query: {query}") prompt = FewShotPromptTemplate( examples=get_examples(), example_prompt=example_prompt, prefix=template, suffix="User input: {input}\nSQL query: ", input_variables=["input", "top_k", "table_info"], ) # llm = Ollama(model = "sqlcoder", num_gpu=1) # llm = HuggingFacePipeline(pipeline=pipe) write_query = create_sql_query_chain(llm, db, prompt) return write_query def sql_to_nl_chain(llm): # llm = Ollama(model = "llama3.1", num_gpu=1) # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1) # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1) answer_prompt = PromptTemplate.from_template( """ <|begin_of_text|> <|begin_of_text|><|start_header_id|>system<|end_header_id|> Given the following user question, corresponding SQL query, and SQL result, answer the user question. 根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。 <|eot_id|> <|begin_of_text|><|start_header_id|>user<|end_header_id|> Question: {question} SQL Query: {query} SQL Result: {result} Answer: <|eot_id|> <|start_header_id|>assistant<|end_header_id|> """ ) # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1) chain = answer_prompt | llm | StrOutputParser() return chain def get_query(db, question, selected_table, llm): write_query = write_query_chain(db, llm) query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()}) query = re.split('SQL query: ', query)[-1] query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料") print(query) return query def query_to_nl(db, question, query, llm): execute_query = QuerySQLDataBaseTool(db=db) result = execute_query.invoke(query) print(result) chain = sql_to_nl_chain(llm) answer = chain.invoke({"question": question, "query": query, "result": result}) return answer def run(db, question, selected_table, llm): write_query = write_query_chain(db, llm) query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()}) query = re.split('SQL query: ', query)[-1] query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料") print(query) execute_query = QuerySQLDataBaseTool(db=db) result = execute_query.invoke(query) print(result) chain = sql_to_nl_chain(llm) answer = chain.invoke({"question": question, "query": query, "result": result}) return query, result, answer if __name__ == "__main__": import time start = time.time() selected_table = ['104_112碳排放公開及建準資料'] question = "建準去年的固定燃燒總排放量是多少?" question = "台積電2022年的直接排放總排放量是多少?" query, result, answer = run(db, question, selected_table, llm) print("question: ", question) print("query: ", query) print("result: ", result) print("answer: ", answer) print(time.time()-start)