text_to_sql_private.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import re
  2. from dotenv import load_dotenv
  3. load_dotenv()
  4. from langchain_community.utilities import SQLDatabase
  5. import os
  6. URI: str = os.environ.get('SUPABASE_URI')
  7. db = SQLDatabase.from_uri(URI)
  8. # print(db.dialect)
  9. # print(db.get_usable_table_names())
  10. # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
  11. context = db.get_context()
  12. # print(list(context))
  13. # print(context["table_info"])
  14. from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
  15. from langchain.chains import create_sql_query_chain
  16. from langchain_community.llms import Ollama
  17. from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
  18. from operator import itemgetter
  19. from langchain_core.output_parsers import StrOutputParser
  20. from langchain_core.prompts import PromptTemplate
  21. from langchain_core.runnables import RunnablePassthrough
  22. # Load model directly
  23. from transformers import AutoTokenizer, AutoModelForCausalLM
  24. from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
  25. import torch
  26. from langchain_huggingface import HuggingFacePipeline
  27. # Load model directly
  28. from transformers import AutoTokenizer, AutoModelForCausalLM
  29. # model_id = "defog/llama-3-sqlcoder-8b"
  30. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  31. # sql_llm = HuggingFacePipeline.from_model_id(
  32. # model_id=model_id,
  33. # task="text-generation",
  34. # model_kwargs={"torch_dtype": torch.bfloat16},
  35. # pipeline_kwargs={"return_full_text": False},
  36. # device=0, device_map='cuda')
  37. ##########################################################################################
  38. from langchain_community.chat_models import ChatOllama
  39. # local_llm = "llama3-groq-tool-use:latest"
  40. local_llm = "llama3-groq-tool-use:latest"
  41. llm = ChatOllama(model=local_llm, temperature=0)
  42. ##########################################################################################
  43. # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
  44. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  45. # llm = HuggingFacePipeline.from_model_id(
  46. # model_id=model_id,
  47. # task="text-generation",
  48. # model_kwargs={"torch_dtype": torch.bfloat16},
  49. # pipeline_kwargs={"return_full_text": False,
  50. # "max_new_tokens": 512},
  51. # device=0, device_map='cuda')
  52. # print(llm.pipeline)
  53. # llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
  54. ##########################################################################################
  55. # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
  56. # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
  57. # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
  58. #, device="auto", load_in_4bit=True
  59. # llm = HuggingFacePipeline(pipeline=pipe)
  60. # llm = HuggingFacePipeline(pipeline=pipe)
  61. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  62. def get_examples():
  63. examples = [
  64. {
  65. "input": "建準廣興廠2023年的自產電力的綠電使用量是多少?",
  66. "query": """SELECT SUM("用電度數(kwh)") AS "自產電力綠電使用量"
  67. FROM "用電度數"
  68. WHERE "項目" = '自產電力(綠電)'
  69. AND "盤查標準" = 'GHG'
  70. AND "年度" = 2023;""",
  71. },
  72. {
  73. "input": "建準廣興廠去年的類別1總排放量是多少?",
  74. "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
  75. FROM "建準碳排放清冊數據"
  76. WHERE "事業名稱" like '%建準%'
  77. AND "事業名稱" like '%廣興廠%'
  78. AND ("類別" like '%類別1-直接排放%' OR "排放源" like '%類別1-直接排放%')
  79. AND "盤查標準" = 'GHG'
  80. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  81. },
  82. {
  83. "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?",
  84. "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
  85. FROM "建準碳排放清冊數據"
  86. WHERE "事業名稱" like '%建準%'
  87. AND "事業名稱" like '%台北辦事處%'
  88. AND ("類別" like '%類別2-能源間接排放%' OR "排放源" like '%類別2-能源間接排放%')
  89. AND "盤查標準" = 'GHG'
  90. AND "年度" = 2022;""",
  91. },
  92. {
  93. "input": "建準去年的固定燃燒總排放量是多少?",
  94. "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  95. FROM "建準碳排放清冊數據"
  96. WHERE "事業名稱" like '%建準%'
  97. AND ("類別" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
  98. AND "盤查標準" = 'GHG'
  99. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  100. },
  101. ]
  102. return examples
  103. def table_description():
  104. database_description = (
  105. "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據`. "
  106. "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
  107. "The `建準碳排放清冊數據` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
  108. "It includes the following columns:\n"
  109. "- `年度`: 盤查年度\n"
  110. "- `事業名稱`: 建準據點"
  111. "- `國家`: 據點所在國家"
  112. "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
  113. " \t*類別1-直接排放:\n"
  114. " \t*類別2-能源間接排放\n"
  115. " \t*類別3-運輸間接排放\n"
  116. " \t*類別4-組織使用產品間接排放\n"
  117. " \t*類別5-使用來自組織產品間接排放\n"
  118. " \t*類別6\n"
  119. "- `排放源`: 由`類別`欄位進一步劃分的細項,包含以下選項:`固定燃燒`, `移動燃燒`, `製程排放`, `逸散排放`, `土地利用`, "
  120. "`外購電力`, `外購能源`, `上游運輸`, `下游運輸`, `員工通勤`, `商務旅行`, `訪客運輸`, "
  121. "`購買產品`, `外購燃料及能資源`, `資本貨物`, `上游租賃`, `廢棄物處理`, `廢棄物清運`, `其他委外業務`, "
  122. "`產品加工`, `產品使用`, `產品最終處理`, `下游租賃`, `投資排放`, `其他`, `其他間接排放` \n"
  123. "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
  124. "- `盤查標準`: ISO or GHG\n"
  125. "The `用電度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
  126. "It includes the following columns:\n"
  127. "- `年度`: 盤查年度\n"
  128. "- `事業名稱`: 建準據點"
  129. "- `國家`: 據點所在國家"
  130. "- `項目`: 用電項目,包含以下:\n"
  131. " \t*外購電力(灰電): 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
  132. " \t*外購電力(綠電): 綠電(太陽光電)的外購電力度數(kwh)\n"
  133. " \t*自產電力(綠電): 綠電(太陽光電)的自產電力度數(kwh)\n"
  134. "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
  135. "- `盤查標準`: ISO or GHG\n"
  136. "The `用水度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
  137. "It includes the following columns:\n"
  138. "- `年度`: 盤查年度\n"
  139. "- `事業名稱`: 建準據點"
  140. "- `國家`: 據點所在國家"
  141. "- `自來水度數(立方公尺 m³)`: 用水度數,單位為m³\n"
  142. "- `盤查標準`: ISO or GHG\n"
  143. )
  144. return database_description
  145. def write_query_chain(db, llm):
  146. template = """
  147. <|begin_of_text|>
  148. <|start_header_id|>system<|end_header_id|>
  149. Generate a SQL query to answer this question: `{input}`
  150. You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
  151. then look at the results of the query and return the answer to the input question.\n\
  152. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
  153. You can order the results to return the most informative data in the database.\n\
  154. Never query for all columns from a table. You must query only the columns that are needed to answer the question.
  155. Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
  156. ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
  157. ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
  158. <|eot_id|>
  159. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  160. DDL statements:
  161. {table_info}
  162. The following is a description of database. Please refer to the database description to give the correct WHERE statement in the PostgreSQL query.\
  163. In particular, the details of the `排放源` and `類別` columns.\n
  164. database description:
  165. {database_description}
  166. Provide ONLY PostgreSQL query and NO premable or explanation!
  167. Below are a number of examples of questions and their corresponding SQL queries.\n\
  168. <|eot_id|>
  169. <|start_header_id|>assistant<|end_header_id|>
  170. """
  171. # prompt_template = PromptTemplate.from_template(template)
  172. example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
  173. prompt = FewShotPromptTemplate(
  174. examples=get_examples(),
  175. example_prompt=example_prompt,
  176. prefix=template,
  177. suffix="User input: {input}\nSQL query: ",
  178. input_variables=["input", "top_k", "table_info"],
  179. )
  180. # llm = Ollama(model = "sqlcoder", num_gpu=1)
  181. # llm = HuggingFacePipeline(pipeline=pipe)
  182. write_query = create_sql_query_chain(llm, db, prompt)
  183. return write_query
  184. def sql_to_nl_chain(llm):
  185. # llm = Ollama(model = "llama3.1", num_gpu=1)
  186. # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
  187. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  188. answer_prompt = PromptTemplate.from_template(
  189. """
  190. <|begin_of_text|>
  191. <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  192. Given the following user question, corresponding SQL query, and SQL result, answer the user question.
  193. 根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
  194. <|eot_id|>
  195. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  196. Question: {question}
  197. SQL Query: {query}
  198. SQL Result: {result}
  199. Answer:
  200. <|eot_id|>
  201. <|start_header_id|>assistant<|end_header_id|>
  202. """
  203. )
  204. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  205. chain = answer_prompt | llm | StrOutputParser()
  206. return chain
  207. def get_query(db, question, selected_table, llm):
  208. write_query = write_query_chain(db, llm)
  209. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  210. query = re.split('SQL query: ', query)[-1]
  211. # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  212. print(query)
  213. return query
  214. def query_to_nl(db, question, query, llm):
  215. execute_query = QuerySQLDataBaseTool(db=db)
  216. result = execute_query.invoke(query)
  217. print(result)
  218. chain = sql_to_nl_chain(llm)
  219. answer = chain.invoke({"question": question, "query": query, "result": result})
  220. return answer
  221. def run(db, question, selected_table, llm):
  222. write_query = write_query_chain(db, llm)
  223. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  224. query = re.split('SQL query: ', query)[-1]
  225. # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  226. print(query)
  227. execute_query = QuerySQLDataBaseTool(db=db)
  228. result = execute_query.invoke(query)
  229. print(result)
  230. chain = sql_to_nl_chain(llm)
  231. answer = chain.invoke({"question": question, "query": query, "result": result})
  232. return query, result, answer
  233. if __name__ == "__main__":
  234. import time
  235. start = time.time()
  236. selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據']
  237. question = "建準去年的上游運輸總排放量是多少?"
  238. # question = "台積電2022年的直接排放總排放量是多少?"
  239. # question = "建準廣興廠去年的灰電使用量"
  240. query, result, answer = run(db, question, selected_table, llm)
  241. print("question: ", question)
  242. print("query: ", query)
  243. print("result: ", result)
  244. print("answer: ", answer)
  245. print(time.time()-start)