text_to_sql.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import re
  2. from dotenv import load_dotenv
  3. load_dotenv()
  4. from langchain_community.utilities import SQLDatabase
  5. import os
  6. URI: str = os.environ.get('SUPABASE_URI')
  7. db = SQLDatabase.from_uri(URI)
  8. # print(db.dialect)
  9. # print(db.get_usable_table_names())
  10. # db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
  11. context = db.get_context()
  12. # print(list(context))
  13. # print(context["table_info"])
  14. from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
  15. from langchain.chains import create_sql_query_chain
  16. from langchain_community.llms import Ollama
  17. from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
  18. from operator import itemgetter
  19. from langchain_core.output_parsers import StrOutputParser
  20. from langchain_core.prompts import PromptTemplate
  21. from langchain_core.runnables import RunnablePassthrough
  22. # Load model directly
  23. from transformers import AutoTokenizer, AutoModelForCausalLM
  24. from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
  25. import torch
  26. from langchain_huggingface import HuggingFacePipeline
  27. # Load model directly
  28. from transformers import AutoTokenizer, AutoModelForCausalLM
  29. # model_id = "defog/llama-3-sqlcoder-8b"
  30. # tokenizer = AutoTokenizer.from_pretrained(model_id)
  31. # sql_llm = HuggingFacePipeline.from_model_id(
  32. # model_id=model_id,
  33. # task="text-generation",
  34. # model_kwargs={"torch_dtype": torch.bfloat16},
  35. # pipeline_kwargs={"return_full_text": False},
  36. # device=0, device_map='cuda')
  37. model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
  38. tokenizer = AutoTokenizer.from_pretrained(model_id)
  39. llm = HuggingFacePipeline.from_model_id(
  40. model_id=model_id,
  41. task="text-generation",
  42. model_kwargs={"torch_dtype": torch.bfloat16},
  43. pipeline_kwargs={"return_full_text": False,
  44. "max_new_tokens": 512},
  45. device=0, device_map='cuda')
  46. print(llm.pipeline)
  47. llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
  48. # model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
  49. # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1,
  50. # model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
  51. #, device="auto", load_in_4bit=True
  52. # llm = HuggingFacePipeline(pipeline=pipe)
  53. # llm = HuggingFacePipeline(pipeline=pipe)
  54. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  55. def get_examples():
  56. examples = [
  57. {
  58. "input": "去年的固定燃燒總排放量是多少?",
  59. "query": 'SELECT SUM("高雄總部及運通廠" + "台北辦事處" + "昆山廣興廠" + "北海建準廠" + "北海立準廠" + "菲律賓建準廠" + "Inc" + "SAS" + "India") AS "固定燃燒總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "排放源" = \'固定燃燒\'',
  60. },
  61. {
  62. "input": "建準廣興廠去年的類別1總排放量是多少?",
  63. "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'',
  64. },
  65. {
  66. "input": "建準廣興廠去年的直接排放總排放量是多少?",
  67. "query": 'SELECT SUM("昆山廣興廠") AS "建準廣興廠直接排放總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%直接排放%\'',
  68. },
  69. ]
  70. return examples
  71. def write_query_chain(db):
  72. template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
  73. Generate a SQL query to answer this question: `{input}`
  74. You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run,
  75. then look at the results of the query and return the answer to the input question.\n\
  76. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
  77. You can order the results to return the most informative data in the database.\n\
  78. Never query for all columns from a table. You must query only the columns that are needed to answer the question.
  79. Wrap each column name in Quotation Mark (") to denote them as delimited identifiers.\n\
  80. ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
  81. ***Pay attention to only return PostgreSQL query.\n\
  82. DDL statements:
  83. {table_info}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
  84. The following SQL query best answers the question `{input}`:
  85. ```sql
  86. """
  87. # prompt_template = PromptTemplate.from_template(template)
  88. example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
  89. prompt = FewShotPromptTemplate(
  90. examples=get_examples(),
  91. example_prompt=example_prompt,
  92. prefix=template,
  93. suffix="User input: {input}\nSQL query: ",
  94. input_variables=["input", "top_k", "table_info"],
  95. )
  96. # llm = Ollama(model = "mannix/defog-llama3-sqlcoder-8b", num_gpu=1)
  97. # llm = HuggingFacePipeline(pipeline=pipe)
  98. write_query = create_sql_query_chain(llm, db, prompt)
  99. return write_query
  100. def sql_to_nl_chain():
  101. # llm = Ollama(model = "llama3.1", num_gpu=1)
  102. # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
  103. # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
  104. answer_prompt = PromptTemplate.from_template(
  105. """
  106. <|begin_of_text|><|start_header_id|>system<|end_header_id|>
  107. Given the following user question, corresponding SQL query, and SQL result, answer the user question.
  108. 給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
  109. For example
  110. Question: 建準廣興廠去年的類別1總排放量是多少?
  111. SQL Query: SELECT SUM("昆山廣興廠") AS "建準廣興廠類別1總排放量"\nFROM "2023 清冊數據(GHG)"\nWHERE "類別" like \'%類別1%\'
  112. SQL Result: [(1102.3712,)]
  113. Answer: 建準廣興廠去年的類別1總排放量是1102.3712
  114. Question: {question}
  115. SQL Query: {query}
  116. SQL Result: {result}
  117. Answer: """
  118. )
  119. chain = answer_prompt | llm | StrOutputParser()
  120. return chain
  121. def run(db, question, selected_table):
  122. write_query = write_query_chain(db)
  123. query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"]})
  124. query = re.split('SQL query: ', query)[-1]
  125. print(query)
  126. execute_query = QuerySQLDataBaseTool(db=db)
  127. result = execute_query.invoke(query)
  128. print(result)
  129. chain = sql_to_nl_chain()
  130. answer = chain.invoke({"question": question, "query": query, "result": result})
  131. return query, result, answer
  132. if __name__ == "__main__":
  133. import time
  134. start = time.time()
  135. selected_table = ['2022 清冊數據(GHG)', '2022 清冊數據(ISO)', '2023 清冊數據(GHG)', '2023 清冊數據(ISO)']
  136. question = "去年的固定燃燒總排放量是多少?"
  137. query, result, answer = run(db, question, selected_table)
  138. print("query: ", query)
  139. print("result: ", result)
  140. print("answer: ", answer)
  141. print(time.time()-start)