text_to_sql2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import re
  2. from dotenv import load_dotenv
  3. load_dotenv()
  4. from langchain_community.utilities import SQLDatabase
  5. import os
  6. URI: str = os.environ.get('SUPABASE_URI')
  7. db = SQLDatabase.from_uri(URI)
  8. # print(db.dialect)
  9. # print(db.get_usable_table_names())
  10. # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
  11. context = db.get_context()
  12. # print(list(context))
  13. # print(context["table_info"])
  14. from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
  15. from langchain.chains import create_sql_query_chain
  16. from langchain_community.llms import Ollama
  17. from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
  18. from operator import itemgetter
  19. from langchain_core.output_parsers import StrOutputParser
  20. from langchain_core.prompts import PromptTemplate
  21. from langchain_core.runnables import RunnablePassthrough
  22. # Load model directly
  23. from transformers import AutoTokenizer, AutoModelForCausalLM
  24. from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
  25. import torch
  26. from langchain_huggingface import HuggingFacePipeline
  27. # Load model directly
  28. from transformers import AutoTokenizer, AutoModelForCausalLM
  29. # model_id = "defog/llama-3-sqlcoder-8b"
  30. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  31. # sql_llm = HuggingFacePipeline.from_model_id(
  32. # model_id=model_id,
  33. # task="text-generation",
  34. # model_kwargs={"torch_dtype": torch.bfloat16},
  35. # pipeline_kwargs={"return_full_text": False},
  36. # device=0, device_map='cuda')
  37. ##########################################################################################
  38. from langchain_community.chat_models import ChatOllama
  39. # local_llm = "llama3-groq-tool-use:latest"
  40. local_llm = "llama3-groq-tool-use:latest"
  41. llm = ChatOllama(model=local_llm, temperature=0)
  42. ##########################################################################################
  43. # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
  44. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  45. # llm = HuggingFacePipeline.from_model_id(
  46. # model_id=model_id,
  47. # task="text-generation",
  48. # model_kwargs={"torch_dtype": torch.bfloat16},
  49. # pipeline_kwargs={"return_full_text": False,
  50. # "max_new_tokens": 512},
  51. # device=0, device_map='cuda')
  52. # print(llm.pipeline)
  53. # llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
  54. ##########################################################################################
  55. # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
  56. # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
  57. # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
  58. #, device="auto", load_in_4bit=True
  59. # llm = HuggingFacePipeline(pipeline=pipe)
  60. # llm = HuggingFacePipeline(pipeline=pipe)
  61. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  62. def get_examples():
  63. examples = [
  64. {
  65. "input": "建準廣興廠去年的類別1總排放量是多少?",
  66. "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
  67. FROM "104_112碳排放公開及建準資料"
  68. WHERE "事業名稱" like '%建準%'
  69. AND "事業名稱" like '%廣興廠%'
  70. AND "類別" = '類別1-直接排放'
  71. AND "盤查標準" = 'GHG'
  72. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  73. },
  74. {
  75. "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?",
  76. "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
  77. FROM "104_112碳排放公開及建準資料"
  78. WHERE "事業名稱" like '%建準%'
  79. AND "事業名稱" like '%台北辦事處%'
  80. AND "類別" = '類別2-能源間接排放'
  81. AND "盤查標準" = 'GHG'
  82. AND "年度" = 2022;""",
  83. },
  84. {
  85. "input": "建準去年的固定燃燒總排放量是多少?",
  86. "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  87. FROM "104_112碳排放公開及建準資料"
  88. WHERE "事業名稱" like '%建準%'
  89. AND "排放源" = '固定燃燒'
  90. AND "盤查標準" = 'GHG'
  91. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  92. },
  93. {
  94. "input": "台積電2022年的類別1總排放量是多少?",
  95. "query": """SELECT SUM("排放量(公噸CO2e)") AS "台積電2022年類別1總排放量"
  96. FROM "104_112碳排放公開及建準資料"
  97. WHERE "事業名稱" like '%台灣積體電路製造股份有限公司%'
  98. AND "類別" = '類別1-直接排放'
  99. AND "年度" = 2022;""",
  100. },
  101. ]
  102. return examples
  103. def table_description():
  104. database_description = (
  105. "The database consists of following table: `104_112碳排放公開及建準資料`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
  106. "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
  107. "The `104_112碳排放公開及建準資料` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
  108. "It includes the following columns:\n"
  109. "- `年度`: 盤查年度\n"
  110. "- `類別`: 溫室氣體的排放類別,包含以下:\n"
  111. " \t*類別1-直接排放\n"
  112. " \t*類別2-能源間接排放\n"
  113. " \t*類別3-運輸間接排放\n"
  114. " \t*類別4-組織使用產品間接排放\n"
  115. " \t*類別5-使用來自組織產品間接排放\n"
  116. " \t*類別6\n"
  117. "- `排放源`: `類別`欄位進一步劃分的細項\n"
  118. "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
  119. "- `盤查標準`: ISO or GHG\n"
  120. "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。"
  121. "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
  122. "- `外購電力(灰電)`: 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
  123. "- `外購電力(綠電)`: 綠電(太陽光電)的外購電力度數(kwh)\n"
  124. "- `自產電力(綠電)`: 綠電(太陽光電)的自產電力度數(kwh)\n"
  125. "- `用水量`: 自來水的使用度數(m³)\n\n"
  126. )
  127. return database_description
  128. def write_query_chain(db, llm):
  129. template = """
  130. <|begin_of_text|>
  131. <|start_header_id|>system<|end_header_id|>
  132. Generate a SQL query to answer this question: `{input}`
  133. You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
  134. then look at the results of the query and return the answer to the input question.\n\
  135. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
  136. You can order the results to return the most informative data in the database.\n\
  137. Never query for all columns from a table. You must query only the columns that are needed to answer the question.
  138. Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
  139. ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
  140. ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
  141. <|eot_id|>
  142. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  143. DDL statements:
  144. {table_info}
  145. database description:
  146. {database_description}
  147. Provide ONLY PostgreSQL query and NO premable or explanation!
  148. Below are a number of examples of questions and their corresponding SQL queries.\n\
  149. <|eot_id|>
  150. <|start_header_id|>assistant<|end_header_id|>
  151. """
  152. # prompt_template = PromptTemplate.from_template(template)
  153. example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`:\nSQL query: {query}")
  154. prompt = FewShotPromptTemplate(
  155. examples=get_examples(),
  156. example_prompt=example_prompt,
  157. prefix=template,
  158. suffix="User input: {input}\nSQL query: ",
  159. input_variables=["input", "top_k", "table_info"],
  160. )
  161. # llm = Ollama(model = "sqlcoder", num_gpu=1)
  162. # llm = HuggingFacePipeline(pipeline=pipe)
  163. write_query = create_sql_query_chain(llm, db, prompt)
  164. return write_query
  165. def sql_to_nl_chain(llm):
  166. # llm = Ollama(model = "llama3.1", num_gpu=1)
  167. # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
  168. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  169. answer_prompt = PromptTemplate.from_template(
  170. """
  171. <|begin_of_text|>
  172. <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  173. Given the following user question, corresponding SQL query, and SQL result, answer the user question.
  174. 根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
  175. <|eot_id|>
  176. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  177. Question: {question}
  178. SQL Query: {query}
  179. SQL Result: {result}
  180. Answer:
  181. <|eot_id|>
  182. <|start_header_id|>assistant<|end_header_id|>
  183. """
  184. )
  185. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  186. chain = answer_prompt | llm | StrOutputParser()
  187. return chain
  188. def get_query(db, question, selected_table, llm):
  189. write_query = write_query_chain(db, llm)
  190. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  191. query = re.split('SQL query: ', query)[-1]
  192. query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  193. print(query)
  194. return query
  195. def query_to_nl(db, question, query, llm):
  196. execute_query = QuerySQLDataBaseTool(db=db)
  197. result = execute_query.invoke(query)
  198. print(result)
  199. chain = sql_to_nl_chain(llm)
  200. answer = chain.invoke({"question": question, "query": query, "result": result})
  201. return answer
  202. def run(db, question, selected_table, llm):
  203. write_query = write_query_chain(db, llm)
  204. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  205. query = re.split('SQL query: ', query)[-1]
  206. query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  207. print(query)
  208. execute_query = QuerySQLDataBaseTool(db=db)
  209. result = execute_query.invoke(query)
  210. print(result)
  211. chain = sql_to_nl_chain(llm)
  212. answer = chain.invoke({"question": question, "query": query, "result": result})
  213. return query, result, answer
  214. if __name__ == "__main__":
  215. import time
  216. start = time.time()
  217. selected_table = ['104_112碳排放公開及建準資料']
  218. question = "建準去年的固定燃燒總排放量是多少?"
  219. question = "台積電2022年的直接排放總排放量是多少?"
  220. query, result, answer = run(db, question, selected_table, llm)
  221. print("question: ", question)
  222. print("query: ", query)
  223. print("result: ", result)
  224. print("answer: ", answer)
  225. print(time.time()-start)