2 İşlemeler e46d4b62c3 ... eac4569456

Yazar SHA1 Mesaj Tarih
  ling eac4569456 adjust to gpt4o version, adjust agent flow, add chat history 4 ay önce
  ling 1348a079ba add sql examples and adjust prompt 4 ay önce
6 değiştirilmiş dosya ile 192 ekleme ve 51 silme
  1. 14 2
      RAG_strategy.py
  2. 69 25
      ai_agent.py
  3. 4 0
      faiss_index.py
  4. 7 3
      rewrite_question.py
  5. 8 1
      systex_app.py
  6. 90 20
      text_to_sql_private.py

+ 14 - 2
RAG_strategy.py

@@ -36,16 +36,28 @@ load_dotenv()
 
 
 def multi_query_chain(llm):
 def multi_query_chain(llm):
     # Multi Query: Different Perspectives
     # Multi Query: Different Perspectives
-    template = """You are an AI language model assistant. Your task is to generate three 
+    template = """
+    <|begin_of_text|>
+    
+    <|start_header_id|>system<|end_header_id|>
+    你是一個來自台灣的ESG的AI助理,請用繁體中文回答
+    You are an AI language model assistant. Your task is to generate three 
     different versions of the given user question to retrieve relevant documents from a vector 
     different versions of the given user question to retrieve relevant documents from a vector 
     database. By generating multiple perspectives on the user question, your goal is to help
     database. By generating multiple perspectives on the user question, your goal is to help
     the user overcome some of the limitations of the distance-based similarity search. 
     the user overcome some of the limitations of the distance-based similarity search. 
     Provide these alternative questions separated by newlines. 
     Provide these alternative questions separated by newlines. 
 
 
     You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
     You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
+    <|eot_id|>
+    
+    <|start_header_id|>user<|end_header_id|>
     
     
+    Original question: {question}
+    請用繁體中文
+    <|eot_id|>
     
     
-    Original question: {question}"""
+    <|start_header_id|>assistant<|end_header_id|>
+    """
     prompt_perspectives = ChatPromptTemplate.from_template(template)
     prompt_perspectives = ChatPromptTemplate.from_template(template)
 
 
     
     

+ 69 - 25
ai_agent.py

@@ -24,20 +24,29 @@ db = SQLDatabase.from_uri(URI)
 
 
 # LLM
 # LLM
 # local_llm = "llama3.1:8b-instruct-fp16"
 # local_llm = "llama3.1:8b-instruct-fp16"
+# local_llm = "llama3.1:8b-instruct-q2_K"
 local_llm = "llama3-groq-tool-use:latest"
 local_llm = "llama3-groq-tool-use:latest"
 llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
 llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
-llm = ChatOllama(model=local_llm, temperature=0)
+# llm = ChatOllama(model=local_llm, temperature=0)
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
 
 
 # RAG usage
 # RAG usage
-from faiss_index import create_faiss_retriever, faiss_query
+from faiss_index import create_faiss_retriever, faiss_multiquery, faiss_query
 retriever = create_faiss_retriever()
 retriever = create_faiss_retriever()
 
 
 # text-to-sql usage
 # text-to-sql usage
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from text_to_sql_private import run, get_query, query_to_nl, table_description
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 from post_processing_sqlparse import get_query_columns, parse_sql_where, get_table_name
 progress_bar = []
 progress_bar = []
+
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
 def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
-    
+    if multi_query:
+        docs = faiss_multiquery(question, retriever, llm)
+        # print(docs)
+    else:
+        docs = retriever.get_relevant_documents(question, k=10)
+        # print(docs)
     context = docs
     context = docs
     
     
     system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
     system_prompt: str = "你是一個來自台灣的AI助理,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
@@ -47,7 +56,7 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     <|start_header_id|>system<|end_header_id|>
     <|start_header_id|>system<|end_header_id|>
     你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
     你是一個來自台灣的ESG的AI助理,請用繁體中文回答問題 \n
     You should not mention anything about "根據提供的文件內容" or other similar terms.
     You should not mention anything about "根據提供的文件內容" or other similar terms.
-    Use five sentences maximum and keep the answer concise.
+    請盡可能的詳細回答問題。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     勿回答無關資訊
     勿回答無關資訊
     <|eot_id|>
     <|eot_id|>
@@ -58,8 +67,9 @@ def faiss_query(question: str, docs, llm, multi_query: bool = False) -> str:
     {context}
     {context}
 
 
     Question: {question}
     Question: {question}
-    用繁體中文回答問題
+    用繁體中文回答問題,請用一段話詳細的回答。
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
     如果你不知道答案請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+    
     <|eot_id|>
     <|eot_id|>
     
     
     <|start_header_id|>assistant<|end_header_id|>
     <|start_header_id|>assistant<|end_header_id|>
@@ -129,10 +139,14 @@ def run_text_to_sql(question: str):
 
 
 def _get_query(question: str):
 def _get_query(question: str):
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
     query, result = get_query(db, question, selected_table, llm)
     query, result = get_query(db, question, selected_table, llm)
     return  query, result
     return  query, result
 
 
 def _query_to_nl(question: str, query: str, result):
 def _query_to_nl(question: str, query: str, result):
+    question = question.replace("美國", "美國 Inc")
+    question = question.replace("法國", "法國 SAS")
     answer = query_to_nl(question, query, result, llm)
     answer = query_to_nl(question, query, result, llm)
     return  answer
     return  answer
 
 
@@ -150,19 +164,24 @@ def generate_additional_question(sql_query):
 def generate_additional_detail(sql_query):
 def generate_additional_detail(sql_query):
     terms = parse_sql_where(sql_query)
     terms = parse_sql_where(sql_query)
     answer = ""
     answer = ""
+    all_documents = []
     for term in list(set(terms)):
     for term in list(set(terms)):
         if term is None: continue
         if term is None: continue
-        question_format = [f"請解釋什麼是{term}?"]
+        question_format = [f"溫室氣體排放源中的{term}是什麼意思?", f"{term}是什麼意思?"]
         for question in question_format:
         for question in question_format:
             # question = f"什麼是{term}?"
             # question = f"什麼是{term}?"
             documents = retriever.get_relevant_documents(question, k=5)
             documents = retriever.get_relevant_documents(question, k=5)
-            generation = faiss_query(question, documents, llm) + "\n"
-            if "test@systex.com" in generation:
-                generation = ""
-            
-            answer += generation
-            # print(question)
-            # print(generation)
+            all_documents.extend(documents)
+            # for doc in documents:
+            #     print(doc)
+        all_question = "\n".join(question_format)
+        generation = faiss_query(all_question, all_documents, llm, multi_query=True) + "\n"
+        if "test@systex.com" in generation:
+            generation = ""
+        
+        answer += generation
+        # print(question)
+        # print(generation)
     return answer
     return answer
 ### SQL Grader
 ### SQL Grader
 
 
@@ -177,7 +196,7 @@ def SQL_Grader():
         
         
         You need to check that each where statement is correctly filtered out what user question need.
         You need to check that each where statement is correctly filtered out what user question need.
         
         
-        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
+        For example, if user question is "建準去年固定燃燒總排放量是多少?", and the PostgreSQL query is 
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
         "SELECT SUM("排放量(公噸CO2e)") AS "下游租賃總排放量"
         FROM "建準碳排放清冊數據new"
         FROM "建準碳排放清冊數據new"
         WHERE "事業名稱" like '%建準%'
         WHERE "事業名稱" like '%建準%'
@@ -293,19 +312,21 @@ def retrieve_and_generation(state):
         # documents = retriever.invoke(question)
         # documents = retriever.invoke(question)
         # TODO: correct Retrieval function
         # TODO: correct Retrieval function
         documents = retriever.get_relevant_documents(question, k=5)
         documents = retriever.get_relevant_documents(question, k=5)
-        for doc in documents:
-            print(doc)
+        # for doc in documents:
+        #     print(doc)
             
             
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # docs_documents = "\n\n".join(doc.page_content for doc in documents)
         # print(documents)
         # print(documents)
-        generation = faiss_query(question, documents, llm)
+        generation = faiss_query(question, documents, llm, multi_query=True)
     else:
     else:
         generation = state["generation"]
         generation = state["generation"]
         
         
         for sub_question in list(set(question_list)):
         for sub_question in list(set(question_list)):
             print(sub_question)
             print(sub_question)
-            documents = retriever.get_relevant_documents(sub_question, k=10)
-            generation += faiss_query(sub_question, documents, llm)
+            documents = retriever.get_relevant_documents(sub_question, k=5)
+            # for doc in documents:
+            #     print(doc)
+            generation += faiss_query(sub_question, documents, llm, multi_query=True)
             generation += "\n"
             generation += "\n"
             
             
     print(generation)
     print(generation)
@@ -513,10 +534,11 @@ def grade_sql_query(state):
     question = state["question"]
     question = state["question"]
     sql_query = state["sql_query"]
     sql_query = state["sql_query"]
     sql_result = state["sql_result"]
     sql_result = state["sql_result"]
-    if "None" in sql_result:
+    if "None" in sql_result or sql_result.startswith("Error:"):
         progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         progress_bar = show_progress(state, "---INCORRECT SQL QUERY---")
         return "incorrect"
         return "incorrect"
     else:
     else:
+        print(sql_result)
         progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
         progress_bar = show_progress(state, "---CORRECT SQL QUERY---")
         return "correct"
         return "correct"
     # retry = state["retry"]
     # retry = state["retry"]
@@ -540,7 +562,16 @@ def grade_sql_query(state):
     #     # print("---GRADE: INCORRECT SQL QUERY---")
     #     # print("---GRADE: INCORRECT SQL QUERY---")
     #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
     #     progress_bar = show_progress(state, "---GRADE: INCORRECT SQL QUERY---")
     #     return "incorrect"
     #     return "incorrect"
-
+def check_sql_answer(state):
+    progress_bar = show_progress(state, "---CHECK SQL ANSWER QUALITY---")
+    generation = state["generation"]
+    if "test@systex.com" in generation:
+        progress_bar = show_progress(state, "---SQL CAN NOT GENERATE ANSWER---")
+        return "bad"
+    else:
+        progress_bar = show_progress(state, "---SQL CAN GENERATE ANSWER---")
+        return "good"
+    
 def build_graph():
 def build_graph():
     workflow = StateGraph(GraphState)
     workflow = StateGraph(GraphState)
 
 
@@ -550,7 +581,7 @@ def build_graph():
     workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("Additoinal Explanation", additional_explanation_question, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("RAG", retrieve_and_generation, retry=RetryPolicy(max_attempts=5))  # retrieve
     workflow.add_node("ERROR", error)  # retrieve
     workflow.add_node("ERROR", error)  # retrieve
-    
+    company_private_data_search
     workflow.add_conditional_edges(
     workflow.add_conditional_edges(
         START,
         START,
         route_question,
         route_question,
@@ -577,7 +608,17 @@ def build_graph():
             
             
         },
         },
     )
     )
-    workflow.add_edge("SQL Answer", "Additoinal Explanation")
+    workflow.add_conditional_edges(
+        "SQL Answer",
+        check_sql_answer,
+        {
+            "good": "Additoinal Explanation",
+            "bad": "RAG",
+            
+        },
+    )
+    
+    # workflow.add_edge("SQL Answer", "Additoinal Explanation")
     workflow.add_edge("Additoinal Explanation", END)
     workflow.add_edge("Additoinal Explanation", END)
 
 
     app = workflow.compile()    
     app = workflow.compile()    
@@ -604,11 +645,14 @@ def main(question: str):
     value["progress_bar"] = progress_bar
     value["progress_bar"] = progress_bar
     # pprint(value["progress_bar"])
     # pprint(value["progress_bar"])
     
     
-    return value["generation"]
+    # return value["generation"]
+    return value
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     # result = main("建準去年的逸散排放總排放量是多少?")
     # result = main("建準去年的逸散排放總排放量是多少?")
-    result = main("建準夏威夷去年的綠電使用量是多少?")
+    # result = main("建準廣興廠去年的上游運輸總排放量是多少?")
+    
+    result = main("建準北海廠去年的固定燃燒排放量是多少?")
     # result = main("溫室氣體是什麼?")
     # result = main("溫室氣體是什麼?")
     # result = main("什麼是外購電力(綠電)?")
     # result = main("什麼是外購電力(綠電)?")
     print("------------------------------------------------------")
     print("------------------------------------------------------")

+ 4 - 0
faiss_index.py

@@ -54,6 +54,7 @@ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
 
 
 def download_embeddings():
 def download_embeddings():
     response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
     response = supabase.table(document_table).select("id, embedding, metadata, content").execute()
+    # response = supabase.table(document_table).select("id, embedding, metadata, content").eq('metadata ->> source', 'supplement.docx').execute()
     embeddings = []
     embeddings = []
     ids = []
     ids = []
     metadatas = []
     metadatas = []
@@ -149,6 +150,9 @@ def faiss_multiquery(question: str, retriever: FAISSRetriever, llm):
 
 
     questions = generate_queries.invoke(question)
     questions = generate_queries.invoke(question)
     questions = [item for item in questions if item != ""]
     questions = [item for item in questions if item != ""]
+    questions.append(question)
+    for q in questions:
+        print(q)
 
 
     # docs = list(map(retriever.get_relevant_documents, questions))
     # docs = list(map(retriever.get_relevant_documents, questions))
     docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))
     docs = list(map(lambda query: retriever.get_relevant_documents(query, k=4), questions))

+ 7 - 3
rewrite_question.py

@@ -14,9 +14,13 @@ from langchain_core.runnables import (
 from typing import Tuple, List, Optional
 from typing import Tuple, List, Optional
 from langchain_core.messages import AIMessage, HumanMessage
 from langchain_core.messages import AIMessage, HumanMessage
 
 
-local_llm = "llama3-groq-tool-use:latest"
+# local_llm = "llama3-groq-tool-use:latest"
+# llm = ChatOllama(model=local_llm, temperature=0)
 # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
 # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
-llm = ChatOllama(model=local_llm, temperature=0)
+from dotenv import load_dotenv
+load_dotenv()
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
 
 
 def get_search_query():
 def get_search_query():
     # Condense a chat history and follow-up question into a standalone question
     # Condense a chat history and follow-up question into a standalone question
@@ -105,6 +109,6 @@ if __name__ == "__main__":
     chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
     chat_history = [(history["q"] , history["a"] ) for history in chat_history if history["a"] != "" and history["a"]  != "string"]
     print(chat_history)
     print(chat_history)
     
     
-    question = "類別2呢"
+    question = "廣興廠"
     modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     print(modified_question)
     print(modified_question)

+ 8 - 1
systex_app.py

@@ -17,6 +17,7 @@ from langchain.callbacks import get_openai_callback
 
 
 from ai_agent import main
 from ai_agent import main
 from semantic_search import semantic_cache
 from semantic_search import semantic_cache
+from RAG_strategy import get_search_query
 
 
 load_dotenv()
 load_dotenv()
 URI = os.getenv("SUPABASE_URI")
 URI = os.getenv("SUPABASE_URI")
@@ -41,6 +42,10 @@ class ChatHistoryItem(BaseModel):
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
 def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
     print(question)
     print(question)
     start = time.time()
     start = time.time()
+    # TODO rewrite query
+    # _search_query = get_search_query()
+    # chat_history = [(item.q, item.a) for item in chat_history[-5:] if item.a != "" and item.a != "string"]
+    # modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
     
     
     with get_openai_callback() as cb:
     with get_openai_callback() as cb:
         # cache_question, cache_answer = semantic_cache(supabase, question)
         # cache_question, cache_answer = semantic_cache(supabase, question)
@@ -48,8 +53,10 @@ def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
         if cache_answer:
         if cache_answer:
             answer = cache_answer
             answer = cache_answer
         else:
         else:
-            answer = main(question)
+            result = main(question)
+            answer = result["generation"]
     processing_time = time.time() - start
     processing_time = time.time() - start
+    # save_history(question + "->" + modified_question, answer, cb, processing_time)
     save_history(question, answer, cb, processing_time)
     save_history(question, answer, cb, processing_time)
     if "test@systex.com" in answer:
     if "test@systex.com" in answer:
         answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
         answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"

+ 90 - 20
text_to_sql_private.py

@@ -46,8 +46,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
 ##########################################################################################
 ##########################################################################################
 from langchain_community.chat_models import ChatOllama
 from langchain_community.chat_models import ChatOllama
 # local_llm = "llama3-groq-tool-use:latest"
 # local_llm = "llama3-groq-tool-use:latest"
-local_llm = "llama3-groq-tool-use:latest"
-llm = ChatOllama(model=local_llm, temperature=0)
+# local_llm = "llama3-groq-tool-use:latest"
+# local_llm = "sqlcoder:latest"
+# local_llm = "llama3.1:8b-instruct-q2_K"
+# llm = ChatOllama(model=local_llm, temperature=0)
 ##########################################################################################
 ##########################################################################################
 # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 # tokenizer = AutoTokenizer.from_pretrained(model_id)
 # tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -73,57 +75,110 @@ llm = ChatOllama(model=local_llm, temperature=0)
 # llm = HuggingFacePipeline(pipeline=pipe)
 # llm = HuggingFacePipeline(pipeline=pipe)
 
 
 # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
 # llm = Ollama(model = "llama3-groq-tool-use:latest", num_gpu=1)
+from langchain_openai import ChatOpenAI
+llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
+
 def get_examples():
 def get_examples():
     examples = [
     examples = [
         {
         {
-            "input": "建準廣興廠去年的自產電力的綠電使用量是多少?",
+            "input": "建準去年固定燃燒總排放量",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "廣興廠去年的固定燃燒排放量是多少?",
+            "query": """FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "據點" = '昆山廣興廠'
+                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+        },
+        {
+            "input": "建準廣興廠去年自產電力的綠電使用量是多少?",
             "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
             "query": """SELECT SUM("用電度數(kwh)") AS "綠電使用量"
                         FROM "用電度數"
                         FROM "用電度數"
                         WHERE "項目" like '%綠電%'
                         WHERE "項目" like '%綠電%'
                         AND "事業名稱" like '%建準%'
                         AND "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興廠%'
+                        AND "據點" = '昆山廣興廠'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         },
         {
         {
-            "input": "建準北海廠去年的類別1總排放量是多少?",
+            "input": "建準北海廠去年的類別1總排放量",
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
             "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
                         FROM "建準碳排放清冊數據new"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%北海%'
+                        AND "據點" in ('北海建準廠', '北海立準廠')
                         AND "類別" = '類別1'
                         AND "類別" = '類別1'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         },
         {
         {
             "input": "建準廣興廠去年的直接排放總排放量是多少?",
             "input": "建準廣興廠去年的直接排放總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
                         FROM "建準碳排放清冊數據new"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興%'
+                        AND "據點" = '昆山廣興廠'
                         AND ("類別項目" like '%直接排放%' OR "排放源" like '%直接排放%')
                         AND ("類別項目" like '%直接排放%' OR "排放源" like '%直接排放%')
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
         },
         },
         {
         {
             "input": "建準台北辦事處2022年的類別2總排放量是多少?",
             "input": "建準台北辦事處2022年的類別2總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "直接排放總排放量"
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
                         FROM "建準碳排放清冊數據new"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%台北%'
+                        AND "據點" = '台北辦事處'
                         AND "類別" = '類別2'
                         AND "類別" = '類別2'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = 2022;""",
                         AND "年度" = 2022;""",
         },
         },
         {
         {
-            "input": "建準去年的固定燃燒總排放量是多少?",
-            "query": """SELECT SUM("排放量(公噸CO2e)") AS "固定燃燒總排放量"
+            "input": "建準法國廠2022年的類別2總排放量",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "類別2總排放量"
                         FROM "建準碳排放清冊數據new"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND ("類別項目" like '%固定燃燒%' OR "排放源" like '%固定燃燒%')
+                        AND "國家" = '法國'
+                        AND "類別" = '類別2'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
-                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;""",
+                        AND "年度" = 2022;""",
+        },
+        {
+            "input": "建準北海2022的外購電力是多少",
+            "query": """SELECT SUM("用電度數(kwh)") AS "外購電力"
+                        FROM "用電度數"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "據點" in ('北海建準廠', '北海立準廠')
+                        AND "項目" like '%外購電力%'
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2022;""",
+        },
+        {
+            "input": "2023建準印度的其他間接排放是多少",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "其他間接排放總量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "國家" = '印度'
+                        AND ("類別項目" like '%其他間接排放%' OR "排放源" like '%其他間接排放%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = 2023;""",
         },
         },
+        {
+            "input": "建準台北前年的產品使用碳排放量是多少",
+            "query": """SELECT SUM("排放量(公噸CO2e)") AS "產品使用總量"
+                        FROM "建準碳排放清冊數據new"
+                        WHERE "事業名稱" like '%建準%'
+                        AND "據點" = '台北辦事處'
+                        AND ("類別項目" like '%產品使用%' OR "排放源" like '%產品使用%')
+                        AND "盤查標準" = 'GHG'
+                        AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-2;""",
+        },
+
 
 
 
 
     ]
     ]
@@ -137,7 +192,8 @@ def table_description():
         "The `建準碳排放清冊數據new` table 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
         "The `建準碳排放清冊數據new` table 描述了建準電機工業股份有限公司不同據點分別在 ISO 14064-1:2018 與 GHG Protocol 標準下的溫室氣體排放量,並依類別1至類別6劃分。"
         "It includes the following columns:\n"
         "It includes the following columns:\n"
         "- `年度`: 盤查年度\n"
         "- `年度`: 盤查年度\n"
-        "- `事業名稱`: 建準據點"
+        "- `事業名稱`: 公司名稱"
+        "- `據點`: 建準廠房據點 include '高雄總部及運通廠', '台北辦事處', '昆山廣興廠', '北海建準廠', '北海立準廠', '菲律賓建準廠', 'Inc', 'SAS', 'India'"
         "- `國家`: 據點所在國家"
         "- `國家`: 據點所在國家"
         "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
         "- `類別`: 溫室氣體的排放類別,包含以下選項:\n"
         "   \t*類別1-直接排放:\n"
         "   \t*類別1-直接排放:\n"
@@ -183,7 +239,9 @@ def write_query_chain(db, llm):
     <|begin_of_text|>
     <|begin_of_text|>
     
     
     <|start_header_id|>system<|end_header_id|>
     <|start_header_id|>system<|end_header_id|>
+
     Generate a SQL query to answer this question: `{input}`
     Generate a SQL query to answer this question: `{input}`
+    你是建準的AI助理,幫助建準查詢碳排放量,如果問題中有提到據點廠房,請使用 PostgreSQL query 進行篩選。
 
 
     You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
     You are a PostgreSQL expert in ESG field. Given an input question, first create a syntactically correct PostgreSQL query to run, 
     then look at the results of the query and return the answer to the input question.\n\
     then look at the results of the query and return the answer to the input question.\n\
@@ -192,6 +250,7 @@ def write_query_chain(db, llm):
     Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
     Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
     Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
     Wrap each column name in  Quotation Mark (") to denote them as delimited identifiers.\n\
     
     
+    Unless the user ask for the type of 盤查標準 to be 'ISO' or 'GHG', queries always include query "盤查標準"='GHG' in the WHERE clause.\n  
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
     ***Pay attention to only return query for PostgreSQL WITHOUT "```sql", And DO NOT content any other words.\n\
     ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
     ***Pay attention to only return PostgreSQL query and no premable or explanation.\n\
     <|eot_id|>
     <|eot_id|>
@@ -209,9 +268,9 @@ def write_query_chain(db, llm):
     Below are a number of examples of questions and their corresponding SQL queries.\n\
     Below are a number of examples of questions and their corresponding SQL queries.\n\
     
     
     <|eot_id|>
     <|eot_id|>
-    
-    <|start_header_id|>assistant<|end_header_id|>
+    SQL query:
     """
     """
+    # <|start_header_id|>assistant<|end_header_id|>
     # prompt_template = PromptTemplate.from_template(template)
     # prompt_template = PromptTemplate.from_template(template)
 
 
     example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
     example_prompt = PromptTemplate.from_template("The following SQL query best answers the question `{input}`\nSQL query: {query}")
@@ -227,6 +286,7 @@ def write_query_chain(db, llm):
     # llm = HuggingFacePipeline(pipeline=pipe)
     # llm = HuggingFacePipeline(pipeline=pipe)
     
     
     
     
+    # sqlcoder = Ollama(model = "sqlcoder", num_gpu=1)
     write_query = create_sql_query_chain(llm, db, prompt)
     write_query = create_sql_query_chain(llm, db, prompt)
 
 
 
 
@@ -245,11 +305,11 @@ def sql_to_nl_chain(llm):
         ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
         ** 請務必在回答中表達是建準的資料,即便問句中並未提及建準。
         
         
         The following shows some example:
         The following shows some example:
-        Question: 廣興廠去年的類別1總排放量是多少?
+        Question: 建準廣興廠去年的類別1總排放量是多少?
         SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
         SQL Query: SELECT SUM("排放量(公噸CO2e)") AS "類別1總排放量"
                         FROM "建準碳排放清冊數據new"
                         FROM "建準碳排放清冊數據new"
                         WHERE "事業名稱" like '%建準%'
                         WHERE "事業名稱" like '%建準%'
-                        AND "事業名稱" like '%廣興%'
+                        AND "據點" = '昆山廣興廠'
                         AND "類別" = '類別1'
                         AND "類別" = '類別1'
                         AND "盤查標準" = 'GHG'
                         AND "盤查標準" = 'GHG'
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
                         AND "年度" = EXTRACT(YEAR FROM CURRENT_DATE)-1;,
@@ -257,6 +317,7 @@ def sql_to_nl_chain(llm):
         Answer: 建準廣興廠去年的類別1總排放量是1102.3712
         Answer: 建準廣興廠去年的類別1總排放量是1102.3712
 
 
         如果你不知道答案或SQL query 出現錯誤請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
         如果你不知道答案或SQL query 出現錯誤請回答:"很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
+        
         勿回答無關資訊
         勿回答無關資訊
         <|eot_id|>
         <|eot_id|>
 
 
@@ -278,10 +339,14 @@ def sql_to_nl_chain(llm):
     return chain
     return chain
 
 
 def get_query(db, question, selected_table, llm):
 def get_query(db, question, selected_table, llm):
+    
     write_query = write_query_chain(db, llm)
     write_query = write_query_chain(db, llm)
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     
     query = re.split('SQL query: ', query)[-1]
     query = re.split('SQL query: ', query)[-1]
+    query = query.replace("```sql","").replace("```","")
+    query = query.replace("碰排","碳排")
+    query = query.replace("%%","%")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     print(query)
     print(query)
     
     
@@ -308,6 +373,9 @@ def run(db, question, selected_table, llm):
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     query = write_query.invoke({"question": question, 'table_names_to_use': selected_table, "top_k": 1000, "table_info":context["table_info"], "database_description": table_description()})
     
     
     query = re.split('SQL query: ', query)[-1]
     query = re.split('SQL query: ', query)[-1]
+    query = query.replace("```sql","").replace("```","")
+    query = query.replace("碰排","碳排")
+    query = query.replace("%%","%")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     # query = query.replace("104_112碰排放公開及建準資料","104_112碳排放公開及建準資料")
     print(query)
     print(query)
 
 
@@ -327,7 +395,9 @@ if __name__ == "__main__":
     start = time.time()
     start = time.time()
     
     
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
     selected_table = ['用水度數', '用水度數', '建準碳排放清冊數據new']
-    question = "建準去年的上游運輸總排放量是多少?"
+    # question = "建準廣興廠去年的上游運輸總排放量是多少?"
+    question = "建準北海廠去年的固定燃燒排放量是多少?"
+    # question = "建準北海廠去年類別1總排放量是多少?"
     # question = "台積電2022年的直接排放總排放量是多少?"
     # question = "台積電2022年的直接排放總排放量是多少?"
     # question = "建準廣興廠去年的灰電使用量"
     # question = "建準廣興廠去年的灰電使用量"
     query, result, answer = run(db, question, selected_table, llm)
     query, result, answer = run(db, question, selected_table, llm)