Browse Source

Add a new node 'Additional Explanation' to the workflow and create new app file

ling 4 months ago
parent
commit
c70d6c8b9d
7 changed files with 522 additions and 121 deletions
  1. 9 2
      RAG_app.py
  2. BIN
      agent_workflow.png
  3. 244 94
      ai_agent.ipynb
  4. 65 25
      ai_agent.py
  5. 100 0
      post_processing_sqlparse.py
  6. 84 0
      systex_app.py
  7. 20 0
      text_to_sql.py

+ 9 - 2
RAG_app.py

@@ -22,6 +22,7 @@ import json
 from json import loads
 import pandas as pd
 import time
+from langchain_community.chat_models import ChatOllama
 from langchain.callbacks import get_openai_callback
 
 from langchain_community.vectorstores import Chroma
@@ -41,6 +42,7 @@ from supabase.client import Client, create_client
 from file_loader.add_vectordb import GetVectorStore
 from faiss_index import create_faiss_retriever, faiss_query
 from local_llm import ollama_, hf
+from ai_agent import main
 # from local_llm import ollama_, taide_llm, hf
 # llm = hf()
 
@@ -66,7 +68,9 @@ async def lifespan(app: FastAPI):
     # vector_store = GetVectorStore(embeddings, supabase, document_table)
     # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
     global_retriever = create_faiss_retriever()
-    llm = hf()
+    local_llm = "llama3-groq-tool-use:latest"
+    # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
+    llm = ChatOllama(model=local_llm, temperature=0)
 
     print(time.time() - start)
     yield
@@ -138,7 +142,10 @@ def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?',
     # print(response_content)
     # return json.loads(response_content)
 
-    
+@app.get("/agents")
+def agent(question: str):
+    answer = main(question)
+    return {"answer": answer}
 
 def save_history(question, answer, reference, cb, processing_time):
     # reference = [doc.dict() for doc in reference]

BIN
agent_workflow.png


File diff suppressed because it is too large
+ 244 - 94
ai_agent.ipynb


+ 65 - 25
ai_agent.py

@@ -34,6 +34,7 @@ retriever = create_faiss_retriever()
 
 # text-to-sql usage
 from text_to_sql2 import run, get_query, query_to_nl, table_description
+from post_processing_sqlparse import get_query_columns, parse_sql_for_stock_info, get_table_name
 
 
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
@@ -58,7 +59,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     {context}
 
     Question: {question}
-    用繁體中文
+    用繁體中文回答問題
     <|eot_id|>
     
     <|start_header_id|>assistant<|end_header_id|>
@@ -135,7 +136,21 @@ def _query_to_nl(question: str, query: str):
     answer = query_to_nl(db, question, query, llm)
     return  answer
 
-
+def generate_additional_detail(sql_query):
+    terms = parse_sql_for_stock_info(sql_query)
+    answer = ""
+    for term in terms:
+        if term is None: continue
+        question_format = [f"什麼是{term}?", f"{term}的用途是什麼", f"如何計算{term}?"]
+        for question in question_format:
+            # question = f"什麼是{term}?"
+            documents = retriever.get_relevant_documents(question, k=30)
+            generation = faiss_query(question, documents, llm)
+            answer += generation
+            answer += "\n"
+            # print(question)
+            # print(generation)
+    return answer
 ### SQL Grader
 
 def SQL_Grader():
@@ -190,11 +205,11 @@ def SQL_Grader():
 def Router():
     prompt = PromptTemplate(
         template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
-        You are an expert at routing a user question to a vectorstore or company private data. 
+        You are an expert at routing a user question to a 專業知識 or 自有數據. 
         Use company private data for questions about the informations about a company's greenhouse gas emissions data.
-        Otherwise, use the vectorstore for questions on ESG field knowledge or news about ESG. 
+        Otherwise, use the 專業知識 for questions on ESG field knowledge or news about ESG. 
         You do not need to be stringent with the keywords in the question related to these topics. 
-        Give a binary choice 'company_private_data' or 'vectorstore' based on the question. 
+        Give a binary choice '自有數據' or '專業知識' based on the question. 
         Return the a JSON with a single key 'datasource' and no premable or explanation. 
         
         Question to route: {question} 
@@ -291,6 +306,28 @@ def company_private_data_search(state):
     
     return {"sql_query": sql_query, "question": question, "generation": generation}
 
+def additional_explanation(state):
+    """
+    
+    Args:
+        state (_type_): _description_
+        
+    Returns:
+        state (dict): Appended additional explanation to state
+    """
+    
+    print("---ADDITIONAL EXPLANATION---")
+    print(state)
+    question = state["question"]
+    sql_query = state["sql_query"]
+    generation = state["generation"]
+    generation += "\n"
+    generation += generate_additional_detail(sql_query)
+    
+    # generation = [company_private_data_result]
+    
+    return {"sql_query": sql_query, "question": question, "generation": generation}
+
 ### Conditional edge
 
 
@@ -312,12 +349,12 @@ def route_question(state):
     source = question_router.invoke({"question": question})
     # print(source)
     print(source["datasource"])
-    if source["datasource"] == "company_private_data":
+    if source["datasource"] == "自有數據":
         print("---ROUTE QUESTION TO TEXT-TO-SQL---")
-        return "company_private_data"
-    elif source["datasource"] == "vectorstore":
+        return "自有數據"
+    elif source["datasource"] == "專業知識":
         print("---ROUTE QUESTION TO RAG---")
-        return "vectorstore"
+        return "專業知識"
     
 def grade_generation_v_documents_and_question(state):
     """
@@ -398,48 +435,51 @@ def build_graph():
     workflow = StateGraph(GraphState)
 
     # Define the nodes
-    workflow.add_node("company_private_data_query", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
-    workflow.add_node("company_private_data_search", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
-    workflow.add_node("retrieve_and_generation", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("Text-to-SQL", company_private_data_get_sql_query, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("SQL Answer", company_private_data_search, retry=RetryPolicy(max_attempts=5))  # web search
+    workflow.add_node("Additoinal Explanation", additional_explanation, retry=RetryPolicy(max_attempts=5))  # retrieve
+    workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     
     workflow.add_conditional_edges(
         START,
         route_question,
         {
-            "company_private_data": "company_private_data_query",
-            "vectorstore": "retrieve_and_generation",
+            "自有數據": "Text-to-SQL",
+            "專業知識": "RAG",
         },
     )
 
     workflow.add_conditional_edges(
-        "retrieve_and_generation",
+        "RAG",
         grade_generation_v_documents_and_question,
         {
-            "not supported": "retrieve_and_generation",
+            "not supported": "RAG",
             "useful": END,
-            "not useful": "retrieve_and_generation",
+            "not useful": "RAG",
         },
     )
     workflow.add_conditional_edges(
-        "company_private_data_query",
+        "Text-to-SQL",
         grade_sql_query,
         {
-            "correct": "company_private_data_search",
-            "incorrect": "company_private_data_query",
-            "failed": END
+            "correct": "SQL Answer",
+            "incorrect": "Text-to-SQL",
+            "failed": "RAG"
             
         },
     )
-    workflow.add_edge("company_private_data_search", END)
+    workflow.add_edge("SQL Answer", "Additoinal Explanation")
+    workflow.add_edge("Additoinal Explanation", END)
 
     app = workflow.compile()    
     
     return app
 
-def main():
+def main(question: str):
     app = build_graph()
     #建準去年的類別一排放量?
-    inputs = {"question": "溫室氣體是什麼"}
+    # inputs = {"question": "溫室氣體是什麼"}
+    inputs = {"question": question}
     for output in app.stream(inputs, {"recursion_limit": 10}):
         for key, value in output.items():
             pprint(f"Finished running: {key}:")
@@ -448,6 +488,6 @@ def main():
     return value["generation"]
 
 if __name__ == "__main__":
-    result = main()
+    result = main("建準去年的直接排放排放量?")
     print("------------------------------------------------------")
     print(result)

+ 100 - 0
post_processing_sqlparse.py

@@ -0,0 +1,100 @@
+
+import sqlparse
+from sqlparse.sql import Comparison, Parenthesis, Token
+from sqlparse.tokens import Literal
+from langchain_community.utilities import SQLDatabase
+
+from dotenv import load_dotenv
+import os
+load_dotenv()
+# URI = os.getenv("SUPABASE_URI")
+# db = SQLDatabase.from_uri(URI, sample_rows_in_table_info=5)
+
+def get_query_columns(sql, get_real_name=False):
+    stmt = sqlparse.parse(sql)[0]
+    columns = []
+    column_identifiers = []
+
+    # get column_identifieres
+    in_select = False
+    for token in stmt.tokens:
+        if isinstance(token, sqlparse.sql.Comment):
+            continue
+        if str(token).lower() == 'select':
+            in_select = True
+        elif in_select and token.ttype is None:
+            if isinstance(token, sqlparse.sql.IdentifierList):
+                for identifier in token.get_identifiers():
+                    column_identifiers.append(identifier)
+            else:
+                column_identifiers.append(token)
+
+            break
+
+    # get column names
+    for column_identifier in column_identifiers:
+        if get_real_name:
+            columns.append(column_identifier.get_real_name())
+        else:
+            columns.append(column_identifier.get_name())
+
+    return columns
+
+
+def extract_comparison_value(tokens, target):
+    """Helper function to extract value based on a comparison target."""
+    is_target = False
+    for token in tokens:
+        if token.value.strip("'\"") == target:
+            is_target = True
+        elif is_target and token.ttype is Literal.String.Single:
+            return token.value.strip("'\"")
+        elif is_target and isinstance(token, Parenthesis):
+            data = db.run(token.value.strip("()"))
+            return eval(data)[0][0]
+    return None
+
+def parse_sql_for_stock_info(sql):
+    """Parse the SQL statement to extract 排放源, 類別"""
+    stmt = sqlparse.parse(sql)[0]
+    emission, class_type = None, None
+    
+    for token in stmt.tokens:
+        if isinstance(token, sqlparse.sql.Comment):
+            continue
+        if token.value.lower().startswith('where'):
+            for token2 in token.tokens:
+                if isinstance(token2, Comparison):
+                    if emission is None:
+                        emission = extract_comparison_value(token2.tokens, "排放源")
+                    if class_type is None:
+                        class_type = extract_comparison_value(token2.tokens, "類別")
+    return emission, class_type
+
+def get_table_name(sql):
+    stmt = sqlparse.parse(sql)[0]
+
+    in_from = False
+    for token in stmt.tokens:
+        if isinstance(token, sqlparse.sql.Comment):
+            continue
+        if str(token).lower() == 'from':
+            in_from = True
+        elif in_from and token.ttype is None:
+            if isinstance(token, sqlparse.sql.Identifier):
+                # print(token, token.ttype)
+                return token.value
+
+if __name__ == "__main__":
+    sql_query = """
+        SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+        FROM "104_112碳排放公開及建準資料"
+        WHERE "事業名稱" like '%建準%'
+        AND "排放源" = '固定燃燒'
+        AND "盤查標準" = 'GHG'
+        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;
+        """
+        
+    print(get_query_columns(sql_query, get_real_name=True))
+    print(parse_sql_for_stock_info(sql_query))
+    print(get_table_name(sql_query))

+ 84 - 0
systex_app.py

@@ -0,0 +1,84 @@
+import datetime
+from json import loads
+import time
+from typing import List
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+
+import pandas as pd
+from pydantic import BaseModel
+import uvicorn
+
+from dotenv import load_dotenv
+import os
+from supabase.client import Client, create_client
+
+from langchain.callbacks import get_openai_callback
+
+from ai_agent import main
+
+load_dotenv()
+URI = os.getenv("SUPABASE_URI")
+supabase_url = os.environ.get("SUPABASE_URL")
+supabase_key = os.environ.get("SUPABASE_KEY")
+supabase: Client = create_client(supabase_url, supabase_key)
+
+
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+@app.get("/agents")
+def agent(question: str):
+    start = time.time()
+    with get_openai_callback() as cb:
+        answer = main(question)
+    processing_time = time.time() - start
+    save_history(question, answer, cb, processing_time)
+    return {"answer": answer}  
+
+def save_history(question, answer, cb, processing_time):
+    # reference = [doc.dict() for doc in reference]
+    record = {
+        'Question': question,
+        'Answer': answer,
+        'Total_Tokens': cb.total_tokens,
+        'Total_Cost': cb.total_cost,
+        'Processing_time': processing_time,
+    }
+    response = (
+        supabase.table("agent_records")
+        .insert(record)
+        .execute()
+    )
+
+class history_output(BaseModel):
+    Question: str
+    Answer: str
+    Total_Tokens: int
+    Total_Cost: float
+    Processing_time: float
+    Time: datetime.datetime
+    
+@app.get('/history', response_model=List[history_output])
+async def get_history():
+    response = supabase.table("agent_records").select("*").execute()
+    df = pd.DataFrame(response.data)
+
+    # engine = create_engine(URI, echo=True)
+
+    # df = pd.read_sql_table("systex_records", engine.connect())  
+    # df.fillna('', inplace=True)
+    result = df.to_json(orient='index', force_ascii=False)
+    result = loads(result)
+    return result.values()  
+
+if __name__ == "__main__":
+    
+    uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080)
+

+ 20 - 0
text_to_sql.py

@@ -218,6 +218,26 @@ def sql_to_nl_chain(llm):
 
     return chain
 
+def get_query(db, question, selected_table, llm):
+    write_query = write_query_chain(db, llm)
+    query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
+    
+    query = re.split('SQL query: ', query)[-1]
+    query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
+    print(query)
+    
+    return query
+
+def query_to_nl(db, question, query, llm):
+    execute_query = QuerySQLDataBaseTool(db=db)
+    result = execute_query.invoke(query)
+    print(result)
+
+    chain = sql_to_nl_chain(llm)
+    answer = chain.invoke({"question": question, "query": query, "result": result})
+
+    return answer
+
 def run(db, question, selected_table, llm):
 
     write_query = write_query_chain(db, llm)

Some files were not shown because too many files changed in this diff