Browse Source

新增 database_description

ling 5 months ago
parent
commit
80e4332a99
2 changed files with 77 additions and 11 deletions
  1. 9 1
      pip_env.txt
  2. 68 10
      text_to_sql.py

+ 9 - 1
pip_env.txt

@@ -1,3 +1,11 @@
+sqlparse
+langchain_huggingface
+PyPDF2
+datasets
+supabase
+faiss-gpu-cu12
+ragas
+psycopg2-binary
 accelerate==0.33.0
 aenum==3.1.15
 aiohappyeyeballs==2.3.7
@@ -104,4 +112,4 @@ uvloop==0.20.0
 watchfiles==0.23.0
 websocket-client==1.8.0
 wrapt==1.16.0
-zipp==3.20.0
+zipp==3.20.0

+ 68 - 10
text_to_sql.py

@@ -74,21 +74,62 @@ def get_examples():
         },
         {
             "input": "建準廣興廠去年的類別1總排放量是多少?",
-            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'',
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別1-直接排放\'',
         },
         {
             "input": "建準廣興廠去年的直接排放總排放量是多少?",
-            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%直接排放%\'',
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別1-直接排放\'',
         },
+        {
+            "input": "建準廣興廠去年的能源間接排放總排放量是多少?",
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別2-能源間接排放\'',
+        },
+
 
     ]
 
     return examples
 
-def write_query_chain(db):
+def table_description():
+    database_description = (
+        "The database consists of following tables: `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)`, `2023 清冊數據(GHG)`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
+        "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
+        "The `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)` and `2023 清冊數據(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
+        "It includes the following columns:\n"
+        "- `類別`: 溫室氣體的排放類別,包含以下:\n"
+        "   \t*類別1-直接排放\n"
+        "   \t*類別2-能源間接排放\n"
+        "   \t*類別3-運輸間接排放\n"
+        "   \t*類別4-組織使用產品間接排放\n"
+        "   \t*類別5-使用來自組織產品間接排放\n"
+        "   \t*類別6\n"
+        "- `排放源`: `類別`欄位進一步劃分的細項\n"
+        "- `高雄總部&運通廠`: 位於台灣的廠房據點\n"
+        "- `台北辦公室`: 位於台灣的廠房據點\n"
+        "- `北海建準廠`: 位於中國的廠房據點\n"
+        "- `北海立準廠`: 位於中國的廠房據點\n"
+        "- `昆山廣興廠`: 位於中國的廠房據點\n"
+        "- `菲律賓建準廠`: 位於菲律賓的廠房據點\n"
+        "- `India`: 位於印度的廠房據點\n"
+        "- `INC`: 位於美國的廠房據點\n"
+        "- `SAS`: 位於法國的廠房據點\n\n"
+
+        "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。"
+        "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
+        "- `外購電力(灰電)`: 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
+        "- `外購電力(綠電)`: 綠電(太陽光電)的外購電力度數(kwh)\n"
+        "- `自產電力(綠電)`: 綠電(太陽光電)的自產電力度數(kwh)\n"
+        "- `用水量`: 自來水的使用度數(m³)\n\n"
+    )
+
+    return database_description
 
-    template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
+def write_query_chain(db):
 
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
     Generate a SQL query to answer this question: `{input}`
 
     You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
@@ -100,12 +141,20 @@ def write_query_chain(db):
     
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
     ***Pay attention to only return PostgreSQL query.\n\
-
+    <|eot_id|>
+        
+    <|begin_of_text|><|start_header_id|>user<|end_header_id|>
     DDL statements:
-    {table_info}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+    {table_info}
+
+    database description:
+    {database_description}
 
     The following SQL query best answers the question `{input}`:
     ```sql
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
     """
     # prompt_template = PromptTemplate.from_template(template)
 
@@ -133,6 +182,7 @@ def sql_to_nl_chain():
     # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
     answer_prompt = PromptTemplate.from_template(
         """
+        <|begin_of_text|>
         <|begin_of_text|><|start_header_id|>system<|end_header_id|>
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
         給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
@@ -142,11 +192,18 @@ def sql_to_nl_chain():
         SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
         SQL Result: [(1102.3712,)]
         Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+        <|eot_id|>
 
+        <|begin_of_text|><|start_header_id|>user<|end_header_id|>
         Question: {question}
         SQL Query: {query}
         SQL Result: {result}
-        Answer: """
+        Answer: 
+        <|eot_id|>
+        
+        <|start_header_id|>assistant<|end_header_id|>
+        
+        """
         )
 
     chain = answer_prompt | llm | StrOutputParser()
@@ -156,7 +213,7 @@ def sql_to_nl_chain():
 def run(db, question, selected_table):
 
     write_query = write_query_chain(db)
-    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"]})
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     query = re.split('SQL query: ', query)[-1]
     print(query)
@@ -176,9 +233,10 @@ if __name__ == "__main__":
     
     start = time.time()
     
-    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)']
-    question = "去年的固定燃燒總排放量是多少?"
+    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']
+    question = "建準廣興廠去年的綠電使用量是多少?"
     query, result, answer = run(db, question, selected_table)
+    print("question: ", question)
     print("query: ", query)
     print("result: ", result)
     print("answer: ", answer)