Browse Source

Add a new node 'Additional Explanation' to the workflow and create new app file

ling 4 tháng trước cách đây
mục cha
commit
c70d6c8b9d
7 tập tin đã thay đổi với 522 bổ sung121 xóa
  1. 9 2
      RAG_app.py
  2. BIN
      agent_workflow.png
  3. 244 94
      ai_agent.ipynb
  4. 65 25
      ai_agent.py
  5. 100 0
      post_processing_sqlparse.py
  6. 84 0
      systex_app.py
  7. 20 0
      text_to_sql.py

+ 9 - 2
RAG_app.py

@@ -22,6 +22,7 @@ import json
 from json import loads
 from json import loads
 import pandas as pd
 import pandas as pd
 import time
 import time
+from langchain_community.chat_models import ChatOllama
 from langchain.callbacks import get_openai_callback
 from langchain.callbacks import get_openai_callback
 
 
 from langchain_community.vectorstores import Chroma
 from langchain_community.vectorstores import Chroma
@@ -41,6 +42,7 @@ from supabase.client import Client, create_client
 from file_loader.add_vectordb import GetVectorStore
 from file_loader.add_vectordb import GetVectorStore
 from faiss_index import create_faiss_retriever, faiss_query
 from faiss_index import create_faiss_retriever, faiss_query
 from local_llm import ollama_, hf
 from local_llm import ollama_, hf
+from ai_agent import main
 # from local_llm import ollama_, taide_llm, hf
 # from local_llm import ollama_, taide_llm, hf
 # llm = hf()
 # llm = hf()
 
 
@@ -66,7 +68,9 @@ async def lifespan(app: FastAPI):
     # vector_store = GetVectorStore(embeddings, supabase, document_table)
     # vector_store = GetVectorStore(embeddings, supabase, document_table)
     # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
     # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
     global_retriever = create_faiss_retriever()
     global_retriever = create_faiss_retriever()
-    llm = hf()
+    local_llm = "llama3-groq-tool-use:latest"
+    # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+    llm = ChatOllama(model=local_llm, temperature=0)
 
 
     print(time.time() - start)
     print(time.time() - start)
     yield
     yield
@@ -138,7 +142,10 @@ def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?',
     # print(response_content)
     # print(response_content)
     # return json.loads(response_content)
     # return json.loads(response_content)
 
 
-    
+@app.get("/agents")
+def agent(question: str):
+    answer = main(question)
+    return {"answer": answer}
 
 
 def save_history(question, answer, reference, cb, processing_time):
 def save_history(question, answer, reference, cb, processing_time):
     # reference = [doc.dict() for doc in reference]
     # reference = [doc.dict() for doc in reference]

BIN
agent_workflow.png


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 244 - 94
ai_agent.ipynb


+ 65 - 25
ai_agent.py

@@ -34,6 +34,7 @@ retriever = create_faiss_retriever()
 
 
 # text-to-sql usage
 # text-to-sql usage
 from text_to_sql2 import run, get_query, query_to_nl, table_description
 from text_to_sql2 import run, get_query, query_to_nl, table_description
+from post_processing_sqlparse import get_query_columns, parse_sql_for_stock_info, get_table_name
 
 
 
 
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
@@ -58,7 +59,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     {context}
     {context}
 
 
     Question: {question}
     Question: {question}
-    用繁體中文
+    用繁體中文回答問題
     <|eot_id|>
     <|eot_id|>
     
     
     <|start_header_id|>assistant<|end_header_id|>
     <|start_header_id|>assistant<|end_header_id|>
@@ -135,7 +136,21 @@ def _query_to_nl(question: str, query: str):
     answer = query_to_nl(db, question, query, llm)
     answer = query_to_nl(db, question, query, llm)
     return  answer
     return  answer
 
 
-
+def generate_additional_detail(sql_query):
+    terms = parse_sql_for_stock_info(sql_query)
+    answer = ""
+    for term in terms:
+        if term is None: continue
+        question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
+        for question in question_format:
+            # question = f"什麼是{term}?"
+            documents = retriever.get_relevant_documents(question, k=30)
+            generation = faiss_query(question, documents, llm)
+            answer += generation
+            answer += "\n"
+            # print(question)
+            # print(generation)
+    return answer
 ### SQL Grader
 ### SQL Grader
 
 
 def SQL_Grader():
 def SQL_Grader():
@@ -190,11 +205,11 @@ def SQL_Grader():
 def Router():
 def Router():
     prompt = PromptTemplate(
     prompt = PromptTemplate(
         template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
         template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
-        You are an expert at routing a user question to a vectorstore or company private data. 
+        You are an expert at routing a user question to a 專業知識 or 自有數據. 
         Use company private data for questions about the informations about a company's greenhouse gas emissions data.
         Use company private data for questions about the informations about a company's greenhouse gas emissions data.
-        Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. 
+        Otherwise, use the 專業知識 for questions on ESG field knowledge or news about ESG. 
         You do not need to be stringent with the keywords in the question related to these topics. 
         You do not need to be stringent with the keywords in the question related to these topics. 
-        Give a binary choice 'company_private_data' or 'vectorstore' based on the question. 
+        Give a binary choice '自有數據' or '專業知識' based on the question. 
         Return the a JSON with a single key 'datasource' and no premable or explanation. 
         Return the a JSON with a single key 'datasource' and no premable or explanation. 
         
         
         Question to route: {question} 
         Question to route: {question} 
@@ -291,6 +306,28 @@ def company_private_data_search(state):
     
     
     return {"sql_query": sql_query, "question": question, "generation": generation}
     return {"sql_query": sql_query, "question": question, "generation": generation}
 
 
+def additional_explanation(state):
+    """
+    
+    Args:
+        state (_type_): _description_
+        
+    Returns:
+        state (dict): Appended additional explanation to state
+    """
+    
+    print("---ADDITIONAL EXPLANATION---")
+    print(state)
+    question = state["question"]
+    sql_query = state["sql_query"]
+    generation = state["generation"]
+    generation += "\n"
+    generation += generate_additional_detail(sql_query)
+    
+    # generation = [company_private_data_result]
+    
+    return {"sql_query": sql_query, "question": question, "generation": generation}
+
 ### Conditional edge
 ### Conditional edge
 
 
 
 
@@ -312,12 +349,12 @@ def route_question(state):
     source = question_router.invoke({"question": question})
     source = question_router.invoke({"question": question})
     # print(source)
     # print(source)
     print(source["datasource"])
     print(source["datasource"])
-    if source["datasource"] == "company_private_data":
+    if source["datasource"] == "自有數據":
         print("---ROUTE QUESTION TO TEXT-TO-SQL---")
         print("---ROUTE QUESTION TO TEXT-TO-SQL---")
-        return "company_private_data"
-    elif source["datasource"] == "vectorstore":
+        return "自有數據"
+    elif source["datasource"] == "專業知識":
         print("---ROUTE QUESTION TO RAG---")
         print("---ROUTE QUESTION TO RAG---")
-        return "vectorstore"
+        return "專業知識"
     
     
 def grade_generation_v_documents_and_question(state):
 def grade_generation_v_documents_and_question(state):
     """
     """
@@ -398,48 +435,51 @@ def build_graph():
     workflow = StateGraph(GraphState)
     workflow = StateGraph(GraphState)
 
 
     # Define the nodes
     # Define the nodes
-    workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
-    workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
-    workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("Text-to-SQL", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("Additoinal Explanation", additional_explanation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     
     
     workflow.add_conditional_edges(
     workflow.add_conditional_edges(
         START,
         START,
         route_question,
         route_question,
         {
         {
-            "company_private_data": "company_private_data_query",
-            "vectorstore": "retrieve_and_generation",
+            "自有數據": "Text-to-SQL",
+            "專業知識": "RAG",
         },
         },
     )
     )
 
 
     workflow.add_conditional_edges(
     workflow.add_conditional_edges(
-        "retrieve_and_generation",
+        "RAG",
         grade_generation_v_documents_and_question,
         grade_generation_v_documents_and_question,
         {
         {
-            "not supported": "retrieve_and_generation",
+            "not supported": "RAG",
             "useful": END,
             "useful": END,
-            "not useful": "retrieve_and_generation",
+            "not useful": "RAG",
         },
         },
     )
     )
     workflow.add_conditional_edges(
     workflow.add_conditional_edges(
-        "company_private_data_query",
+        "Text-to-SQL",
         grade_sql_query,
         grade_sql_query,
         {
         {
-            "correct": "company_private_data_search",
-            "incorrect": "company_private_data_query",
-            "failed": END
+            "correct": "SQL Answer",
+            "incorrect": "Text-to-SQL",
+            "failed": "RAG"
             
             
         },
         },
     )
     )
-    workflow.add_edge("company_private_data_search", END)
+    workflow.add_edge("SQL Answer", "Additoinal Explanation")
+    workflow.add_edge("Additoinal Explanation", END)
 
 
     app = workflow.compile()    
     app = workflow.compile()    
     
     
     return app
     return app
 
 
-def main():
+def main(question: str):
     app = build_graph()
     app = build_graph()
     #建準去年的類別一排放量?
     #建準去年的類別一排放量?
-    inputs = {"question": "溫室氣體是什麼"}
+    # inputs = {"question": "溫室氣體是什麼"}
+    inputs = {"question": question}
     for output in app.stream(inputs, {"recursion_limit": 10}):
     for output in app.stream(inputs, {"recursion_limit": 10}):
         for key, value in output.items():
         for key, value in output.items():
             pprint(f"Finished running: {key}:")
             pprint(f"Finished running: {key}:")
@@ -448,6 +488,6 @@ def main():
     return value["generation"]
     return value["generation"]
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    result = main()
+    result = main("建準去年的直接排放排放量?")
     print("------------------------------------------------------")
     print("------------------------------------------------------")
     print(result)
     print(result)

+ 100 - 0
post_processing_sqlparse.py

@@ -0,0 +1,100 @@
+
+import sqlparse
+from sqlparse.sql import Comparison, Parenthesis, Token
+from sqlparse.tokens import Literal
+from langchain_community.utilities import SQLDatabase
+
+from dotenv import load_dotenv
+import os
+load_dotenv()
+# URI = os.getenv("SUPABASE_URI")
+# db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
+
+def get_query_columns(sql, get_real_name=False):
+    stmt = sqlparse.parse(sql)[0]
+    columns = []
+    column_identifiers = []
+
+    # get column_identifieres
+    in_select = False
+    for token in stmt.tokens:
+        if isinstance(token, sqlparse.sql.Comment):
+            continue
+        if str(token).lower() == 'select':
+            in_select = True
+        elif in_select and token.ttype is None:
+            if isinstance(token, sqlparse.sql.IdentifierList):
+                for identifier in token.get_identifiers():
+                    column_identifiers.append(identifier)
+            else:
+                column_identifiers.append(token)
+
+            break
+
+    # get column names
+    for column_identifier in column_identifiers:
+        if get_real_name:
+            columns.append(column_identifier.get_real_name())
+        else:
+            columns.append(column_identifier.get_name())
+
+    return columns
+
+
+def extract_comparison_value(tokens, target):
+    """Helper function to extract value based on a comparison target."""
+    is_target = False
+    for token in tokens:
+        if token.value.strip("'\"") == target:
+            is_target = True
+        elif is_target and token.ttype is Literal.String.Single:
+            return token.value.strip("'\"")
+        elif is_target and isinstance(token, Parenthesis):
+            data = db.run(token.value.strip("()"))
+            return eval(data)[0][0]
+    return None
+
+def parse_sql_for_stock_info(sql):
+    """Parse the SQL statement to extract 排放源, 類別"""
+    stmt = sqlparse.parse(sql)[0]
+    emission, class_type = None, None
+    
+    for token in stmt.tokens:
+        if isinstance(token, sqlparse.sql.Comment):
+            continue
+        if token.value.lower().startswith('where'):
+            for token2 in token.tokens:
+                if isinstance(token2, Comparison):
+                    if emission is None:
+                        emission = extract_comparison_value(token2.tokens, "排放源")
+                    if class_type is None:
+                        class_type = extract_comparison_value(token2.tokens, "類別")
+    return emission, class_type
+
+def get_table_name(sql):
+    stmt = sqlparse.parse(sql)[0]
+
+    in_from = False
+    for token in stmt.tokens:
+        if isinstance(token, sqlparse.sql.Comment):
+            continue
+        if str(token).lower() == 'from':
+            in_from = True
+        elif in_from and token.ttype is None:
+            if isinstance(token, sqlparse.sql.Identifier):
+                # print(token, token.ttype)
+                return token.value
+
+if __name__ == "__main__":
+    sql_query = """
+        SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+        FROM "104_112碳排放公開及建準資料"
+        WHERE "事業名稱" like '%建準%'
+        AND "排放源" = '固定燃燒'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
+        """
+        
+    print(get_query_columns(sql_query, get_real_name=True))
+    print(parse_sql_for_stock_info(sql_query))
+    print(get_table_name(sql_query))

+ 84 - 0
systex_app.py

@@ -0,0 +1,84 @@
+import datetime
+from json import loads
+import time
+from typing import List
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+
+import pandas as pd
+from pydantic import BaseModel
+import uvicorn
+
+from dotenv import load_dotenv
+import os
+from supabase.client import Client, create_client
+
+from langchain.callbacks import get_openai_callback
+
+from ai_agent import main
+
+load_dotenv()
+URI = os.getenv("SUPABASE_URI")
+supabase_url = os.environ.get("SUPABASE_URL")
+supabase_key = os.environ.get("SUPABASE_KEY")
+supabase: Client = create_client(supabase_url, supabase_key)
+
+
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+@app.get("/agents")
+def agent(question: str):
+    start = time.time()
+    with get_openai_callback() as cb:
+        answer = main(question)
+    processing_time = time.time() - start
+    save_history(question, answer, cb, processing_time)
+    return {"answer": answer}  
+
+def save_history(question, answer, cb, processing_time):
+    # reference = [doc.dict() for doc in reference]
+    record = {
+        'Question': question,
+        'Answer': answer,
+        'Total_Tokens': cb.total_tokens,
+        'Total_Cost': cb.total_cost,
+        'Processing_time': processing_time,
+    }
+    response = (
+        supabase.table("agent_records")
+        .insert(record)
+        .execute()
+    )
+
+class history_output(BaseModel):
+    Question: str
+    Answer: str
+    Total_Tokens: int
+    Total_Cost: float
+    Processing_time: float
+    Time: datetime.datetime
+    
+@app.get('/history', response_model=List[history_output])
+async def get_history():
+    response = supabase.table("agent_records").select("*").execute()
+    df = pd.DataFrame(response.data)
+
+    # engine = create_engine(URI, echo=True)
+
+    # df = pd.read_sql_table("systex_records", engine.connect())  
+    # df.fillna('', inplace=True)
+    result = df.to_json(orient='index', force_ascii=False)
+    result = loads(result)
+    return result.values()  
+
+if __name__ == "__main__":
+    
+    uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080)
+

+ 20 - 0
text_to_sql.py

@@ -218,6 +218,26 @@ def sql_to_nl_chain(llm):
 
 
     return chain
     return chain
 
 
+def get_query(db, question, selected_table, llm):
+    write_query = write_query_chain(db, llm)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
+    
+    query = re.split('SQL query: ', query)[-1]
+    query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
+    print(query)
+    
+    return query
+
+def query_to_nl(db, question, query, llm):
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain(llm)
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return answer
+
 def run(db, question, selected_table, llm):
 def run(db, question, selected_table, llm):
 
 
     write_query = write_query_chain(db, llm)
     write_query = write_query_chain(db, llm)

Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác