Browse Source

adjust to gpt4o version, adjust agent flow, add chat history

ling 4 months ago
parent
commit
eac4569456
5 changed files with 102 additions and 31 deletions
  1. 14 2
      RAG_strategy.py
  2. 69 25
      ai_agent.py
  3. 4 0
      faiss_index.py
  4. 7 3
      rewrite_question.py
  5. 8 1
      systex_app.py

+ 14 - 2
RAG_strategy.py

@@ -36,16 +36,28 @@ load_dotenv()
 
 def multi_query_chain(llm):
     # Multi Query: Different Perspectives
-    template = """You are an AI language model assistant. Your task is to generate three 
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,請用繁體中文回答
+    You are an AI language model assistant. Your task is to generate three 
     different versions of the given user question to retrieve relevant documents from a vector 
     database. By generating multiple perspectives on the user question, your goal is to help
     the user overcome some of the limitations of the distance-based similarity search. 
     Provide these alternative questions separated by newlines. 
 
     You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
     
+    Original question: {question}
+    請用繁體中文
+    <|eot_id|>
     
-    Original question: {question}"""
+    <|start_header_id|>assistant<|end_header_id|>
+    """
     prompt_perspectives = ChatPromptTemplate.from_template(template)
 
     

+ 69 - 25
ai_agent.py

@@ -24,20 +24,29 @@ db = SQLDatabase.from_uri(URI)
 
 # LLM
 # local_llm = "llama3.1:8b-instruct-fp16"
+# local_llm = "llama3.1:8b-instruct-q2_K"
 local_llm = "llama3-groq-tool-use:latest"
 llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
-llm = ChatOllama(model=local_llm, temperature=0)
+# llm = ChatOllama(model=local_llm, temperature=0)
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
 
 # RAG usage
-from faiss_index import create_faiss_retriever, faiss_query
+from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
 retriever = create_faiss_retriever()
 
 # text-to-sql usage
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 progress_bar = []
+
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
-    
+    if multi_query:
+        docs = faiss_multiquery(question, retriever, llm)
+        # print(docs)
+    else:
+        docs = retriever.get_relevant_documents(question, k=10)
+        # print(docs)
     context = docs
     
     system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
@@ -47,7 +56,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     <|start_header_id|>system<|end_header_id|>
     你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
     You should not mention anything about "根據提供的文件內容" or other similar terms.
-    Use five sentences maximum and keep the answer concise.
+    請盡可能的詳細回答問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     勿回答無關資訊
     <|eot_id|>
@@ -58,8 +67,9 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     {context}
 
     Question: {question}
-    用繁體中文回答問題
+    用繁體中文回答問題,請用一段話詳細的回答。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    
     <|eot_id|>
     
     <|start_header_id|>assistant<|end_header_id|>
@@ -129,10 +139,14 @@ def run_text_to_sql(question: str):
 
 def _get_query(question: str):
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
     query, result = get_query(db, question, selected_table, llm)
     return  query, result
 
 def _query_to_nl(question: str, query: str, result):
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
     answer = query_to_nl(question, query, result, llm)
     return  answer
 
@@ -150,19 +164,24 @@ def generate_additional_question(sql_query):
 def generate_additional_detail(sql_query):
     terms = parse_sql_where(sql_query)
     answer = ""
+    all_documents = []
     for term in list(set(terms)):
         if term is None: continue
-        question_format = [f"請解釋什麼是{term}?"]
+        question_format = [f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
         for question in question_format:
             # question = f"什麼是{term}?"
             documents = retriever.get_relevant_documents(question, k=5)
-            generation = faiss_query(question, documents, llm) + "\n"
-            if "test@systex.com" in generation:
-                generation = ""
-            
-            answer += generation
-            # print(question)
-            # print(generation)
+            all_documents.extend(documents)
+            # for doc in documents:
+            #     print(doc)
+        all_question = "\n".join(question_format)
+        generation = faiss_query(all_question, all_documents, llm, multi_query=True) + "\n"
+        if "test@systex.com" in generation:
+            generation = ""
+        
+        answer += generation
+        # print(question)
+        # print(generation)
     return answer
 ### SQL Grader
 
@@ -177,7 +196,7 @@ def SQL_Grader():
         
         You need to check that each where statement is correctly filtered out what user question need.
         
-        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
         FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%建準%'
@@ -293,19 +312,21 @@ def retrieve_and_generation(state):
         # documents = retriever.invoke(question)
         # TODO: correct Retrieval function
         documents = retriever.get_relevant_documents(question, k=5)
-        for doc in documents:
-            print(doc)
+        # for doc in documents:
+        #     print(doc)
             
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # print(documents)
-        generation = faiss_query(question, documents, llm)
+        generation = faiss_query(question, documents, llm, multi_query=True)
     else:
         generation = state["generation"]
         
         for sub_question in list(set(question_list)):
             print(sub_question)
-            documents = retriever.get_relevant_documents(sub_question, k=10)
-            generation += faiss_query(sub_question, documents, llm)
+            documents = retriever.get_relevant_documents(sub_question, k=5)
+            # for doc in documents:
+            #     print(doc)
+            generation += faiss_query(sub_question, documents, llm, multi_query=True)
             generation += "\n"
             
     print(generation)
@@ -513,10 +534,11 @@ def grade_sql_query(state):
     question = state["question"]
     sql_query = state["sql_query"]
     sql_result = state["sql_result"]
-    if "None" in sql_result:
+    if "None" in sql_result or sql_result.startswith("Error:"):
         progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         return "incorrect"
     else:
+        print(sql_result)
         progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
         return "correct"
     # retry = state["retry"]
@@ -540,7 +562,16 @@ def grade_sql_query(state):
     #     # print("---GRADE: INCORRECT SQL QUERY---")
     #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
     #     return "incorrect"
-
+def check_sql_answer(state):
+    progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
+    generation = state["generation"]
+    if "test@systex.com" in generation:
+        progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
+        return "bad"
+    else:
+        progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
+        return "good"
+    
 def build_graph():
     workflow = StateGraph(GraphState)
 
@@ -550,7 +581,7 @@ def build_graph():
     workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("ERROR", error)  # retrieve
-    
+    company_private_data_search
     workflow.add_conditional_edges(
         START,
         route_question,
@@ -577,7 +608,17 @@ def build_graph():
             
         },
     )
-    workflow.add_edge("SQL Answer", "Additoinal Explanation")
+    workflow.add_conditional_edges(
+        "SQL Answer",
+        check_sql_answer,
+        {
+            "good": "Additoinal Explanation",
+            "bad": "RAG",
+            
+        },
+    )
+    
+    # workflow.add_edge("SQL Answer", "Additoinal Explanation")
     workflow.add_edge("Additoinal Explanation", END)
 
     app = workflow.compile()    
@@ -604,11 +645,14 @@ def main(question: str):
     value["progress_bar"] = progress_bar
     # pprint(value["progress_bar"])
     
-    return value["generation"]
+    # return value["generation"]
+    return value
 
 if __name__ == "__main__":
     # result = main("建準去年的逸散排放總排放量是多少?")
-    result = main("建準夏威夷去年的綠電使用量是多少?")
+    # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
+    
+    result = main("建準北海廠去年的固定燃燒排放量是多少?")
     # result = main("溫室氣體是什麼?")
     # result = main("什麼是外購電力(綠電)?")
     print("------------------------------------------------------")

+ 4 - 0
faiss_index.py

@@ -54,6 +54,7 @@ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
 
 def download_embeddings():
     response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
+    # response = supabase.table(document_table).select("id, embedding, metadata, content").eq('metadata ->> source', 'supplement.docx').execute()
     embeddings = []
     ids = []
     metadatas = []
@@ -149,6 +150,9 @@ def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
 
     questions = generate_queries.invoke(question)
     questions = [item for item in questions if item != ""]
+    questions.append(question)
+    for q in questions:
+        print(q)
 
     # docs = list(map(retriever.get_relevant_documents, questions))
     docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))

+ 7 - 3
rewrite_question.py

@@ -14,9 +14,13 @@ from langchain_core.runnables import (
 from typing import Tuple, List, Optional
 from langchain_core.messages import AIMessage, HumanMessage
 
-local_llm = "llama3-groq-tool-use:latest"
+# local_llm = "llama3-groq-tool-use:latest"
+# llm = ChatOllama(model=local_llm, temperature=0)
 # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
-llm = ChatOllama(model=local_llm, temperature=0)
+from dotenv import load_dotenv
+load_dotenv()
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
 
 def get_search_query():
     # Condense a chat history and follow-up question into a standalone question
@@ -105,6 +109,6 @@ if __name__ == "__main__":
     chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
     print(chat_history)
     
-    question = "類別2呢"
+    question = "廣興廠"
     modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     print(modified_question)

+ 8 - 1
systex_app.py

@@ -17,6 +17,7 @@ from langchain.callbacks import get_openai_callback
 
 from ai_agent import main
 from semantic_search import semantic_cache
+from RAG_strategy import get_search_query
 
 load_dotenv()
 URI = os.getenv("SUPABASE_URI")
@@ -41,6 +42,10 @@ class ChatHistoryItem(BaseModel):
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
     print(question)
     start = time.time()
+    # TODO rewrite query
+    # _search_query = get_search_query()
+    # chat_history = [(item.q, item.a) for item in chat_history[-5:] if item.a != "" and item.a != "string"]
+    # modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     
     with get_openai_callback() as cb:
         # cache_question, cache_answer = semantic_cache(supabase, question)
@@ -48,8 +53,10 @@ def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
         if cache_answer:
             answer = cache_answer
         else:
-            answer = main(question)
+            result = main(question)
+            answer = result["generation"]
     processing_time = time.time() - start
+    # save_history(question + "->" + modified_question, answer, cb, processing_time)
     save_history(question, answer, cb, processing_time)
     if "test@systex.com" in answer:
         answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"