Browse Source

update ai agents

ling 4 months ago
parent
commit
16c9e5af51
5 changed files with 394 additions and 31 deletions
  1. 2 1
      ai_agent.ipynb
  2. 58 20
      ai_agent.py
  3. 23 8
      post_processing_sqlparse.py
  4. 2 2
      text_to_sql2.py
  5. 309 0
      text_to_sql_private.py

+ 2 - 1
ai_agent.ipynb

@@ -933,7 +933,8 @@
     "    return {\"sql_query\": sql_query, \"question\": question, \"generation\": generation}\n",
     "\n",
     "def additional_explanation(state):\n",
-    "    \"\"\"_summary_\n",
+    "    \"\"\"\n",
+    "    \n",
     "\n",
     "    Args:\n",
     "        state (_type_): _description_\n",

+ 58 - 20
ai_agent.py

@@ -33,8 +33,8 @@ from faiss_index import create_faiss_retriever, faiss_query
 retriever = create_faiss_retriever()
 
 # text-to-sql usage
-from text_to_sql2 import run, get_query, query_to_nl, table_description
-from post_processing_sqlparse import get_query_columns, parse_sql_for_stock_info, get_table_name
+from text_to_sql_private import run, get_query, query_to_nl, table_description
+from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 
 
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
@@ -136,8 +136,19 @@ def _query_to_nl(question: str, query: str):
     answer = query_to_nl(db, question, query, llm)
     return  answer
 
+def generate_additional_question(sql_query):
+    terms = parse_sql_where(sql_query)
+    question_list = []
+    for term in terms:
+        if term is None: continue
+        question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
+        question_list.extend(question_format)
+        
+    return question_list
+        
+    
 def generate_additional_detail(sql_query):
-    terms = parse_sql_for_stock_info(sql_query)
+    terms = parse_sql_where(sql_query)
     answer = ""
     for term in terms:
         if term is None: continue
@@ -232,7 +243,9 @@ class GraphState(TypedDict):
         documents: list of documents
     """
 
+    route: str
     question: str
+    question_list: List[str]
     generation: str
     documents: List[str]
     retry: int
@@ -250,16 +263,32 @@ def retrieve_and_generation(state):
         state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
     """
     print("---RETRIEVE---")
+    if not state["route"]:
+        route = "RAG"
+    else:
+        route = state["route"]
     question = state["question"]
-
+    print(state)
+    question_list = state["question_list"]
+    
     # Retrieval
-    # documents = retriever.invoke(question)
-    # TODO: correct Retrieval function
-    documents = retriever.get_relevant_documents(question, k=30)
-    # docs_documents = "\n\n".join(doc.page_content for doc in documents)
-    # print(documents)
-    generation = faiss_query(question, documents, llm)
-    return {"documents": documents, "question": question, "generation": generation}
+    if not question_list:
+        # documents = retriever.invoke(question)
+        # TODO: correct Retrieval function
+        documents = retriever.get_relevant_documents(question, k=30)
+        # docs_documents = "\n\n".join(doc.page_content for doc in documents)
+        # print(documents)
+        generation = faiss_query(question, documents, llm)
+    else:
+        generation = state["generation"]
+        
+        for sub_question in question_list:
+            documents = retriever.get_relevant_documents(sub_question, k=30)
+            generation += faiss_query(sub_question, documents, llm)
+            generation += "\n"
+            
+    
+    return {"route": route, "documents": documents, "question": question, "generation": generation}
 
 def company_private_data_get_sql_query(state):
     """
@@ -272,6 +301,10 @@ def company_private_data_get_sql_query(state):
         state (dict): return generated PostgreSQL query and record retry times
     """
     print("---SQL QUERY---")
+    if not state["route"]:
+        route = "SQL"
+    else:
+        route = state["route"]
     question = state["question"]
     
     if state["retry"]:
@@ -283,7 +316,7 @@ def company_private_data_get_sql_query(state):
     
     sql_query = _get_query(question)
     
-    return {"sql_query": sql_query, "question": question, "retry": retry}
+    return {"route": route,"sql_query": sql_query, "question": question, "retry": retry}
     
 def company_private_data_search(state):
     """
@@ -306,7 +339,7 @@ def company_private_data_search(state):
     
     return {"sql_query": sql_query, "question": question, "generation": generation}
 
-def additional_explanation(state):
+def additional_explanation_question(state):
     """
     
     Args:
@@ -320,13 +353,17 @@ def additional_explanation(state):
     print(state)
     question = state["question"]
     sql_query = state["sql_query"]
+    print(sql_query)
     generation = state["generation"]
-    generation += "\n"
-    generation += generate_additional_detail(sql_query)
+    question_list = generate_additional_question(sql_query)
+    print(question_list)
+    # generation += "\n"
+    # generation += generate_additional_detail(sql_query)
+    
     
     # generation = [company_private_data_result]
     
-    return {"sql_query": sql_query, "question": question, "generation": generation}
+    return {"sql_query": sql_query, "question": question, "generation": generation, "question_list": question_list}
 
 ### Conditional edge
 
@@ -437,7 +474,7 @@ def build_graph():
     # Define the nodes
     workflow.add_node("Text-to-SQL", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
     workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
-    workflow.add_node("Additoinal Explanation", additional_explanation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     
     workflow.add_conditional_edges(
@@ -469,7 +506,7 @@ def build_graph():
         },
     )
     workflow.add_edge("SQL Answer", "Additoinal Explanation")
-    workflow.add_edge("Additoinal Explanation", END)
+    workflow.add_edge("Additoinal Explanation", "RAG")
 
     app = workflow.compile()    
     
@@ -483,11 +520,12 @@ def main(question: str):
     for output in app.stream(inputs, {"recursion_limit": 10}):
         for key, value in output.items():
             pprint(f"Finished running: {key}:")
-    pprint(value["generation"])
+    # pprint(value["generation"])
+    # pprint(value)
     
     return value["generation"]
 
 if __name__ == "__main__":
-    result = main("建準去年的直接排放排放量?")
+    result = main("建準去年的逸散排放總排放量是多少?")
     print("------------------------------------------------------")
     print(result)

+ 23 - 8
post_processing_sqlparse.py

@@ -54,22 +54,37 @@ def extract_comparison_value(tokens, target):
             return eval(data)[0][0]
     return None
 
-def parse_sql_for_stock_info(sql):
+def parse_sql_where(sql):
     """Parse the SQL statement to extract 排放源, 類別"""
     stmt = sqlparse.parse(sql)[0]
-    emission, class_type = None, None
-    
+    column_dict = {
+        "排放源": None,
+        "類別": None
+    }
+
+    def get_column_details(token, column_args):
+        if isinstance(token, Comparison):
+            print(token, type(token))
+            for column_name in column_args.keys():
+                if column_args[column_name] is None:
+                    column_args[column_name] = extract_comparison_value(token.tokens, column_name)
+                
+        return column_args
+
     for token in stmt.tokens:
         if isinstance(token, sqlparse.sql.Comment):
             continue
         if token.value.lower().startswith('where'):
             for token2 in token.tokens:
+                # print(token2, type(token2))
                 if isinstance(token2, Comparison):
-                    if emission is None:
-                        emission = extract_comparison_value(token2.tokens, "排放源")
-                    if class_type is None:
-                        class_type = extract_comparison_value(token2.tokens, "類別")
-    return emission, class_type
+                    column_dict = get_column_details(token2, column_dict)
+                elif isinstance(token2, Parenthesis):
+                    # print(token2, type(token2))
+                    for token3 in token2.tokens:
+                        column_dict = get_column_details(token3, column_dict)
+    column_values = [column_dict[column_name].replace("%", "") for column_name in column_dict.keys()]
+    return column_values
 
 def get_table_name(sql):
     stmt = sqlparse.parse(sql)[0]

+ 2 - 2
text_to_sql2.py

@@ -175,7 +175,7 @@ def write_query_chain(db, llm):
     {database_description}
 
     Provide ONLY PostgreSQL query and NO premable or explanation!
-    The following SQL query best answers the question `{input}`:
+    Below are a number of examples of questions and their corresponding SQL queries.\n\
     
     <|eot_id|>
     
@@ -183,7 +183,7 @@ def write_query_chain(db, llm):
     """
     # prompt_template = PromptTemplate.from_template(template)
 
-    example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
+    example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`:\nSQL query: {query}")
     prompt = FewShotPromptTemplate(
         examples=get_examples(),
         example_prompt=example_prompt,

+ 309 - 0
text_to_sql_private.py

@@ -0,0 +1,309 @@
+import re
+from dotenv import load_dotenv
+load_dotenv()
+
+from langchain_community.utilities import SQLDatabase
+import os
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# print(db.dialect)
+# print(db.get_usable_table_names())
+# db.run('SELECT * FROM "2022 清冊數據(GHG)" LIMIT 10;')
+
+context = db.get_context()
+# print(list(context))
+# print(context["table_info"])
+
+from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
+from langchain.chains import create_sql_query_chain
+from langchain_community.llms import Ollama
+
+from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
+from operator import itemgetter
+
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import PromptTemplate
+from langchain_core.runnables import RunnablePassthrough
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoModelForCausalLM, AutoTokenizer,pipeline
+import torch
+from langchain_huggingface import HuggingFacePipeline
+
+# Load model directly
+from transformers import AutoTokenizer, AutoModelForCausalLM
+# model_id = "defog/llama-3-sqlcoder-8b"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+# sql_llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False},
+#     device=0, device_map='cuda')
+
+##########################################################################################
+from langchain_community.chat_models import ChatOllama
+# local_llm = "llama3-groq-tool-use:latest"
+local_llm = "llama3-groq-tool-use:latest"
+llm = ChatOllama(model=local_llm, temperature=0)
+##########################################################################################
+# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+# llm = HuggingFacePipeline.from_model_id(
+#     model_id=model_id,
+#     task="text-generation",
+#     model_kwargs={"torch_dtype": torch.bfloat16},
+#     pipeline_kwargs={"return_full_text": False,
+#         "max_new_tokens": 512},
+#     device=0, device_map='cuda')
+# print(llm.pipeline)
+# llm.pipeline.tokenizer.pad_token_id = llm.pipeline.model.config.eos_token_id[0]
+##########################################################################################
+
+# model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
+
+# pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, top_k=50, temperature=0.1, 
+#                 model_kwargs={"torch_dtype": torch.bfloat16, "return_full_text": False})
+#, device="auto", load_in_4bit=True
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = HuggingFacePipeline(pipeline=pipe)
+
+# llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+def get_examples():
+    examples = [
+        {
+            "input": "建準廣興廠2023年的自產電力的綠電使用量是多少?",
+            "query": """SELECT SUM("用電度數(kwh)") AS "自產電力綠電使用量"
+                        FROM "用電度數"
+                        WHERE "項目" = '自產電力(綠電)'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2023;""",
+        },
+        {
+            "input": "建準廣興廠去年的類別1總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+                        FROM "建準碳排放清冊數據"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%廣興廠%'
+                        AND ("類別" like '%類別1-直接排放%' OR "排放源" like '%類別1-直接排放%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "建準台北辦事處2022年的能源間接排放總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
+                        FROM "建準碳排放清冊數據"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "事業名稱" like '%台北辦事處%'
+                        AND ("類別" like '%類別2-能源間接排放%' OR "排放源" like '%類別2-能源間接排放%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2022;""",
+        },
+        {
+            "input": "建準去年的固定燃燒總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+                        FROM "建準碳排放清冊數據"
+                        WHERE "事業名稱" like '%建準%'
+                        AND ("類別" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+
+
+    ]
+
+    return examples
+
+def table_description():
+    database_description = (
+        "The database consists of following table: `用水度數`, `用水度數`, `建準碳排放清冊數據`. "
+        "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
+        "The `建準碳排放清冊數據` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
+        "It includes the following columns:\n"
+        "- `年度`: 盤查年度\n"
+        "- `事業名稱`: 建準據點"
+        "- `國家`: 據點所在國家"
+        "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
+        "   \t*類別1-直接排放:\n"
+        "   \t*類別2-能源間接排放\n"
+        "   \t*類別3-運輸間接排放\n"
+        "   \t*類別4-組織使用產品間接排放\n"
+        "   \t*類別5-使用來自組織產品間接排放\n"
+        "   \t*類別6\n"
+        "- `排放源`: 由`類別`欄位進一步劃分的細項,包含以下選項:`固定燃燒`, `移動燃燒`, `製程排放`, `逸散排放`, `土地利用`, "
+        "`外購電力`, `外購能源`, `上游運輸`, `下游運輸`, `員工通勤`, `商務旅行`, `訪客運輸`, "
+        "`購買產品`, `外購燃料及能資源`, `資本貨物`, `上游租賃`, `廢棄物處理`, `廢棄物清運`, `其他委外業務`, "
+        "`產品加工`, `產品使用`, `產品最終處理`, `下游租賃`, `投資排放`, `其他`, `其他間接排放` \n"
+        "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
+        "- `盤查標準`: ISO or GHG\n"
+        
+
+        "The `用電度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
+        "It includes the following columns:\n"
+        "- `年度`: 盤查年度\n"
+        "- `事業名稱`: 建準據點"
+        "- `國家`: 據點所在國家"
+        "- `項目`: 用電項目,包含以下:\n"
+        "   \t*外購電力(灰電): 灰電(火力發電、核能發電等)的外購電力度數(kwh)\n"
+        "   \t*外購電力(綠電): 綠電(太陽光電)的外購電力度數(kwh)\n"
+        "   \t*自產電力(綠電): 綠電(太陽光電)的自產電力度數(kwh)\n"
+        "- `用電度數(kwh)`: 用電度數,單位為kwh\n"
+        "- `盤查標準`: ISO or GHG\n"
+        
+        "The `用水度數` 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量。"
+        "It includes the following columns:\n"
+        "- `年度`: 盤查年度\n"
+        "- `事業名稱`: 建準據點"
+        "- `國家`: 據點所在國家"
+        "- `自來水度數(立方公尺 m³)`: 用水度數,單位為m³\n"
+        "- `盤查標準`: ISO or GHG\n"
+    )
+
+    return database_description
+
+def write_query_chain(db, llm):
+
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    Generate a SQL query to answer this question: `{input}`
+
+    You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
+    then look at the results of the query and return the answer to the input question.\n\
+    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. 
+    You can order the results to return the most informative data in the database.\n\
+    Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
+    Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
+    
+    ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
+    ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
+    <|eot_id|>
+        
+    <|begin_of_text|><|start_header_id|>user<|end_header_id|>
+    DDL statements:
+    {table_info}
+
+    The following is a description of database. Please refer to the database description to give the correct WHERE statement in the PostgreSQL query.\
+    In particular, the details of the `排放源` and `類別` columns.\n
+    database description:
+    {database_description}
+
+    Provide ONLY PostgreSQL query and NO premable or explanation!
+    Below are a number of examples of questions and their corresponding SQL queries.\n\
+    
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    # prompt_template = PromptTemplate.from_template(template)
+
+    example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
+    prompt = FewShotPromptTemplate(
+        examples=get_examples(),
+        example_prompt=example_prompt,
+        prefix=template,
+        suffix="User input: {input}\nSQL query: ",
+        input_variables=["input", "top_k", "table_info"],
+    )
+
+    # llm = Ollama(model = "sqlcoder", num_gpu=1)
+    # llm = HuggingFacePipeline(pipeline=pipe)
+    
+    
+    write_query = create_sql_query_chain(llm, db, prompt)
+
+
+    return write_query
+
+def sql_to_nl_chain(llm):
+    # llm = Ollama(model = "llama3.1", num_gpu=1)
+    # llm = Ollama(model = "llama3.1:8b-instruct-q2_K", num_gpu=1)
+    # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+    answer_prompt = PromptTemplate.from_template(
+        """
+        <|begin_of_text|>
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|>
+        Given the following user question, corresponding SQL query, and SQL result, answer the user question.
+        根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
+
+        
+        <|eot_id|>
+
+        <|begin_of_text|><|start_header_id|>user<|end_header_id|>
+        Question: {question}
+        SQL Query: {query}
+        SQL Result: {result}
+        Answer: 
+        <|eot_id|>
+        
+        <|start_header_id|>assistant<|end_header_id|>
+        
+        """
+        )
+
+    # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+    chain = answer_prompt | llm | StrOutputParser()
+
+    return chain
+
+def get_query(db, question, selected_table, llm):
+    write_query = write_query_chain(db, llm)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
+    
+    query = re.split('SQL query: ', query)[-1]
+    # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
+    print(query)
+    
+    return query
+
+def query_to_nl(db, question, query, llm):
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain(llm)
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return answer
+
+def run(db, question, selected_table, llm):
+
+    write_query = write_query_chain(db, llm)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
+    
+    query = re.split('SQL query: ', query)[-1]
+    # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
+    print(query)
+
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain(llm)
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return query, result, answer
+
+
+if __name__ == "__main__":
+    import time
+    
+    start = time.time()
+    
+    selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據']
+    question = "建準去年的上游運輸總排放量是多少?"
+    # question = "台積電2022年的直接排放總排放量是多少?"
+    # question = "建準廣興廠去年的灰電使用量"
+    query, result, answer = run(db, question, selected_table, llm)
+    print("question: ", question)
+    print("query: ", query)
+    print("result: ", result)
+    print("answer: ", answer)
+    
+    print(time.time()-start)
+