Explorar el Código

add sql examples and adjust prompt

ling hace 4 meses
padre
commit
1348a079ba
Se han modificado 1 ficheros con 90 adiciones y 20 borrados
  1. 90 20
      text_to_sql_private.py

+ 90 - 20
text_to_sql_private.py

@@ -46,8 +46,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
 ##########################################################################################
 from langchain_community.chat_models import ChatOllama
 # local_llm = "llama3-groq-tool-use:latest"
-local_llm = "llama3-groq-tool-use:latest"
-llm = ChatOllama(model=local_llm, temperature=0)
+# local_llm = "llama3-groq-tool-use:latest"
+# local_llm = "sqlcoder:latest"
+# local_llm = "llama3.1:8b-instruct-q2_K"
+# llm = ChatOllama(model=local_llm, temperature=0)
 ##########################################################################################
 # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 # tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -73,57 +75,110 @@ llm = ChatOllama(model=local_llm, temperature=0)
 # llm = HuggingFacePipeline(pipeline=pipe)
 
 # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+
 def get_examples():
     examples = [
         {
-            "input": "建準廣興廠去年的自產電力的綠電使用量是多少?",
+            "input": "建準去年固定燃燒總排放量",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "廣興廠去年的固定燃燒排放量是多少?",
+            "query": """FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "據點" = '昆山廣興廠'
+                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "建準廣興廠去年自產電力的綠電使用量是多少?",
             "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
                         FROM "用電度數"
                         WHERE "項目" like '%綠電%'
                         AND "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興廠%'
+                        AND "據點" = '昆山廣興廠'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         {
-            "input": "建準北海廠去年的類別1總排放量是多少?",
+            "input": "建準北海廠去年的類別1總排放量",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%北海%'
+                        AND "據點" in ('北海建準廠', '北海立準廠')
                         AND "類別" = '類別1'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         {
             "input": "建準廣興廠去年的直接排放總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興%'
+                        AND "據點" = '昆山廣興廠'
                         AND ("類別項目" like '%直接排放%' OR "排放源" like '%直接排放%')
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         {
             "input": "建準台北辦事處2022年的類別2總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%台北%'
+                        AND "據點" = '台北辦事處'
                         AND "類別" = '類別2'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = 2022;""",
         },
         {
-            "input": "建準去年的固定燃燒總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+            "input": "建準法國廠2022年的類別2總排放量",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
-                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "國家" = '法國'
+                        AND "類別" = '類別2'
                         AND "盤查標準" = 'GHG'
-                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+                        AND "年度" = 2022;""",
+        },
+        {
+            "input": "建準北海2022的外購電力是多少",
+            "query": """SELECT SUM("用電度數(kwh)") AS "外購電力"
+                        FROM "用電度數"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "據點" in ('北海建準廠', '北海立準廠')
+                        AND "項目" like '%外購電力%'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2022;""",
+        },
+        {
+            "input": "2023建準印度的其他間接排放是多少",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "其他間接排放總量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "國家" = '印度'
+                        AND ("類別項目" like '%其他間接排放%' OR "排放源" like '%其他間接排放%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2023;""",
         },
+        {
+            "input": "建準台北前年的產品使用碳排放量是多少",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "產品使用總量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "據點" = '台北辦事處'
+                        AND ("類別項目" like '%產品使用%' OR "排放源" like '%產品使用%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-2;""",
+        },
+
 
 
     ]
@@ -137,7 +192,8 @@ def table_description():
         "The `建準碳排放清冊數據new` table 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
         "It includes the following columns:\n"
         "- `年度`: 盤查年度\n"
-        "- `事業名稱`: 建準據點"
+        "- `事業名稱`: 公司名稱"
+        "- `據點`: 建準廠房據點 include '高雄總部及運通廠', '台北辦事處', '昆山廣興廠', '北海建準廠', '北海立準廠', '菲律賓建準廠', 'Inc', 'SAS', 'India'"
         "- `國家`: 據點所在國家"
         "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
         "   \t*類別1-直接排放:\n"
@@ -183,7 +239,9 @@ def write_query_chain(db, llm):
     <|begin_of_text|>
     
     <|start_header_id|>system<|end_header_id|>
+
     Generate a SQL query to answer this question: `{input}`
+    你是建準的AI助理,幫助建準查詢碳排放量,如果問題中有提到據點廠房,請使用 PostgreSQL query 進行篩選。
 
     You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
     then look at the results of the query and return the answer to the input question.\n\
@@ -192,6 +250,7 @@ def write_query_chain(db, llm):
     Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
     Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
     
+    Unless the user ask for the type of 盤查標準 to be 'ISO' or 'GHG', queries always include query "盤查標準"='GHG' in the WHERE clause.\n  
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
     ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
     <|eot_id|>
@@ -209,9 +268,9 @@ def write_query_chain(db, llm):
     Below are a number of examples of questions and their corresponding SQL queries.\n\
     
     <|eot_id|>
-    
-    <|start_header_id|>assistant<|end_header_id|>
+    SQL query:
     """
+    # <|start_header_id|>assistant<|end_header_id|>
     # prompt_template = PromptTemplate.from_template(template)
 
     example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
@@ -227,6 +286,7 @@ def write_query_chain(db, llm):
     # llm = HuggingFacePipeline(pipeline=pipe)
     
     
+    # sqlcoder = Ollama(model = "sqlcoder", num_gpu=1)
     write_query = create_sql_query_chain(llm, db, prompt)
 
 
@@ -245,11 +305,11 @@ def sql_to_nl_chain(llm):
         ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
         
         The following shows some example:
-        Question: 廣興廠去年的類別1總排放量是多少?
+        Question: 建準廣興廠去年的類別1總排放量是多少?
         SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興%'
+                        AND "據點" = '昆山廣興廠'
                         AND "類別" = '類別1'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
@@ -257,6 +317,7 @@ def sql_to_nl_chain(llm):
         Answer: 建準廣興廠去年的類別1總排放量是1102.3712
 
         如果你不知道答案或SQL query 出現錯誤請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+        
         勿回答無關資訊
         <|eot_id|>
 
@@ -278,10 +339,14 @@ def sql_to_nl_chain(llm):
     return chain
 
 def get_query(db, question, selected_table, llm):
+    
     write_query = write_query_chain(db, llm)
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     query = re.split('SQL query: ', query)[-1]
+    query = query.replace("```sql","").replace("```","")
+    query = query.replace("碰排","碳排")
+    query = query.replace("%%","%")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     print(query)
     
@@ -308,6 +373,9 @@ def run(db, question, selected_table, llm):
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     query = re.split('SQL query: ', query)[-1]
+    query = query.replace("```sql","").replace("```","")
+    query = query.replace("碰排","碳排")
+    query = query.replace("%%","%")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     print(query)
 
@@ -327,7 +395,9 @@ if __name__ == "__main__":
     start = time.time()
     
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
-    question = "建準去年的上游運輸總排放量是多少?"
+    # question = "建準廣興廠去年的上游運輸總排放量是多少?"
+    question = "建準北海廠去年的固定燃燒排放量是多少?"
+    # question = "建準北海廠去年類別1總排放量是多少?"
     # question = "台積電2022年的直接排放總排放量是多少?"
     # question = "建準廣興廠去年的灰電使用量"
     query, result, answer = run(db, question, selected_table, llm)