2 İşlemeler 2b07b528cb ... b0158324b4

Yazar SHA1 Mesaj Tarih
  ling b0158324b4 add ai-agent file 8 ay önce
  ling 87a0cdfc88 text-to-sql for '104_112碳排放公開及建準資料' table 8 ay önce
5 değiştirilmiş dosya ile 899 ekleme ve 132 silme
  1. BIN
      agent_workflow.png
  2. 404 97
      ai_agent.ipynb
  3. 453 0
      ai_agent.py
  4. 1 1
      faiss_index.py
  5. 41 34
      text_to_sql2.py

BIN
agent_workflow.png


Dosya farkı çok büyük olduğundan ihmal edildi
+ 404 - 97
ai_agent.ipynb


+ 453 - 0
ai_agent.py

@@ -0,0 +1,453 @@
+
+from langchain_community.chat_models import ChatOllama
+from langchain_core.output_parsers import JsonOutputParser
+from langchain_core.prompts import PromptTemplate
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+
+# graph usage
+from pprint import pprint
+from typing import List
+from langchain_core.documents import Document
+from typing_extensions import TypedDict
+from langgraph.graph import END, StateGraph, START
+from langgraph.pregel import RetryPolicy
+
+# supabase db
+from langchain_community.utilities import SQLDatabase
+import os
+from dotenv import load_dotenv
+load_dotenv()
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# LLM
+# local_llm = "llama3.1:8b-instruct-fp16"
+local_llm = "llama3-groq-tool-use:latest"
+llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+llm = ChatOllama(model=local_llm, temperature=0)
+
+# RAG usage
+from faiss_index import create_faiss_retriever, faiss_query
+retriever = create_faiss_retriever()
+
+# text-to-sql usage
+from text_to_sql2 import run, get_query, query_to_nl, table_description
+
+
+def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
+    
+    context = docs
+    
+    system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    Use five sentences maximum and keep the answer concise.
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    勿回答無關資訊
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
+    Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    用繁體中文
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    prompt = ChatPromptTemplate.from_template(
+        system_prompt + "\n\n" +
+        template
+    )
+
+    rag_chain = prompt | llm | StrOutputParser()
+    return rag_chain.invoke({"context": context, "question": question})
+
+### Hallucination Grader
+
+def Hallucination_Grader():
+    # Prompt
+    prompt = PromptTemplate(
+        template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are a grader assessing whether an answer is grounded in / supported by a set of facts. 
+        Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. 
+        Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. 
+        Return the a JSON with a single key 'score' and no premable or explanation. 
+        <|eot_id|><|start_header_id|>user<|end_header_id|>
+        Here are the facts:
+        \n ------- \n
+        {documents} 
+        \n ------- \n
+        Here is the answer: {generation} 
+        Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["generation", "documents"],
+    )
+
+    hallucination_grader = prompt | llm_json | JsonOutputParser()
+    
+    return hallucination_grader
+
+### Answer Grader
+
+def Answer_Grader():
+    # Prompt
+    prompt = PromptTemplate(
+        template="""
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an 
+        answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is 
+        useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
+        <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
+        \n ------- \n
+        {generation} 
+        \n ------- \n
+        Here is the question: {question} 
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["generation", "question"],
+    )
+
+    answer_grader = prompt | llm_json | JsonOutputParser()
+    
+    return answer_grader
+
+# Text-to-SQL
+def run_text_to_sql(question: str):
+    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
+    # question = "建準去年的固定燃燒總排放量是多少?"
+    query, result, answer = run(db, question, selected_table, llm)
+    
+    return  answer, query
+
+def _get_query(question: str):
+    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
+    query = get_query(db, question, selected_table, llm)
+    return  query
+
+def _query_to_nl(question: str, query: str):
+    answer = query_to_nl(db, question, query, llm)
+    return  answer
+
+
+### SQL Grader
+
+def SQL_Grader():
+    prompt = PromptTemplate(
+        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are a SQL query grader assessing correctness of PostgreSQL query to a user question. 
+        Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
+        
+        Here is database description:
+        {table_info}
+        
+        You need to check that each where statement is correctly filtered out what user question need.
+        
+        For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
+        FROM "104_112碳排放公開及建準資料"
+        WHERE "事業名稱" like '%建準%'
+        AND "排放源" = '下游租賃'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
+        For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
+        
+        Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+        FROM "104_112碳排放公開及建準資料"
+        WHERE "事業名稱" like '%台積電%'
+        AND "排放源" = '固定燃燒'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
+        For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
+        
+        and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
+        
+        If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. 
+        You need to strictly examine whether the sql PostgreSQL query matches the user question.
+        Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
+        Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
+        <|eot_id|>
+        
+        <|start_header_id|>user<|end_header_id|>
+        Here is the PostgreSQL query: \n\n {sql_query} \n\n
+        Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+        """,
+        input_variables=["table_info", "question", "sql_query"],
+    )
+
+    sql_query_grader = prompt | llm_json | JsonOutputParser()
+    
+    return sql_query_grader
+
+### Router
+def Router():
+    prompt = PromptTemplate(
+        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are an expert at routing a user question to a vectorstore or company private data. 
+        Use company private data for questions about the informations about a company's greenhouse gas emissions data.
+        Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. 
+        You do not need to be stringent with the keywords in the question related to these topics. 
+        Give a binary choice 'company_private_data' or 'vectorstore' based on the question. 
+        Return the a JSON with a single key 'datasource' and no premable or explanation. 
+        
+        Question to route: {question} 
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["question"],
+    )
+
+    question_router = prompt | llm_json | JsonOutputParser()
+    
+    return question_router
+
+class GraphState(TypedDict):
+    """
+    Represents the state of our graph.
+
+    Attributes:
+        question: question
+        generation: LLM generation
+        company_private_data: whether to search company private data
+        documents: list of documents
+    """
+
+    question: str
+    generation: str
+    documents: List[str]
+    retry: int
+    sql_query: str
+    
+# Node
+def retrieve_and_generation(state):
+    """
+    Retrieve documents from vectorstore
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
+    """
+    print("---RETRIEVE---")
+    question = state["question"]
+
+    # Retrieval
+    # documents = retriever.invoke(question)
+    # TODO: correct Retrieval function
+    documents = retriever.get_relevant_documents(question, k=30)
+    # docs_documents = "\n\n".join(doc.page_content for doc in documents)
+    # print(documents)
+    generation = faiss_query(question, documents, llm)
+    return {"documents": documents, "question": question, "generation": generation}
+
+def company_private_data_get_sql_query(state):
+    """
+    Get PostgreSQL query according to question
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): return generated PostgreSQL query and record retry times
+    """
+    print("---SQL QUERY---")
+    question = state["question"]
+    
+    if state["retry"]:
+        retry = state["retry"]
+        retry += 1
+    else: 
+        retry = 0
+    # print("RETRY: ", retry)
+    
+    sql_query = _get_query(question)
+    
+    return {"sql_query": sql_query, "question": question, "retry": retry}
+    
+def company_private_data_search(state):
+    """
+    Execute PostgreSQL query and convert to nature language.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): Appended sql results to state
+    """
+
+    print("---SQL TO NL---")
+    # print(state)
+    question = state["question"]
+    sql_query = state["sql_query"]
+    generation = _query_to_nl(question, sql_query)
+    
+    # generation = [company_private_data_result]
+    
+    return {"sql_query": sql_query, "question": question, "generation": generation}
+
+### Conditional edge
+
+
+def route_question(state):
+    """
+    Route question to web search or RAG.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        str: Next node to call
+    """
+
+    print("---ROUTE QUESTION---")
+    question = state["question"]
+    # print(question)
+    question_router = Router()
+    source = question_router.invoke({"question": question})
+    # print(source)
+    print(source["datasource"])
+    if source["datasource"] == "company_private_data":
+        print("---ROUTE QUESTION TO TEXT-TO-SQL---")
+        return "company_private_data"
+    elif source["datasource"] == "vectorstore":
+        print("---ROUTE QUESTION TO RAG---")
+        return "vectorstore"
+    
+def grade_generation_v_documents_and_question(state):
+    """
+    Determines whether the generation is grounded in the document and answers question.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        str: Decision for next node to call
+    """
+
+    print("---CHECK HALLUCINATIONS---")
+    question = state["question"]
+    documents = state["documents"]
+    generation = state["generation"]
+
+    
+    # print(docs_documents)
+    # print(generation)
+    hallucination_grader = Hallucination_Grader()
+    score = hallucination_grader.invoke(
+        {"documents": documents, "generation": generation}
+    )
+    # print(score)
+    grade = score["score"]
+
+    # Check hallucination
+    if grade in ["yes", "true", 1, "1"]:
+        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+        # Check question-answering
+        print("---GRADE GENERATION vs QUESTION---")
+        answer_grader = Answer_Grader()
+        score = answer_grader.invoke({"question": question, "generation": generation})
+        grade = score["score"]
+        if grade in ["yes", "true", 1, "1"]:
+            print("---DECISION: GENERATION ADDRESSES QUESTION---")
+            return "useful"
+        else:
+            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+            return "not useful"
+    else:
+        pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+        return "not supported"
+    
+def grade_sql_query(state):
+    """
+    Determines whether the Postgresql query are correct to the question
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): Decision for retry or continue
+    """
+
+    print("---CHECK SQL CORRECTNESS TO QUESTION---")
+    question = state["question"]
+    sql_query = state["sql_query"]
+    retry = state["retry"]
+
+    # Score each doc
+    sql_query_grader = SQL_Grader()
+    score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
+    grade = score["score"]
+    # Document relevant
+    if grade in ["yes", "true", 1, "1"]:
+        print("---GRADE: CORRECT SQL QUERY---")
+        return "correct"
+    elif retry >= 5:
+        print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+        return "failed"
+    else:
+        print("---GRADE: INCORRECT SQL QUERY---")
+        return "incorrect"
+
+def build_graph():
+    workflow = StateGraph(GraphState)
+
+    # Define the nodes
+    workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    
+    workflow.add_conditional_edges(
+        START,
+        route_question,
+        {
+            "company_private_data": "company_private_data_query",
+            "vectorstore": "retrieve_and_generation",
+        },
+    )
+
+    workflow.add_conditional_edges(
+        "retrieve_and_generation",
+        grade_generation_v_documents_and_question,
+        {
+            "not supported": "retrieve_and_generation",
+            "useful": END,
+            "not useful": "retrieve_and_generation",
+        },
+    )
+    workflow.add_conditional_edges(
+        "company_private_data_query",
+        grade_sql_query,
+        {
+            "correct": "company_private_data_search",
+            "incorrect": "company_private_data_query",
+            "failed": END
+            
+        },
+    )
+    workflow.add_edge("company_private_data_search", END)
+
+    app = workflow.compile()    
+    
+    return app
+
+def main():
+    app = build_graph()
+    #建準去年的類別一排放量?
+    inputs = {"question": "溫室氣體是什麼"}
+    for output in app.stream(inputs, {"recursion_limit": 10}):
+        for key, value in output.items():
+            pprint(f"Finished running: {key}:")
+    pprint(value["generation"])
+    
+    return value["generation"]
+
+if __name__ == "__main__":
+    result = main()
+    print("------------------------------------------------------")
+    print(result)

+ 1 - 1
faiss_index.py

@@ -413,7 +413,7 @@ if __name__ == "__main__":
         
 
     # Save results to a JSON file
-    with open('qa_results+all.json', 'w', encoding='utf8') as outfile:
+    with open('qa_results_all.json', 'w', encoding='utf8') as outfile:
         json.dump(results, outfile, indent=4, ensure_ascii=False)
 
     print('All questions done!')

+ 41 - 34
text_to_sql2.py

@@ -75,15 +75,6 @@ llm = ChatOllama(model=local_llm, temperature=0)
 # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
 def get_examples():
     examples = [
-        {
-            "input": "建準去年的固定燃燒總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
-                        FROM "104_112碳排放公開及建準資料"
-                        WHERE "事業名稱" like '%建準%'
-                        AND "排放源" = '固定燃燒'
-                        AND "盤查標準" = 'GHG'
-                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
-        },
         {
             "input": "建準廣興廠去年的類別1總排放量是多少?",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
@@ -104,6 +95,15 @@ def get_examples():
                         AND "盤查標準" = 'GHG'
                         AND "年度" = 2022;""",
         },
+        {
+            "input": "建準去年的固定燃燒總排放量是多少?",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+                        FROM "104_112碳排放公開及建準資料"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "排放源" = '固定燃燒'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
         {
             "input": "台積電2022年的類別1總排放量是多少?",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "台積電2022年類別1總排放量"
@@ -120,10 +120,12 @@ def get_examples():
 
 def table_description():
     database_description = (
-        "The database consists of following tables: `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)`, `2023 清冊數據(GHG)`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
+        "The database consists of following table: `104_112碳排放公開及建準資料`, `水電使用量(ISO)` and `水電使用量(GHG)`. "
         "This is a PostgreSQL database, so you need to use postgres-related queries.\n\n"
-        "The `2022 清冊數據(ISO)`, `2023 清冊數據(ISO)`, `2022 清冊數據(GHG)` and `2023 清冊數據(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
+        "The `104_112碳排放公開及建準資料` table 描述了不同事業單位或廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
         "It includes the following columns:\n"
+
+        "- `年度`: 盤查年度\n"
         "- `類別`: 溫室氣體的排放類別,包含以下:\n"
         "   \t*類別1-直接排放\n"
         "   \t*類別2-能源間接排放\n"
@@ -132,15 +134,9 @@ def table_description():
         "   \t*類別5-使用來自組織產品間接排放\n"
         "   \t*類別6\n"
         "- `排放源`: `類別`欄位進一步劃分的細項\n"
-        "- `高雄總部&運通廠`: 位於台灣的廠房據點\n"
-        "- `台北辦公室`: 位於台灣的廠房據點\n"
-        "- `北海建準廠`: 位於中國的廠房據點\n"
-        "- `北海立準廠`: 位於中國的廠房據點\n"
-        "- `昆山廣興廠`: 位於中國的廠房據點\n"
-        "- `菲律賓建準廠`: 位於菲律賓的廠房據點\n"
-        "- `India`: 位於印度的廠房據點\n"
-        "- `INC`: 位於美國的廠房據點\n"
-        "- `SAS`: 位於法國的廠房據點\n\n"
+        "- `排放量(公噸CO2e)`: 溫室氣體排放量\n"
+        "- `盤查標準`: ISO or GHG\n"
+        
 
         "The `水電使用量(ISO)` and `水電使用量(GHG)` table 描述了不同廠房分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的水電使用量,包含'外購電力 度數 (kwh)'與'自來水 度數 (立方公尺 m³)'。"
         "The `public.departments_table` table contains information about the various departments in the company. It includes:\n"
@@ -196,7 +192,7 @@ def write_query_chain(db, llm):
         input_variables=["input", "top_k", "table_info"],
     )
 
-    # llm = Ollama(model = "mannix/defog-llama3-sqlcoder-8b", num_gpu=1)
+    # llm = Ollama(model = "sqlcoder", num_gpu=1)
     # llm = HuggingFacePipeline(pipeline=pipe)
     
     
@@ -214,19 +210,9 @@ def sql_to_nl_chain(llm):
         <|begin_of_text|>
         <|begin_of_text|><|start_header_id|>system<|end_header_id|>
         Given the following user question, corresponding SQL query, and SQL result, answer the user question.
-        給定以下使用者問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
+        根據使用者的問題、對應的 SQL 查詢和 SQL 結果,以繁體中文回答使用者問題。
 
-        For example
-        Question: 建準廣興廠去年的類別總排放量是多少?
-        SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
-                        FROM "104_112碳排放公開及建準資料"
-                        WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興廠%'
-                        AND "類別" = '類別1-直接排放'
-                        AND "盤查標準" = 'GHG'
-                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
-        SQL Result: [(1102.3712,)]
-        Answer: 建準廣興廠去年的類別1總排放量是1102.3712
+        
         <|eot_id|>
 
         <|begin_of_text|><|start_header_id|>user<|end_header_id|>
@@ -241,14 +227,35 @@ def sql_to_nl_chain(llm):
         """
         )
 
+    # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
     chain = answer_prompt | llm | StrOutputParser()
 
     return chain
 
+def get_query(db, question, selected_table, llm):
+    write_query = write_query_chain(db, llm)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
+    
+    query = re.split('SQL query: ', query)[-1]
+    query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
+    print(query)
+    
+    return query
+
+def query_to_nl(db, question, query, llm):
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain(llm)
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return answer
+
 def run(db, question, selected_table, llm):
 
     write_query = write_query_chain(db, llm)
-    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": "foo"})
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     query = re.split('SQL query: ', query)[-1]
     query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor