Browse Source

add ai-agent file

ling 4 months ago
parent
commit
b0158324b4
4 changed files with 858 additions and 98 deletions
  1. BIN
      agent_workflow.png
  2. 404 97
      ai_agent.ipynb
  3. 453 0
      ai_agent.py
  4. 1 1
      faiss_index.py

BIN
agent_workflow.png


File diff suppressed because it is too large
+ 404 - 97
ai_agent.ipynb


+ 453 - 0
ai_agent.py

@@ -0,0 +1,453 @@
+
+from langchain_community.chat_models import ChatOllama
+from langchain_core.output_parsers import JsonOutputParser
+from langchain_core.prompts import PromptTemplate
+
+from langchain.prompts import ChatPromptTemplate
+from langchain_core.output_parsers import StrOutputParser
+
+# graph usage
+from pprint import pprint
+from typing import List
+from langchain_core.documents import Document
+from typing_extensions import TypedDict
+from langgraph.graph import END, StateGraph, START
+from langgraph.pregel import RetryPolicy
+
+# supabase db
+from langchain_community.utilities import SQLDatabase
+import os
+from dotenv import load_dotenv
+load_dotenv()
+URI: str =  os.environ.get('SUPABASE_URI')
+db = SQLDatabase.from_uri(URI)
+
+# LLM
+# local_llm = "llama3.1:8b-instruct-fp16"
+local_llm = "llama3-groq-tool-use:latest"
+llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+llm = ChatOllama(model=local_llm, temperature=0)
+
+# RAG usage
+from faiss_index import create_faiss_retriever, faiss_query
+retriever = create_faiss_retriever()
+
+# text-to-sql usage
+from text_to_sql2 import run, get_query, query_to_nl, table_description
+
+
+def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
+    
+    context = docs
+    
+    system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
+    You should not mention anything about "根據提供的文件內容" or other similar terms.
+    Use five sentences maximum and keep the answer concise.
+    如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    勿回答無關資訊
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
+    Answer the following question based on this context:
+
+    {context}
+
+    Question: {question}
+    用繁體中文
+    <|eot_id|>
+    
+    <|start_header_id|>assistant<|end_header_id|>
+    """
+    prompt = ChatPromptTemplate.from_template(
+        system_prompt + "\n\n" +
+        template
+    )
+
+    rag_chain = prompt | llm | StrOutputParser()
+    return rag_chain.invoke({"context": context, "question": question})
+
+### Hallucination Grader
+
+def Hallucination_Grader():
+    # Prompt
+    prompt = PromptTemplate(
+        template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are a grader assessing whether an answer is grounded in / supported by a set of facts. 
+        Give 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. 
+        Provide 'yes' or 'no' score as a JSON with a single key 'score' and no preamble or explanation. 
+        Return the a JSON with a single key 'score' and no premable or explanation. 
+        <|eot_id|><|start_header_id|>user<|end_header_id|>
+        Here are the facts:
+        \n ------- \n
+        {documents} 
+        \n ------- \n
+        Here is the answer: {generation} 
+        Provide 'yes' or 'no' score as a JSON with a single key 'score' and no premable or explanation.
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["generation", "documents"],
+    )
+
+    hallucination_grader = prompt | llm_json | JsonOutputParser()
+    
+    return hallucination_grader
+
+### Answer Grader
+
+def Answer_Grader():
+    # Prompt
+    prompt = PromptTemplate(
+        template="""
+        <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an 
+        answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is 
+        useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
+        <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
+        \n ------- \n
+        {generation} 
+        \n ------- \n
+        Here is the question: {question} 
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["generation", "question"],
+    )
+
+    answer_grader = prompt | llm_json | JsonOutputParser()
+    
+    return answer_grader
+
+# Text-to-SQL
+def run_text_to_sql(question: str):
+    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
+    # question = "建準去年的固定燃燒總排放量是多少?"
+    query, result, answer = run(db, question, selected_table, llm)
+    
+    return  answer, query
+
+def _get_query(question: str):
+    selected_table = ['104_112碳排放公開及建準資料', '水電使用量(GHG)', '水電使用量(ISO)']
+    query = get_query(db, question, selected_table, llm)
+    return  query
+
+def _query_to_nl(question: str, query: str):
+    answer = query_to_nl(db, question, query, llm)
+    return  answer
+
+
+### SQL Grader
+
+def SQL_Grader():
+    prompt = PromptTemplate(
+        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are a SQL query grader assessing correctness of PostgreSQL query to a user question. 
+        Based on following database description, you need to grade whether the PostgreSQL query exactly matches the user question.
+        
+        Here is database description:
+        {table_info}
+        
+        You need to check that each where statement is correctly filtered out what user question need.
+        
+        For example, if user question is "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
+        FROM "104_112碳排放公開及建準資料"
+        WHERE "事業名稱" like '%建準%'
+        AND "排放源" = '下游租賃'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
+        For the above example, we can find that user asked for "固定燃燒", but the PostgreSQL query gives "排放源" = '下游租賃' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
+        
+        Another example like "建準去年的固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        "SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+        FROM "104_112碳排放公開及建準資料"
+        WHERE "事業名稱" like '%台積電%'
+        AND "排放源" = '固定燃燒'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;"
+        For the above example, we can find that user asked for "建準", but the PostgreSQL query gives "事業名稱" like '%台積電%' in WHERE statement, which means the PostgreSQL query is incorrect for the user question.
+        
+        and so on. You need to strictly examine whether the sql PostgreSQL query matches the user question.
+        
+        If the PostgreSQL query do not exactly matches the user question, grade it as incorrect. 
+        You need to strictly examine whether the sql PostgreSQL query matches the user question.
+        Give a binary score 'yes' or 'no' score to indicate whether the PostgreSQL query is correct to the question. \n
+        Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
+        <|eot_id|>
+        
+        <|start_header_id|>user<|end_header_id|>
+        Here is the PostgreSQL query: \n\n {sql_query} \n\n
+        Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+        """,
+        input_variables=["table_info", "question", "sql_query"],
+    )
+
+    sql_query_grader = prompt | llm_json | JsonOutputParser()
+    
+    return sql_query_grader
+
+### Router
+def Router():
+    prompt = PromptTemplate(
+        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
+        You are an expert at routing a user question to a vectorstore or company private data. 
+        Use company private data for questions about the informations about a company's greenhouse gas emissions data.
+        Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. 
+        You do not need to be stringent with the keywords in the question related to these topics. 
+        Give a binary choice 'company_private_data' or 'vectorstore' based on the question. 
+        Return the a JSON with a single key 'datasource' and no premable or explanation. 
+        
+        Question to route: {question} 
+        <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
+        input_variables=["question"],
+    )
+
+    question_router = prompt | llm_json | JsonOutputParser()
+    
+    return question_router
+
+class GraphState(TypedDict):
+    """
+    Represents the state of our graph.
+
+    Attributes:
+        question: question
+        generation: LLM generation
+        company_private_data: whether to search company private data
+        documents: list of documents
+    """
+
+    question: str
+    generation: str
+    documents: List[str]
+    retry: int
+    sql_query: str
+    
+# Node
+def retrieve_and_generation(state):
+    """
+    Retrieve documents from vectorstore
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): New key added to state, documents, that contains retrieved documents, and generation, genrating by LLM
+    """
+    print("---RETRIEVE---")
+    question = state["question"]
+
+    # Retrieval
+    # documents = retriever.invoke(question)
+    # TODO: correct Retrieval function
+    documents = retriever.get_relevant_documents(question, k=30)
+    # docs_documents = "\n\n".join(doc.page_content for doc in documents)
+    # print(documents)
+    generation = faiss_query(question, documents, llm)
+    return {"documents": documents, "question": question, "generation": generation}
+
+def company_private_data_get_sql_query(state):
+    """
+    Get PostgreSQL query according to question
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): return generated PostgreSQL query and record retry times
+    """
+    print("---SQL QUERY---")
+    question = state["question"]
+    
+    if state["retry"]:
+        retry = state["retry"]
+        retry += 1
+    else: 
+        retry = 0
+    # print("RETRY: ", retry)
+    
+    sql_query = _get_query(question)
+    
+    return {"sql_query": sql_query, "question": question, "retry": retry}
+    
+def company_private_data_search(state):
+    """
+    Execute PostgreSQL query and convert to nature language.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): Appended sql results to state
+    """
+
+    print("---SQL TO NL---")
+    # print(state)
+    question = state["question"]
+    sql_query = state["sql_query"]
+    generation = _query_to_nl(question, sql_query)
+    
+    # generation = [company_private_data_result]
+    
+    return {"sql_query": sql_query, "question": question, "generation": generation}
+
+### Conditional edge
+
+
+def route_question(state):
+    """
+    Route question to web search or RAG.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        str: Next node to call
+    """
+
+    print("---ROUTE QUESTION---")
+    question = state["question"]
+    # print(question)
+    question_router = Router()
+    source = question_router.invoke({"question": question})
+    # print(source)
+    print(source["datasource"])
+    if source["datasource"] == "company_private_data":
+        print("---ROUTE QUESTION TO TEXT-TO-SQL---")
+        return "company_private_data"
+    elif source["datasource"] == "vectorstore":
+        print("---ROUTE QUESTION TO RAG---")
+        return "vectorstore"
+    
+def grade_generation_v_documents_and_question(state):
+    """
+    Determines whether the generation is grounded in the document and answers question.
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        str: Decision for next node to call
+    """
+
+    print("---CHECK HALLUCINATIONS---")
+    question = state["question"]
+    documents = state["documents"]
+    generation = state["generation"]
+
+    
+    # print(docs_documents)
+    # print(generation)
+    hallucination_grader = Hallucination_Grader()
+    score = hallucination_grader.invoke(
+        {"documents": documents, "generation": generation}
+    )
+    # print(score)
+    grade = score["score"]
+
+    # Check hallucination
+    if grade in ["yes", "true", 1, "1"]:
+        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
+        # Check question-answering
+        print("---GRADE GENERATION vs QUESTION---")
+        answer_grader = Answer_Grader()
+        score = answer_grader.invoke({"question": question, "generation": generation})
+        grade = score["score"]
+        if grade in ["yes", "true", 1, "1"]:
+            print("---DECISION: GENERATION ADDRESSES QUESTION---")
+            return "useful"
+        else:
+            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
+            return "not useful"
+    else:
+        pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
+        return "not supported"
+    
+def grade_sql_query(state):
+    """
+    Determines whether the Postgresql query are correct to the question
+
+    Args:
+        state (dict): The current graph state
+
+    Returns:
+        state (dict): Decision for retry or continue
+    """
+
+    print("---CHECK SQL CORRECTNESS TO QUESTION---")
+    question = state["question"]
+    sql_query = state["sql_query"]
+    retry = state["retry"]
+
+    # Score each doc
+    sql_query_grader = SQL_Grader()
+    score = sql_query_grader.invoke({"table_info": table_description(), "question": question, "sql_query": sql_query})
+    grade = score["score"]
+    # Document relevant
+    if grade in ["yes", "true", 1, "1"]:
+        print("---GRADE: CORRECT SQL QUERY---")
+        return "correct"
+    elif retry >= 5:
+        print("---GRADE: INCORRECT SQL QUERY AND REACH RETRY LIMIT---")
+        return "failed"
+    else:
+        print("---GRADE: INCORRECT SQL QUERY---")
+        return "incorrect"
+
+def build_graph():
+    workflow = StateGraph(GraphState)
+
+    # Define the nodes
+    workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    
+    workflow.add_conditional_edges(
+        START,
+        route_question,
+        {
+            "company_private_data": "company_private_data_query",
+            "vectorstore": "retrieve_and_generation",
+        },
+    )
+
+    workflow.add_conditional_edges(
+        "retrieve_and_generation",
+        grade_generation_v_documents_and_question,
+        {
+            "not supported": "retrieve_and_generation",
+            "useful": END,
+            "not useful": "retrieve_and_generation",
+        },
+    )
+    workflow.add_conditional_edges(
+        "company_private_data_query",
+        grade_sql_query,
+        {
+            "correct": "company_private_data_search",
+            "incorrect": "company_private_data_query",
+            "failed": END
+            
+        },
+    )
+    workflow.add_edge("company_private_data_search", END)
+
+    app = workflow.compile()    
+    
+    return app
+
+def main():
+    app = build_graph()
+    #建準去年的類別一排放量?
+    inputs = {"question": "溫室氣體是什麼"}
+    for output in app.stream(inputs, {"recursion_limit": 10}):
+        for key, value in output.items():
+            pprint(f"Finished running: {key}:")
+    pprint(value["generation"])
+    
+    return value["generation"]
+
+if __name__ == "__main__":
+    result = main()
+    print("------------------------------------------------------")
+    print(result)

+ 1 - 1
faiss_index.py

@@ -413,7 +413,7 @@ if __name__ == "__main__":
         
 
     # Save results to a JSON file
-    with open('qa_results+all.json', 'w', encoding='utf8') as outfile:
+    with open('qa_results_all.json', 'w', encoding='utf8') as outfile:
         json.dump(results, outfile, indent=4, ensure_ascii=False)
 
     print('All questions done!')

Some files were not shown because too many files changed in this diff