text_to_sql_private.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import re
  2. from dotenv import load_dotenv
  3. load_dotenv()
  4. from langchain_community.utilities import SQLDatabase
  5. import os
  6. URI: str = os.environ.get('SUPABASE_URI')
  7. db = SQLDatabase.from_uri(URI)
  8. # print(db.dialect)
  9. # print(db.get_usable_table_names())
  10. # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
  11. context = db.get_context()
  12. # print(list(context))
  13. # print(context["table_info"])
  14. from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
  15. from langchain.chains import create_sql_query_chain
  16. from langchain_community.llms import Ollama
  17. from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
  18. from operator import itemgetter
  19. from langchain_core.output_parsers import StrOutputParser
  20. from langchain_core.prompts import PromptTemplate
  21. from langchain_core.runnables import RunnablePassthrough
  22. # Load model directly
  23. from transformers import AutoTokenizer, AutoModelForCausalLM
  24. from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
  25. import torch
  26. from langchain_huggingface import HuggingFacePipeline
  27. # Load model directly
  28. from transformers import AutoTokenizer, AutoModelForCausalLM
  29. # model_id = "defog/llama-3-sqlcoder-8b"
  30. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  31. # sql_llm = HuggingFacePipeline.from_model_id(
  32. # model_id=model_id,
  33. # task="text-generation",
  34. # model_kwargs={"torch_dtype": torch.bfloat16},
  35. # pipeline_kwargs={"return_full_text": False},
  36. # device=0, device_map='cuda')
  37. ##########################################################################################
  38. from langchain_community.chat_models import ChatOllama
  39. # local_llm = "llama3-groq-tool-use:latest"
  40. # local_llm = "llama3-groq-tool-use:latest"
  41. # local_llm = "sqlcoder:latest"
  42. # local_llm = "llama3.1:8b-instruct-q2_K"
  43. # llm = ChatOllama(model=local_llm, temperature=0)
  44. ##########################################################################################
  45. # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
  46. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  47. # llm = HuggingFacePipeline.from_model_id(
  48. # model_id=model_id,
  49. # task="text-generation",
  50. # model_kwargs={"torch_dtype": torch.bfloat16},
  51. # pipeline_kwargs={"return_full_text": False,
  52. # "max_new_tokens": 512},
  53. # device=0, device_map='cuda')
  54. # print(llm.pipeline)
  55. # llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
  56. ##########################################################################################
  57. # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
  58. # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
  59. # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
  60. #, device="auto", load_in_4bit=True
  61. # llm = HuggingFacePipeline(pipeline=pipe)
  62. # llm = HuggingFacePipeline(pipeline=pipe)
  63. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  64. from langchain_openai import ChatOpenAI
  65. llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  66. def get_examples():
  67. examples = [
  68. {
  69. "input": "建準去年固定燃燒總排放量",
  70. "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
  71. FROM "建準碳排放清冊數據new"
  72. WHERE "事業名稱" like '%建準%'
  73. AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
  74. AND "盤查標準" = 'GHG'
  75. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  76. },
  77. {
  78. "input": "廣興廠去年的固定燃燒排放量是多少?",
  79. "query": """FROM "建準碳排放清冊數據new"
  80. WHERE "事業名稱" like '%建準%'
  81. AND "據點" = '昆山廣興廠'
  82. AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
  83. AND "盤查標準" = 'GHG'
  84. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  85. },
  86. {
  87. "input": "建準廣興廠去年自產電力的綠電使用量是多少?",
  88. "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
  89. FROM "用電度數"
  90. WHERE "項目" like '%綠電%'
  91. AND "事業名稱" like '%建準%'
  92. AND "據點" = '昆山廣興廠'
  93. AND "盤查標準" = 'GHG'
  94. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  95. },
  96. {
  97. "input": "建準北海廠去年的類別1總排放量",
  98. "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
  99. FROM "建準碳排放清冊數據new"
  100. WHERE "事業名稱" like '%建準%'
  101. AND "據點" in ('北海建準廠', '北海立準廠')
  102. AND "類別" = '類別1'
  103. AND "盤查標準" = 'GHG'
  104. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  105. },
  106. {
  107. "input": "建準廣興廠去年的直接排放總排放量是多少?",
  108. "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
  109. FROM "建準碳排放清冊數據new"
  110. WHERE "事業名稱" like '%建準%'
  111. AND "據點" = '昆山廣興廠'
  112. AND ("類別項目" like '%直接排放%' OR "排放源" like '%直接排放%')
  113. AND "盤查標準" = 'GHG'
  114. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
  115. },
  116. {
  117. "input": "建準台北辦事處2022年的類別2總排放量是多少?",
  118. "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
  119. FROM "建準碳排放清冊數據new"
  120. WHERE "事業名稱" like '%建準%'
  121. AND "據點" = '台北辦事處'
  122. AND "類別" = '類別2'
  123. AND "盤查標準" = 'GHG'
  124. AND "年度" = 2022;""",
  125. },
  126. {
  127. "input": "建準法國廠2022年的類別2總排放量",
  128. "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
  129. FROM "建準碳排放清冊數據new"
  130. WHERE "事業名稱" like '%建準%'
  131. AND "國家" = '法國'
  132. AND "類別" = '類別2'
  133. AND "盤查標準" = 'GHG'
  134. AND "年度" = 2022;""",
  135. },
  136. {
  137. "input": "建準北海2022的外購電力是多少",
  138. "query": """SELECT SUM("用電度數(kwh)") AS "外購電力"
  139. FROM "用電度數"
  140. WHERE "事業名稱" like '%建準%'
  141. AND "據點" in ('北海建準廠', '北海立準廠')
  142. AND "項目" like '%外購電力%'
  143. AND "盤查標準" = 'GHG'
  144. AND "年度" = 2022;""",
  145. },
  146. {
  147. "input": "2023建準印度的其他間接排放是多少",
  148. "query": """SELECT SUM("排放量(公噸CO2e)") AS "其他間接排放總量"
  149. FROM "建準碳排放清冊數據new"
  150. WHERE "事業名稱" like '%建準%'
  151. AND "國家" = '印度'
  152. AND ("類別項目" like '%其他間接排放%' OR "排放源" like '%其他間接排放%')
  153. AND "盤查標準" = 'GHG'
  154. AND "年度" = 2023;""",
  155. },
  156. {
  157. "input": "建準台北前年的產品使用碳排放量是多少",
  158. "query": """SELECT SUM("排放量(公噸CO2e)") AS "產品使用總量"
  159. FROM "建準碳排放清冊數據new"
  160. WHERE "事業名稱" like '%建準%'
  161. AND "據點" = '台北辦事處'
  162. AND ("類別項目" like '%產品使用%' OR "排放源" like '%產品使用%')
  163. AND "盤查標準" = 'GHG'
  164. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-2;""",
  165. },
  166. ]
  167. return examples
  168. def table_description():
  169. database_description = (
  170. "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據new`."
  171. "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
  172. "The `建準碳排放清冊數據new` table 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
  173. "It includes the following columns:\n"
  174. "- `年度`: 盤查年度\n"
  175. "- `事業名稱`: 公司名稱"
  176. "- `據點`: 建準廠房據點 include '高雄總部及運通廠', '台北辦事處', '昆山廣興廠', '北海建準廠', '北海立準廠', '菲律賓建準廠', 'Inc', 'SAS', 'India'"
  177. "- `國家`: 據點所在國家"
  178. "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
  179. " \t*類別1-直接排放:\n"
  180. " \t*類別2-能源間接排放\n"
  181. " \t*類別3-運輸間接排放\n"
  182. " \t*類別4-組織使用產品間接排放\n"
  183. " \t*類別5-使用來自組織產品間接排放\n"
  184. " \t*類別6\n"
  185. "- `排放源`: 由`類別`欄位進一步劃分的細項,包含以下選項:`固定燃燒`, `移動燃燒`, `製程排放`, `逸散排放`, `土地利用`, "
  186. "`外購電力`, `外購能源`, `上游運輸`, `下游運輸`, `員工通勤`, `商務旅行`, `訪客運輸`, "
  187. "`購買產品`, `外購燃料及能資源`, `資本貨物`, `上游租賃`, `廢棄物處理`, `廢棄物清運`, `其他委外業務`, "
  188. "`產品加工`, `產品使用`, `產品最終處理`, `下游租賃`, `投資排放`, `其他`, `其他間接排放` \n"
  189. "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
  190. "- `盤查標準`: ISO or GHG\n"
  191. "The `用電度數` 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
  192. "It includes the following columns:\n"
  193. "- `年度`: 盤查年度\n"
  194. "- `事業名稱`: 建準據點"
  195. "- `國家`: 據點所在國家"
  196. "- `項目`: 用電項目,包含以下:\n"
  197. " \t*外購電力(灰電): 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
  198. " \t*外購電力(綠電): 綠電(太陽光電)的外購電力度數(kwh)\n"
  199. " \t*自產電力(綠電): 綠電(太陽光電)的自產電力度數(kwh)\n"
  200. "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
  201. "- `盤查標準`: ISO or GHG\n"
  202. "The `用水度數` 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
  203. "It includes the following columns:\n"
  204. "- `年度`: 盤查年度\n"
  205. "- `事業名稱`: 建準據點"
  206. "- `國家`: 據點所在國家"
  207. "- `自來水度數(立方公尺 m³)`: 用水度數,單位為m³\n"
  208. "- `盤查標準`: ISO or GHG\n"
  209. )
  210. return database_description
  211. def write_query_chain(db, llm):
  212. template = """
  213. <|begin_of_text|>
  214. <|start_header_id|>system<|end_header_id|>
  215. Generate a SQL query to answer this question: `{input}`
  216. 你是建準的AI助理,幫助建準查詢碳排放量,如果問題中有提到據點廠房,請使用 PostgreSQL query 進行篩選。
  217. You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
  218. then look at the results of the query and return the answer to the input question.\n\
  219. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
  220. You can order the results to return the most informative data in the database.\n\
  221. Never query for all columns from a table. You must query only the columns that are needed to answer the question.
  222. Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
  223. Unless the user ask for the type of 盤查標準 to be 'ISO' or 'GHG', queries always include query "盤查標準"='GHG' in the WHERE clause.\n
  224. ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
  225. ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
  226. <|eot_id|>
  227. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  228. DDL statements:
  229. {table_info}
  230. The following is a description of database. Please refer to the database description to give the correct WHERE statement in the PostgreSQL query.\
  231. In particular, the details of the `排放源` and `類別` columns.\n
  232. database description:
  233. {database_description}
  234. Provide ONLY PostgreSQL query and NO premable or explanation!
  235. Below are a number of examples of questions and their corresponding SQL queries.\n\
  236. <|eot_id|>
  237. SQL query:
  238. """
  239. # <|start_header_id|>assistant<|end_header_id|>
  240. # prompt_template = PromptTemplate.from_template(template)
  241. example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
  242. prompt = FewShotPromptTemplate(
  243. examples=get_examples(),
  244. example_prompt=example_prompt,
  245. prefix=template,
  246. suffix="User input: {input}\nSQL query: ",
  247. input_variables=["input", "top_k", "table_info"],
  248. )
  249. # llm = Ollama(model = "sqlcoder", num_gpu=1)
  250. # llm = HuggingFacePipeline(pipeline=pipe)
  251. # sqlcoder = Ollama(model = "sqlcoder", num_gpu=1)
  252. write_query = create_sql_query_chain(llm, db, prompt)
  253. return write_query
  254. def sql_to_nl_chain(llm):
  255. # llm = Ollama(model = "llama3.1", num_gpu=1)
  256. # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
  257. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  258. answer_prompt = PromptTemplate.from_template(
  259. """
  260. <|begin_of_text|>
  261. <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  262. Given the following user question, corresponding SQL query, and SQL result, answer the user question.
  263. 根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
  264. ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
  265. The following shows some example:
  266. Question: 建準廣興廠去年的類別1總排放量是多少?
  267. SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
  268. FROM "建準碳排放清冊數據new"
  269. WHERE "事業名稱" like '%建準%'
  270. AND "據點" = '昆山廣興廠'
  271. AND "類別" = '類別1'
  272. AND "盤查標準" = 'GHG'
  273. AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
  274. SQL Result: [(1102.3712,)]
  275. Answer: 建準廣興廠去年的類別1總排放量是1102.3712
  276. 如果你不知道答案或SQL query 出現錯誤請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  277. 勿回答無關資訊
  278. <|eot_id|>
  279. <|begin_of_text|><|start_header_id|>user<|end_header_id|>
  280. Question: {question}
  281. SQL Query: {query}
  282. SQL Result: {result}
  283. Answer:
  284. <|eot_id|>
  285. <|start_header_id|>assistant<|end_header_id|>
  286. """
  287. )
  288. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  289. chain = answer_prompt | llm | StrOutputParser()
  290. return chain
  291. def get_query(db, question, selected_table, llm):
  292. write_query = write_query_chain(db, llm)
  293. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  294. query = re.split('SQL query: ', query)[-1]
  295. query = query.replace("```sql","").replace("```","")
  296. query = query.replace("碰排","碳排")
  297. query = query.replace("%%","%")
  298. # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  299. print(query)
  300. execute_query = QuerySQLDataBaseTool(db=db)
  301. result = execute_query.invoke(query)
  302. print(result)
  303. return query, result
  304. def query_to_nl(question, query, result, llm):
  305. # execute_query = QuerySQLDataBaseTool(db=db)
  306. # result = execute_query.invoke(query)
  307. # print(result)
  308. chain = sql_to_nl_chain(llm)
  309. print(result)
  310. answer = chain.invoke({"question": question, "query": query, "result": result})
  311. return answer
  312. def run(db, question, selected_table, llm):
  313. write_query = write_query_chain(db, llm)
  314. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
  315. query = re.split('SQL query: ', query)[-1]
  316. query = query.replace("```sql","").replace("```","")
  317. query = query.replace("碰排","碳排")
  318. query = query.replace("%%","%")
  319. # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
  320. print(query)
  321. execute_query = QuerySQLDataBaseTool(db=db)
  322. result = execute_query.invoke(query)
  323. print(result)
  324. chain = sql_to_nl_chain(llm)
  325. answer = chain.invoke({"question": question, "query": query, "result": result})
  326. return query, result, answer
  327. if __name__ == "__main__":
  328. import time
  329. start = time.time()
  330. selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
  331. # question = "建準廣興廠去年的上游運輸總排放量是多少?"
  332. question = "建準北海廠去年的固定燃燒排放量是多少?"
  333. # question = "建準北海廠去年類別1總排放量是多少?"
  334. # question = "台積電2022年的直接排放總排放量是多少?"
  335. # question = "建準廣興廠去年的灰電使用量"
  336. query, result, answer = run(db, question, selected_table, llm)
  337. print("question: ", question)
  338. print("query: ", query)
  339. print("result: ", result)
  340. print("answer: ", answer)
  341. print(time.time()-start)