ソースを参照

Text-to-SQL 整合環境部公開資料

ling 5 ヶ月 前
コミット
2b07b528cb
1 ファイル変更282 行追加0 行削除
  1. 282 0
      text_to_sql2.py

+ 282 - 0
text_to_sql2.py

@@ -0,0 +1,282 @@
+import re
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_community.utilities import SQLDatabase
+import os
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# print(db.dialect)
+# print(db.get_usable_table_names())
+# db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
+
+context = db.get_context()
+# print(list(context))
+# print(context["table_info"])
+
+from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
+from langchain.chains import create_sql_query_chain
+from langchain_community.llms import Ollama
+
+from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
+from operator import itemgetter
+
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import PromptTemplate
+from langchain_core.runnables import RunnablePassthrough
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
+import torch
+from langchain_huggingface import HuggingFacePipeline
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+# model_id = "defog/llama-3-sqlcoder-8b"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+# sql_llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False},
+#     device=0, device_map='cuda')
+
+##########################################################################################
+from langchain_community.chat_models import ChatOllama
+# local_llm = "llama3-groq-tool-use:latest"
+local_llm = "llama3-groq-tool-use:latest"
+llm = ChatOllama(model=local_llm, temperature=0)
+##########################################################################################
+# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+# llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False,
+#         "max_new_tokens": 512},
+#     device=0, device_map='cuda')
+# print(llm.pipeline)
+# llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+##########################################################################################
+
+# model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
+
+# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1, 
+#                 model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
+#, device="auto", load_in_4bit=True
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+def get_examples():
+    examples = [
+        {
+            "input": "建準去年的固定燃燒總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+                        FROM "104_112碳排放公開及建準資料"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "排放源" = '固定燃燒'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "建準廣興廠去年的類別1總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+                        FROM "104_112碳排放公開及建準資料"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%廣興廠%'
+                        AND "類別" = '類別1-直接排放'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
+                        FROM "104_112碳排放公開及建準資料"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%台北辦事處%'
+                        AND "類別" = '類別2-能源間接排放'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2022;""",
+        },
+        {
+            "input": "台積電2022年的類別1總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "台積電2022年類別1總排放量"
+                        FROM "104_112碳排放公開及建準資料"
+                        WHERE "事業名稱" like '%台灣積體電路製造股份有限公司%'
+                        AND "類別" = '類別1-直接排放'
+                        AND "年度" = 2022;""",
+        },
+
+
+    ]
+
+    return examples
+
+def table_description():
+    database_description = (
+        "The database consists of following tables: `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)`, `2023 清冊數據(GHG)`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
+        "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
+        "The `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)` and `2023 清冊數據(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
+        "It includes the following columns:\n"
+        "- `類別`: 溫室氣體的排放類別,包含以下:\n"
+        "   \t*類別1-直接排放\n"
+        "   \t*類別2-能源間接排放\n"
+        "   \t*類別3-運輸間接排放\n"
+        "   \t*類別4-組織使用產品間接排放\n"
+        "   \t*類別5-使用來自組織產品間接排放\n"
+        "   \t*類別6\n"
+        "- `排放源`: `類別`欄位進一步劃分的細項\n"
+        "- `高雄總部&運通廠`: 位於台灣的廠房據點\n"
+        "- `台北辦公室`: 位於台灣的廠房據點\n"
+        "- `北海建準廠`: 位於中國的廠房據點\n"
+        "- `北海立準廠`: 位於中國的廠房據點\n"
+        "- `昆山廣興廠`: 位於中國的廠房據點\n"
+        "- `菲律賓建準廠`: 位於菲律賓的廠房據點\n"
+        "- `India`: 位於印度的廠房據點\n"
+        "- `INC`: 位於美國的廠房據點\n"
+        "- `SAS`: 位於法國的廠房據點\n\n"
+
+        "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。"
+        "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
+        "- `外購電力(灰電)`: 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
+        "- `外購電力(綠電)`: 綠電(太陽光電)的外購電力度數(kwh)\n"
+        "- `自產電力(綠電)`: 綠電(太陽光電)的自產電力度數(kwh)\n"
+        "- `用水量`: 自來水的使用度數(m³)\n\n"
+    )
+
+    return database_description
+
+def write_query_chain(db, llm):
+
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    Generate a SQL query to answer this question: `{input}`
+
+    You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
+    then look at the results of the query and return the answer to the input question.\n\
+    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. 
+    You can order the results to return the most informative data in the database.\n\
+    Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
+    Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
+    
+    ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
+    ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
+    <|eot_id|>
+        
+    <|begin_of_text|><|start_header_id|>user<|end_header_id|>
+    DDL statements:
+    {table_info}
+
+    database description:
+    {database_description}
+
+    Provide ONLY PostgreSQL query and NO premable or explanation!
+    The following SQL query best answers the question `{input}`:
+    
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    # prompt_template = PromptTemplate.from_template(template)
+
+    example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
+    prompt = FewShotPromptTemplate(
+        examples=get_examples(),
+        example_prompt=example_prompt,
+        prefix=template,
+        suffix="User input: {input}\nSQL query: ",
+        input_variables=["input", "top_k", "table_info"],
+    )
+
+    # llm = Ollama(model = "mannix/defog-llama3-sqlcoder-8b", num_gpu=1)
+    # llm = HuggingFacePipeline(pipeline=pipe)
+    
+    
+    write_query = create_sql_query_chain(llm, db, prompt)
+
+
+    return write_query
+
+def sql_to_nl_chain(llm):
+    # llm = Ollama(model = "llama3.1", num_gpu=1)
+    # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
+    # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+    answer_prompt = PromptTemplate.from_template(
+        """
+        <|begin_of_text|>
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|>
+        Given the following user question, corresponding SQL query, and SQL result, answer the user question.
+        給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
+
+        For example
+        Question: 建準廣興廠去年的類別總排放量是多少?
+        SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+                        FROM "104_112碳排放公開及建準資料"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%廣興廠%'
+                        AND "類別" = '類別1-直接排放'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
+        SQL Result: [(1102.3712,)]
+        Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+        <|eot_id|>
+
+        <|begin_of_text|><|start_header_id|>user<|end_header_id|>
+        Question: {question}
+        SQL Query: {query}
+        SQL Result: {result}
+        Answer: 
+        <|eot_id|>
+        
+        <|start_header_id|>assistant<|end_header_id|>
+        
+        """
+        )
+
+    chain = answer_prompt | llm | StrOutputParser()
+
+    return chain
+
+def run(db, question, selected_table, llm):
+
+    write_query = write_query_chain(db, llm)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": "foo"})
+    
+    query = re.split('SQL query: ', query)[-1]
+    query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
+    print(query)
+
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain(llm)
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return query, result, answer
+
+
+if __name__ == "__main__":
+    import time
+    
+    start = time.time()
+    
+    selected_table = ['104_112碳排放公開及建準資料']
+    question = "建準去年的固定燃燒總排放量是多少?"
+    question = "台積電2022年的直接排放總排放量是多少?"
+    query, result, answer = run(db, question, selected_table, llm)
+    print("question: ", question)
+    print("query: ", query)
+    print("result: ", result)
+    print("answer: ", answer)
+    
+    print(time.time()-start)
+