Przeglądaj źródła

新增 database_description

ling 5 miesięcy temu
rodzic
commit
80e4332a99
2 zmienionych plików z 77 dodań i 11 usunięć
  1. 9 1
      pip_env.txt
  2. 68 10
      text_to_sql.py

+ 9 - 1
pip_env.txt

@@ -1,3 +1,11 @@
+sqlparse
+langchain_huggingface
+PyPDF2
+datasets
+supabase
+faiss-gpu-cu12
+ragas
+psycopg2-binary
 accelerate==0.33.0
 accelerate==0.33.0
 aenum==3.1.15
 aenum==3.1.15
 aiohappyeyeballs==2.3.7
 aiohappyeyeballs==2.3.7
@@ -104,4 +112,4 @@ uvloop==0.20.0
 watchfiles==0.23.0
 watchfiles==0.23.0
 websocket-client==1.8.0
 websocket-client==1.8.0
 wrapt==1.16.0
 wrapt==1.16.0
-zipp==3.20.0
+zipp==3.20.0

+ 68 - 10
text_to_sql.py

@@ -74,21 +74,62 @@ def get_examples():
         },
         },
         {
         {
             "input": "建準廣興廠去年的類別1總排放量是多少?",
             "input": "建準廣興廠去年的類別1總排放量是多少?",
-            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'',
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別1-直接排放\'',
         },
         },
         {
         {
             "input": "建準廣興廠去年的直接排放總排放量是多少?",
             "input": "建準廣興廠去年的直接排放總排放量是多少?",
-            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%直接排放%\'',
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別1-直接排放\'',
         },
         },
+        {
+            "input": "建準廣興廠去年的能源間接排放總排放量是多少?",
+            "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別2-能源間接排放\'',
+        },
+
 
 
     ]
     ]
 
 
     return examples
     return examples
 
 
-def write_query_chain(db):
+def table_description():
+    database_description = (
+        "The database consists of following tables: `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)`, `2023 清冊數據(GHG)`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
+        "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
+        "The `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)` and `2023 清冊數據(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
+        "It includes the following columns:\n"
+        "- `類別`: 溫室氣體的排放類別,包含以下:\n"
+        "   \t*類別1-直接排放\n"
+        "   \t*類別2-能源間接排放\n"
+        "   \t*類別3-運輸間接排放\n"
+        "   \t*類別4-組織使用產品間接排放\n"
+        "   \t*類別5-使用來自組織產品間接排放\n"
+        "   \t*類別6\n"
+        "- `排放源`: `類別`欄位進一步劃分的細項\n"
+        "- `高雄總部&運通廠`: 位於台灣的廠房據點\n"
+        "- `台北辦公室`: 位於台灣的廠房據點\n"
+        "- `北海建準廠`: 位於中國的廠房據點\n"
+        "- `北海立準廠`: 位於中國的廠房據點\n"
+        "- `昆山廣興廠`: 位於中國的廠房據點\n"
+        "- `菲律賓建準廠`: 位於菲律賓的廠房據點\n"
+        "- `India`: 位於印度的廠房據點\n"
+        "- `INC`: 位於美國的廠房據點\n"
+        "- `SAS`: 位於法國的廠房據點\n\n"
+
+        "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。"
+        "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
+        "- `外購電力(灰電)`: 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
+        "- `外購電力(綠電)`: 綠電(太陽光電)的外購電力度數(kwh)\n"
+        "- `自產電力(綠電)`: 綠電(太陽光電)的自產電力度數(kwh)\n"
+        "- `用水量`: 自來水的使用度數(m³)\n\n"
+    )
+
+    return database_description
 
 
-    template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
+def write_query_chain(db):
 
 
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
     Generate a SQL query to answer this question: `{input}`
     Generate a SQL query to answer this question: `{input}`
 
 
     You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
     You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
@@ -100,12 +141,20 @@ def write_query_chain(db):
     
     
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
     ***Pay attention to only return PostgreSQL query.\n\
     ***Pay attention to only return PostgreSQL query.\n\
-
+    <|eot_id|>
+        
+    <|begin_of_text|><|start_header_id|>user<|end_header_id|>
     DDL statements:
     DDL statements:
-    {table_info}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+    {table_info}
+
+    database description:
+    {database_description}
 
 
     The following SQL query best answers the question `{input}`:
     The following SQL query best answers the question `{input}`:
     ```sql
     ```sql
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
     """
     """
     # prompt_template = PromptTemplate.from_template(template)
     # prompt_template = PromptTemplate.from_template(template)
 
 
@@ -133,6 +182,7 @@ def sql_to_nl_chain():
     # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
     # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
     answer_prompt = PromptTemplate.from_template(
     answer_prompt = PromptTemplate.from_template(
         """
         """
+        <|begin_of_text|>
         <|begin_of_text|><|start_header_id|>system<|end_header_id|>
         <|begin_of_text|><|start_header_id|>system<|end_header_id|>
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
         給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
         給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
@@ -142,11 +192,18 @@ def sql_to_nl_chain():
         SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
         SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
         SQL Result: [(1102.3712,)]
         SQL Result: [(1102.3712,)]
         Answer: 建準廣興廠去年的類別1總排放量是1102.3712
         Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+        <|eot_id|>
 
 
+        <|begin_of_text|><|start_header_id|>user<|end_header_id|>
         Question: {question}
         Question: {question}
         SQL Query: {query}
         SQL Query: {query}
         SQL Result: {result}
         SQL Result: {result}
-        Answer: """
+        Answer: 
+        <|eot_id|>
+        
+        <|start_header_id|>assistant<|end_header_id|>
+        
+        """
         )
         )
 
 
     chain = answer_prompt | llm | StrOutputParser()
     chain = answer_prompt | llm | StrOutputParser()
@@ -156,7 +213,7 @@ def sql_to_nl_chain():
 def run(db, question, selected_table):
 def run(db, question, selected_table):
 
 
     write_query = write_query_chain(db)
     write_query = write_query_chain(db)
-    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"]})
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     
     query = re.split('SQL query: ', query)[-1]
     query = re.split('SQL query: ', query)[-1]
     print(query)
     print(query)
@@ -176,9 +233,10 @@ if __name__ == "__main__":
     
     
     start = time.time()
     start = time.time()
     
     
-    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)']
-    question = "去年的固定燃燒總排放量是多少?"
+    selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']
+    question = "建準廣興廠去年的綠電使用量是多少?"
     query, result, answer = run(db, question, selected_table)
     query, result, answer = run(db, question, selected_table)
+    print("question: ", question)
     print("query: ", query)
     print("query: ", query)
     print("result: ", result)
     print("result: ", result)
     print("answer: ", answer)
     print("answer: ", answer)