Explorar el Código

update route prompt

ling hace 4 meses
padre
commit
4e2e6c2245
Se han modificado 1 ficheros con 23 adiciones y 11 borrados
  1. 23 11
      ai_agent.py

+ 23 - 11
ai_agent.py

@@ -80,7 +80,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     )
 
     rag_chain = prompt | llm | StrOutputParser()
-    return rag_chain.invoke({"context": context, "question": question})
+    return docs, rag_chain.invoke({"context": context, "question": question})
 
 ### Hallucination Grader
 
@@ -166,16 +166,21 @@ def generate_additional_detail(sql_query):
     answer = ""
     all_documents = []
     for term in list(set(terms)):
+        print(term)
         if term is None: continue
-        question_format = [f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
+        question_format = [ f"溫室氣體排放源中的{term}是什麼意思?",  f"{term}是什麼意思?"]
+        # f"溫室氣體排放源中的{term}是什麼意思?",
         for question in question_format:
             # question = f"什麼是{term}?"
             documents = retriever.get_relevant_documents(question, k=5)
             all_documents.extend(documents)
             # for doc in documents:
             #     print(doc)
-        all_question = "\n".join(question_format)
-        generation = faiss_query(all_question, all_documents, llm, multi_query=True) + "\n"
+        all_question = "".join(question_format)
+        documents, generation = faiss_query(all_question, all_documents, llm, multi_query=True) 
+        # print(generation)
+        # print("-----------------------")
+        # generation = answer + "\n"
         if "test@systex.com" in generation:
             generation = ""
         
@@ -238,10 +243,9 @@ def Router():
     prompt = PromptTemplate(
         template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 
         You are an expert at routing a user question to a 專業知識 or 自有數據. 
-        Use company private data for questions about the informations about a company's greenhouse gas emissions data.
-        Otherwise, use the 專業知識 for questions on ESG field knowledge or news about ESG. 
-        你需要分辨使用者問題是否在詢問公司的自有數據,例如想了解公司的碳排放源數據等等,如果判斷為是,則使用"自有數據",
-        若使用者的問題是想了解碳盤查或碳管理等等的 ESG 知識和相關新聞,請使用"專業知識"。
+        你需要分辨使用者問題是否在詢問某個公司與其據點廠房的自有數據或是尋求專業的碳盤查或碳管理等等的 ESG 知識和相關新聞,
+        如果問題是想了解某個公司與其據點廠房的碳排放源的排放量或用電、用水量等等,請使用"自有數據",
+        若使用者的問題是想了解碳盤查、碳交易或碳管理等等的 ESG 知識和相關新聞,請使用"專業知識"。
         You do not need to be stringent with the keywords in the question related to these topics. 
         Give a binary choice '自有數據' or '專業知識' based on the question. 
         Return the a JSON with a single key 'datasource' and no premable or explanation. 
@@ -319,7 +323,9 @@ def retrieve_and_generation(state):
             
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # print(documents)
-        generation = faiss_query(question, documents, llm, multi_query=True)
+        documents, generation = faiss_query(question, documents, llm, multi_query=True)
+        # for doc in documents:
+        #     print(doc)
     else:
         generation = state["generation"]
         
@@ -328,7 +334,8 @@ def retrieve_and_generation(state):
             documents = retriever.get_relevant_documents(sub_question, k=5)
             # for doc in documents:
             #     print(doc)
-            generation += faiss_query(sub_question, documents, llm, multi_query=True)
+            documents, answer = faiss_query(sub_question, documents, llm, multi_query=True)
+            generation += answer
             generation += "\n"
             
     print(generation)
@@ -443,8 +450,13 @@ def route_question(state):
     # print(question)
     question_router = Router()
     source = question_router.invoke({"question": question})
-    if "建準" in question:
+    print("Original:", source["datasource"])
+    # if "建準" in question:
+    kw = ["建準", "北海", "廣興", "崑山廣興", "Inc", "SAS", "立準"]
+    if any(char in question for char in kw):
         source["datasource"] = "自有數據"
+    elif "範例" in question:
+        source["datasource"] = "專業知識"
         
     # print(source)
     print(source["datasource"])