Explorar el Código

adjust to gpt4o version, adjust agent flow, add chat history

ling hace 4 meses
padre
commit
eac4569456
Se han modificado 5 ficheros con 102 adiciones y 31 borrados
  1. 14 2
      RAG_strategy.py
  2. 69 25
      ai_agent.py
  3. 4 0
      faiss_index.py
  4. 7 3
      rewrite_question.py
  5. 8 1
      systex_app.py

+ 14 - 2
RAG_strategy.py

@@ -36,16 +36,28 @@ load_dotenv()
 
 
 def multi_query_chain(llm):
 def multi_query_chain(llm):
     # Multi Query: Different Perspectives
     # Multi Query: Different Perspectives
-    template = """You are an AI language model assistant. Your task is to generate three 
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,請用繁體中文回答
+    You are an AI language model assistant. Your task is to generate three 
     different versions of the given user question to retrieve relevant documents from a vector 
     different versions of the given user question to retrieve relevant documents from a vector 
     database. By generating multiple perspectives on the user question, your goal is to help
     database. By generating multiple perspectives on the user question, your goal is to help
     the user overcome some of the limitations of the distance-based similarity search. 
     the user overcome some of the limitations of the distance-based similarity search. 
     Provide these alternative questions separated by newlines. 
     Provide these alternative questions separated by newlines. 
 
 
     You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
     You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
     
     
+    Original question: {question}
+    請用繁體中文
+    <|eot_id|>
     
     
-    Original question: {question}"""
+    <|start_header_id|>assistant<|end_header_id|>
+    """
     prompt_perspectives = ChatPromptTemplate.from_template(template)
     prompt_perspectives = ChatPromptTemplate.from_template(template)
 
 
     
     

+ 69 - 25
ai_agent.py

@@ -24,20 +24,29 @@ db = SQLDatabase.from_uri(URI)
 
 
 # LLM
 # LLM
 # local_llm = "llama3.1:8b-instruct-fp16"
 # local_llm = "llama3.1:8b-instruct-fp16"
+# local_llm = "llama3.1:8b-instruct-q2_K"
 local_llm = "llama3-groq-tool-use:latest"
 local_llm = "llama3-groq-tool-use:latest"
 llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
 llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
-llm = ChatOllama(model=local_llm, temperature=0)
+# llm = ChatOllama(model=local_llm, temperature=0)
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
 
 
 # RAG usage
 # RAG usage
-from faiss_index import create_faiss_retriever, faiss_query
+from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
 retriever = create_faiss_retriever()
 retriever = create_faiss_retriever()
 
 
 # text-to-sql usage
 # text-to-sql usage
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 progress_bar = []
 progress_bar = []
+
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
-    
+    if multi_query:
+        docs = faiss_multiquery(question, retriever, llm)
+        # print(docs)
+    else:
+        docs = retriever.get_relevant_documents(question, k=10)
+        # print(docs)
     context = docs
     context = docs
     
     
     system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
     system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
@@ -47,7 +56,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     <|start_header_id|>system<|end_header_id|>
     <|start_header_id|>system<|end_header_id|>
     你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
     你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
     You should not mention anything about "根據提供的文件內容" or other similar terms.
     You should not mention anything about "根據提供的文件內容" or other similar terms.
-    Use five sentences maximum and keep the answer concise.
+    請盡可能的詳細回答問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     勿回答無關資訊
     勿回答無關資訊
     <|eot_id|>
     <|eot_id|>
@@ -58,8 +67,9 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     {context}
     {context}
 
 
     Question: {question}
     Question: {question}
-    用繁體中文回答問題
+    用繁體中文回答問題,請用一段話詳細的回答。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    
     <|eot_id|>
     <|eot_id|>
     
     
     <|start_header_id|>assistant<|end_header_id|>
     <|start_header_id|>assistant<|end_header_id|>
@@ -129,10 +139,14 @@ def run_text_to_sql(question: str):
 
 
 def _get_query(question: str):
 def _get_query(question: str):
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
     query, result = get_query(db, question, selected_table, llm)
     query, result = get_query(db, question, selected_table, llm)
     return  query, result
     return  query, result
 
 
 def _query_to_nl(question: str, query: str, result):
 def _query_to_nl(question: str, query: str, result):
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
     answer = query_to_nl(question, query, result, llm)
     answer = query_to_nl(question, query, result, llm)
     return  answer
     return  answer
 
 
@@ -150,19 +164,24 @@ def generate_additional_question(sql_query):
 def generate_additional_detail(sql_query):
 def generate_additional_detail(sql_query):
     terms = parse_sql_where(sql_query)
     terms = parse_sql_where(sql_query)
     answer = ""
     answer = ""
+    all_documents = []
     for term in list(set(terms)):
     for term in list(set(terms)):
         if term is None: continue
         if term is None: continue
-        question_format = [f"請解釋什麼是{term}?"]
+        question_format = [f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
         for question in question_format:
         for question in question_format:
             # question = f"什麼是{term}?"
             # question = f"什麼是{term}?"
             documents = retriever.get_relevant_documents(question, k=5)
             documents = retriever.get_relevant_documents(question, k=5)
-            generation = faiss_query(question, documents, llm) + "\n"
-            if "test@systex.com" in generation:
-                generation = ""
-            
-            answer += generation
-            # print(question)
-            # print(generation)
+            all_documents.extend(documents)
+            # for doc in documents:
+            #     print(doc)
+        all_question = "\n".join(question_format)
+        generation = faiss_query(all_question, all_documents, llm, multi_query=True) + "\n"
+        if "test@systex.com" in generation:
+            generation = ""
+        
+        answer += generation
+        # print(question)
+        # print(generation)
     return answer
     return answer
 ### SQL Grader
 ### SQL Grader
 
 
@@ -177,7 +196,7 @@ def SQL_Grader():
         
         
         You need to check that each where statement is correctly filtered out what user question need.
         You need to check that each where statement is correctly filtered out what user question need.
         
         
-        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
         FROM "建準碳排放清冊數據new"
         FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%建準%'
         WHERE "事業名稱" like '%建準%'
@@ -293,19 +312,21 @@ def retrieve_and_generation(state):
         # documents = retriever.invoke(question)
         # documents = retriever.invoke(question)
         # TODO: correct Retrieval function
         # TODO: correct Retrieval function
         documents = retriever.get_relevant_documents(question, k=5)
         documents = retriever.get_relevant_documents(question, k=5)
-        for doc in documents:
-            print(doc)
+        # for doc in documents:
+        #     print(doc)
             
             
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # print(documents)
         # print(documents)
-        generation = faiss_query(question, documents, llm)
+        generation = faiss_query(question, documents, llm, multi_query=True)
     else:
     else:
         generation = state["generation"]
         generation = state["generation"]
         
         
         for sub_question in list(set(question_list)):
         for sub_question in list(set(question_list)):
             print(sub_question)
             print(sub_question)
-            documents = retriever.get_relevant_documents(sub_question, k=10)
-            generation += faiss_query(sub_question, documents, llm)
+            documents = retriever.get_relevant_documents(sub_question, k=5)
+            # for doc in documents:
+            #     print(doc)
+            generation += faiss_query(sub_question, documents, llm, multi_query=True)
             generation += "\n"
             generation += "\n"
             
             
     print(generation)
     print(generation)
@@ -513,10 +534,11 @@ def grade_sql_query(state):
     question = state["question"]
     question = state["question"]
     sql_query = state["sql_query"]
     sql_query = state["sql_query"]
     sql_result = state["sql_result"]
     sql_result = state["sql_result"]
-    if "None" in sql_result:
+    if "None" in sql_result or sql_result.startswith("Error:"):
         progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         return "incorrect"
         return "incorrect"
     else:
     else:
+        print(sql_result)
         progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
         progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
         return "correct"
         return "correct"
     # retry = state["retry"]
     # retry = state["retry"]
@@ -540,7 +562,16 @@ def grade_sql_query(state):
     #     # print("---GRADE: INCORRECT SQL QUERY---")
     #     # print("---GRADE: INCORRECT SQL QUERY---")
     #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
     #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
     #     return "incorrect"
     #     return "incorrect"
-
+def check_sql_answer(state):
+    progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
+    generation = state["generation"]
+    if "test@systex.com" in generation:
+        progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
+        return "bad"
+    else:
+        progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
+        return "good"
+    
 def build_graph():
 def build_graph():
     workflow = StateGraph(GraphState)
     workflow = StateGraph(GraphState)
 
 
@@ -550,7 +581,7 @@ def build_graph():
     workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("ERROR", error)  # retrieve
     workflow.add_node("ERROR", error)  # retrieve
-    
+    company_private_data_search
     workflow.add_conditional_edges(
     workflow.add_conditional_edges(
         START,
         START,
         route_question,
         route_question,
@@ -577,7 +608,17 @@ def build_graph():
             
             
         },
         },
     )
     )
-    workflow.add_edge("SQL Answer", "Additoinal Explanation")
+    workflow.add_conditional_edges(
+        "SQL Answer",
+        check_sql_answer,
+        {
+            "good": "Additoinal Explanation",
+            "bad": "RAG",
+            
+        },
+    )
+    
+    # workflow.add_edge("SQL Answer", "Additoinal Explanation")
     workflow.add_edge("Additoinal Explanation", END)
     workflow.add_edge("Additoinal Explanation", END)
 
 
     app = workflow.compile()    
     app = workflow.compile()    
@@ -604,11 +645,14 @@ def main(question: str):
     value["progress_bar"] = progress_bar
     value["progress_bar"] = progress_bar
     # pprint(value["progress_bar"])
     # pprint(value["progress_bar"])
     
     
-    return value["generation"]
+    # return value["generation"]
+    return value
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     # result = main("建準去年的逸散排放總排放量是多少?")
     # result = main("建準去年的逸散排放總排放量是多少?")
-    result = main("建準夏威夷去年的綠電使用量是多少?")
+    # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
+    
+    result = main("建準北海廠去年的固定燃燒排放量是多少?")
     # result = main("溫室氣體是什麼?")
     # result = main("溫室氣體是什麼?")
     # result = main("什麼是外購電力(綠電)?")
     # result = main("什麼是外購電力(綠電)?")
     print("------------------------------------------------------")
     print("------------------------------------------------------")

+ 4 - 0
faiss_index.py

@@ -54,6 +54,7 @@ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
 
 
 def download_embeddings():
 def download_embeddings():
     response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
     response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
+    # response = supabase.table(document_table).select("id, embedding, metadata, content").eq('metadata ->> source', 'supplement.docx').execute()
     embeddings = []
     embeddings = []
     ids = []
     ids = []
     metadatas = []
     metadatas = []
@@ -149,6 +150,9 @@ def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
 
 
     questions = generate_queries.invoke(question)
     questions = generate_queries.invoke(question)
     questions = [item for item in questions if item != ""]
     questions = [item for item in questions if item != ""]
+    questions.append(question)
+    for q in questions:
+        print(q)
 
 
     # docs = list(map(retriever.get_relevant_documents, questions))
     # docs = list(map(retriever.get_relevant_documents, questions))
     docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
     docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))

+ 7 - 3
rewrite_question.py

@@ -14,9 +14,13 @@ from langchain_core.runnables import (
 from typing import Tuple, List, Optional
 from typing import Tuple, List, Optional
 from langchain_core.messages import AIMessage, HumanMessage
 from langchain_core.messages import AIMessage, HumanMessage
 
 
-local_llm = "llama3-groq-tool-use:latest"
+# local_llm = "llama3-groq-tool-use:latest"
+# llm = ChatOllama(model=local_llm, temperature=0)
 # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
 # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
-llm = ChatOllama(model=local_llm, temperature=0)
+from dotenv import load_dotenv
+load_dotenv()
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
 
 
 def get_search_query():
 def get_search_query():
     # Condense a chat history and follow-up question into a standalone question
     # Condense a chat history and follow-up question into a standalone question
@@ -105,6 +109,6 @@ if __name__ == "__main__":
     chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
     chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
     print(chat_history)
     print(chat_history)
     
     
-    question = "類別2呢"
+    question = "廣興廠"
     modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     print(modified_question)
     print(modified_question)

+ 8 - 1
systex_app.py

@@ -17,6 +17,7 @@ from langchain.callbacks import get_openai_callback
 
 
 from ai_agent import main
 from ai_agent import main
 from semantic_search import semantic_cache
 from semantic_search import semantic_cache
+from RAG_strategy import get_search_query
 
 
 load_dotenv()
 load_dotenv()
 URI = os.getenv("SUPABASE_URI")
 URI = os.getenv("SUPABASE_URI")
@@ -41,6 +42,10 @@ class ChatHistoryItem(BaseModel):
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
     print(question)
     print(question)
     start = time.time()
     start = time.time()
+    # TODO rewrite query
+    # _search_query = get_search_query()
+    # chat_history = [(item.q, item.a) for item in chat_history[-5:] if item.a != "" and item.a != "string"]
+    # modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     
     
     with get_openai_callback() as cb:
     with get_openai_callback() as cb:
         # cache_question, cache_answer = semantic_cache(supabase, question)
         # cache_question, cache_answer = semantic_cache(supabase, question)
@@ -48,8 +53,10 @@ def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
         if cache_answer:
         if cache_answer:
             answer = cache_answer
             answer = cache_answer
         else:
         else:
-            answer = main(question)
+            result = main(question)
+            answer = result["generation"]
     processing_time = time.time() - start
     processing_time = time.time() - start
+    # save_history(question + "->" + modified_question, answer, cb, processing_time)
     save_history(question, answer, cb, processing_time)
     save_history(question, answer, cb, processing_time)
     if "test@systex.com" in answer:
     if "test@systex.com" in answer:
         answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
         answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"