text_to_sql.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import re
  2. from dotenv import load_dotenv
  3. load_dotenv()
  4. from langchain_community.utilities import SQLDatabase
  5. import os
  6. URI: str = os.environ.get('SUPABASE_URI')
  7. db = SQLDatabase.from_uri(URI)
  8. # print(db.dialect)
  9. # print(db.get_usable_table_names())
  10. # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
  11. context = db.get_context()
  12. # print(list(context))
  13. # print(context["table_info"])
  14. from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
  15. from langchain.chains import create_sql_query_chain
  16. from langchain_community.llms import Ollama
  17. from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
  18. from operator import itemgetter
  19. from langchain_core.output_parsers import StrOutputParser
  20. from langchain_core.prompts import PromptTemplate
  21. from langchain_core.runnables import RunnablePassthrough
  22. # Load model directly
  23. from transformers import AutoTokenizer, AutoModelForCausalLM
  24. from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
  25. import torch
  26. from langchain_huggingface import HuggingFacePipeline
  27. # Load model directly
  28. from transformers import AutoTokenizer, AutoModelForCausalLM
  29. # model_id = "defog/llama-3-sqlcoder-8b"
  30. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  31. # sql_llm = HuggingFacePipeline.from_model_id(
  32. # model_id=model_id,
  33. # task="text-generation",
  34. # model_kwargs={"torch_dtype": torch.bfloat16},
  35. # pipeline_kwargs={"return_full_text": False},
  36. # device=0, device_map='cuda')
  37. ##########################################################################################
  38. from langchain_community.chat_models import ChatOllama
  39. # local_llm = "llama3-groq-tool-use:latest"
  40. local_llm = "llama3-groq-tool-use:latest"
  41. llm = ChatOllama(model=local_llm, temperature=0)
  42. ##########################################################################################
  43. # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
  44. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  45. # llm = HuggingFacePipeline.from_model_id(
  46. # model_id=model_id,
  47. # task="text-generation",
  48. # model_kwargs={"torch_dtype": torch.bfloat16},
  49. # pipeline_kwargs={"return_full_text": False,
  50. # "max_new_tokens": 512},
  51. # device=0, device_map='cuda')
  52. # print(llm.pipeline)
  53. # llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
  54. ##########################################################################################
  55. # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
  56. # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
  57. # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
  58. #, device="auto", load_in_4bit=True
  59. # llm = HuggingFacePipeline(pipeline=pipe)
  60. # llm = HuggingFacePipeline(pipeline=pipe)
  61. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  62. def get_examples():
  63. examples = [
  64. {
  65. "input": "去年的固定燃燒總排放量是多少?",
  66. "query": 'SELECT SUM("高雄總部及運通廠" + "台北辦事處" + "昆山廣興廠" + "北海建準廠" + "北海立準廠" + "菲律賓建準廠" + "Inc" + "SAS" + "India") AS "固定燃燒總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "排放源" = \'固定燃燒\'',
  67. },
  68. {
  69. "input": "建準廣興廠去年的類別1總排放量是多少?",
  70. "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別1-直接排放\'',
  71. },
  72. {
  73. "input": "建準廣興廠去年的直接排放總排放量是多少?",
  74. "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別1-直接排放\'',
  75. },
  76. {
  77. "input": "建準廣興廠去年的能源間接排放總排放量是多少?",
  78. "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" = \'類別2-能源間接排放\'',
  79. },
  80. ]
  81. return examples
  82. def table_description():
  83. database_description = (
  84. "The database consists of following tables: `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)`, `2023 清冊數據(GHG)`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
  85. "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
  86. "The `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)` and `2023 清冊數據(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
  87. "It includes the following columns:\n"
  88. "- `類別`: 溫室氣體的排放類別,包含以下:\n"
  89. " \t*類別1-直接排放\n"
  90. " \t*類別2-能源間接排放\n"
  91. " \t*類別3-運輸間接排放\n"
  92. " \t*類別4-組織使用產品間接排放\n"
  93. " \t*類別5-使用來自組織產品間接排放\n"
  94. " \t*類別6\n"
  95. "- `排放源`: `類別`欄位進一步劃分的細項\n"
  96. "- `高雄總部&運通廠`: 位於台灣的廠房據點\n"
  97. "- `台北辦公室`: 位於台灣的廠房據點\n"
  98. "- `北海建準廠`: 位於中國的廠房據點\n"
  99. "- `北海立準廠`: 位於中國的廠房據點\n"
  100. "- `昆山廣興廠`: 位於中國的廠房據點\n"
  101. "- `菲律賓建準廠`: 位於菲律賓的廠房據點\n"
  102. "- `India`: 位於印度的廠房據點\n"
  103. "- `INC`: 位於美國的廠房據點\n"
  104. "- `SAS`: 位於法國的廠房據點\n\n"
  105. "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。"
  106. "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
  107. "- `外購電力(灰電)`: 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
  108. "- `外購電力(綠電)`: 綠電(太陽光電)的外購電力度數(kwh)\n"
  109. "- `自產電力(綠電)`: 綠電(太陽光電)的自產電力度數(kwh)\n"
  110. "- `用水量`: 自來水的使用度數(m³)\n\n"
  111. )
  112. return database_description
  113. def write_query_chain(db, llm):
  114. template = """
  115. <|begin_of_text|>
  116. <|start_header_id|>system<|end_header_id|>
  117. Generate a SQL query to answer this question: `{input}`
  118. You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
  119. then look at the results of the query and return the answer to the input question.\n\
  120. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
  121. You can order the results to return the most informative data in the database.\n\
  122. Never query for all columns from a table. You must query only the columns that are needed to answer the question.
  123. Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
  124. ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
  125. ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
  126. <|eot_id|>
  127. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  128. DDL statements:
  129. {table_info}
  130. database description:
  131. {database_description}
  132. Provide ONLY PostgreSQL query and NO premable or explanation!
  133. The following SQL query best answers the question `{input}`:
  134. <|eot_id|>
  135. <|start_header_id|>assistant<|end_header_id|>
  136. """
  137. # prompt_template = PromptTemplate.from_template(template)
  138. example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
  139. prompt = FewShotPromptTemplate(
  140. examples=get_examples(),
  141. example_prompt=example_prompt,
  142. prefix=template,
  143. suffix="User input: {input}\nSQL query: ",
  144. input_variables=["input", "top_k", "table_info"],
  145. )
  146. # llm = Ollama(model = "mannix/defog-llama3-sqlcoder-8b", num_gpu=1)
  147. # llm = HuggingFacePipeline(pipeline=pipe)
  148. write_query = create_sql_query_chain(llm, db, prompt)
  149. return write_query
  150. def sql_to_nl_chain(llm):
  151. # llm = Ollama(model = "llama3.1", num_gpu=1)
  152. # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
  153. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  154. answer_prompt = PromptTemplate.from_template(
  155. """
  156. <|begin_of_text|>
  157. <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  158. Given the following user question, corresponding SQL query, and SQL result, answer the user question.
  159. 給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
  160. For example
  161. Question: 建準廣興廠去年的類別1總排放量是多少?
  162. SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
  163. SQL Result: [(1102.3712,)]
  164. Answer: 建準廣興廠去年的類別1總排放量是1102.3712
  165. <|eot_id|>
  166. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  167. Question: {question}
  168. SQL Query: {query}
  169. SQL Result: {result}
  170. Answer:
  171. <|eot_id|>
  172. <|start_header_id|>assistant<|end_header_id|>
  173. """
  174. )
  175. chain = answer_prompt | llm | StrOutputParser()
  176. return chain
  177. def get_query(db, question, selected_table, llm):
  178. write_query = write_query_chain(db, llm)
  179. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  180. query = re.split('SQL query: ', query)[-1]
  181. query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  182. print(query)
  183. return query
  184. def query_to_nl(db, question, query, llm):
  185. execute_query = QuerySQLDataBaseTool(db=db)
  186. result = execute_query.invoke(query)
  187. print(result)
  188. chain = sql_to_nl_chain(llm)
  189. answer = chain.invoke({"question": question, "query": query, "result": result})
  190. return answer
  191. def run(db, question, selected_table, llm):
  192. write_query = write_query_chain(db, llm)
  193. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  194. query = re.split('SQL query: ', query)[-1]
  195. print(query)
  196. execute_query = QuerySQLDataBaseTool(db=db)
  197. result = execute_query.invoke(query)
  198. print(result)
  199. chain = sql_to_nl_chain(llm)
  200. answer = chain.invoke({"question": question, "query": query, "result": result})
  201. return query, result, answer
  202. if __name__ == "__main__":
  203. import time
  204. start = time.time()
  205. selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)', '水電使用量(GHG)', '水電使用量(ISO)']
  206. question = "建準去年的固定燃燒總排放量是多少?"
  207. query, result, answer = run(db, question, selected_table)
  208. print("question: ", question)
  209. print("query: ", query)
  210. print("result: ", result)
  211. print("answer: ", answer)
  212. print(time.time()-start)