SherryLiu пре 7 месеци
родитељ
комит
bcec934286
4 измењених фајлова са 144 додато и 55 уклоњено
  1. 1 1
      RAG_app_copy.py
  2. 81 50
      RAG_strategy.py
  3. 35 0
      TAIDE-test.py
  4. 27 4
      docker-compose.yml

+ 1 - 1
RAG_app_copy.py

@@ -121,7 +121,7 @@ def multi_query_answer(question, retriever=Depends(get_retriever)):
         return {"Answer": final_answer}
     except Exception as e:
         logger.error(f"Error in /answer2 endpoint: {e}")
-        raise HTTPException(status_code=500, detail="Internal Server Error")
+        raise HTTPException(status_code=500, detail=str(e))
 
 class ChatHistoryItem(BaseModel):
     q: str

+ 81 - 50
RAG_strategy.py

@@ -9,7 +9,7 @@ from langchain_core.runnables import RunnablePassthrough
 from langchain import hub
 from langchain.globals import set_llm_cache
 from langchain import PromptTemplate
-
+from langchain.llms.base import LLM
 
 from langchain_core.runnables import (
     RunnableBranch,
@@ -48,9 +48,11 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
 openai.api_key = openai_api_key
 URI = os.getenv("SUPABASE_URI")
 
+from typing import Optional, List, Any, Dict
+
 # 設置緩存,以減少對API的重複請求。使用Redis
-# set_llm_cache(SQLiteCache(database_path=".langchain.db"))
-# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6380", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
+set_llm_cache(SQLiteCache(database_path=".langchain.db"))
+# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
 
 # # TAIDE model on Ollama https://ollama.com/jcai/llama3-taide-lx-8b-chat-alpha1
 # def interact_with_model(messages, api_url="http://localhost:11434/v1/chat/completions"):
@@ -58,23 +60,55 @@ URI = os.getenv("SUPABASE_URI")
 #     response = requests.post(api_url, json={"model": "jcai/llama3-taide-lx-8b-chat-alpha1:Q4_K_M", "messages": messages})
 #     return response.json()["choices"][0]["message"]["content"]
 
-# class CustomTAIDELLM(LLM):
-#     api_url: str = "http://localhost:11434/v1/chat/completions"
-    
-#     def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
-#         messages = [{"role": "user", "content": prompt}]
-#         response = requests.post(self.api_url, json={
-#             "model": "taide-local",  # Use your local model name
-#             "messages": messages
-#         })
-#         return response.json()["choices"][0]["message"]["content"]
-    
-#     @property
-#     def _llm_type(self) -> str:
-#         return "custom_taide"
-
-# # Create an instance of the custom LLM
-# taide_llm = CustomTAIDELLM()
+import requests
+from typing import Optional, List, Any, Dict
+from langchain.llms.base import LLM
+
+class CustomTAIDELLM(LLM):
+    api_url: str = "http://localhost:11434/api/chat"
+    model_name: str = "taide-local"
+    system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+
+    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
+        # Format the prompt according to TAIDE requirements
+        formatted_prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{prompt} [/INST]"
+        print(f"Formatted prompt being sent to TAIDE model: {formatted_prompt}")
+
+        payload = {
+            "model": self.model_name,
+            "messages": [
+                {"role": "system", "content": self.system_prompt},
+                {"role": "user", "content": prompt}
+            ]
+        }
+
+        try:
+            response = requests.post(self.api_url, json=payload)
+            response.raise_for_status()
+            result = response.json()
+            print(f"Full API response: {result}")
+            if "message" in result:
+                return result["message"]["content"].strip()
+            else:
+                print(f"Unexpected response structure: {result}")
+                return "Error: Unexpected response from the model."
+        except requests.RequestException as e:
+            print(f"Error calling Ollama API: {e}")
+            return f"Error: Unable to get response from the model. {str(e)}"
+
+    @property
+    def _llm_type(self) -> str:
+        return "custom_taide"
+
+    def get_model_name(self):
+        return self.model_name
+
+    @property
+    def _identifying_params(self) -> Dict[str, Any]:
+        return {"model_name": self.model_name}
+
+# Create an instance of the custom LLM
+taide_llm = CustomTAIDELLM(api_url="http://localhost:11434/api/chat", model_name="taide-local")
 
 # 生成多個不同版本的問題,進行檢索,並返回答案和參考文檔
 def multi_query(question, retriever, chat_history):
@@ -100,7 +134,7 @@ def multi_query(question, retriever, chat_history):
         # generate_queries = interact_with_model(messages).split("\n")
 
         
-        llm = ChatOpenAI(model="gpt-4-1106-preview")
+        # llm = ChatOpenAI(model="gpt-4-1106-preview")
         # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
         # llm = ChatOllama(model="gemma2", temperature=0)
         # llm = ChatOllama(model=model)
@@ -108,7 +142,8 @@ def multi_query(question, retriever, chat_history):
 
         generate_queries = (
             prompt_perspectives 
-            | llm
+            | taide_llm
+            # | llm
             | StrOutputParser() 
             | (lambda x: x.split("\n"))
         )
@@ -140,49 +175,37 @@ def multi_query(question, retriever, chat_history):
 
 # 根據檢索到的文檔和用戶問題生成最後答案
 def multi_query_rag_prompt(retrieval_chain, question):
-    # RAG
     template = """Answer the following question based on this context:
 
     {context}
 
     Question: {question}
-    Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English\n
+    Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English.
     You should not mention anything about "根據提供的文件內容" or other similar terms.
     If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。I'm sorry I cannot answer your question. Please send your question to test@email.com for further assistance. Thank you."
     """
 
     prompt = ChatPromptTemplate.from_template(template)
-    context = retrieval_chain.invoke({"question": question})  # Ensure this returns the context
-
-
-    # llm = ChatOpenAI(temperature=0)
-    llm = ChatOpenAI(model="gpt-4-1106-preview")
-    # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
-    # llm = ChatOllama(model="gemma2", temperature=0)
-
+    context = retrieval_chain.invoke({"question": question})
+    print(f"Retrieved context: {context[:200]}...")  # Print first 200 chars of context
 
     final_rag_chain = (
         {"context": retrieval_chain, 
         "question": itemgetter("question")} 
         | prompt
-        | llm
+        | taide_llm
         | StrOutputParser()
     )
-    messages = [
-        {"role": "system", "content": template},
-        {"role": "user", "content": question},
-        {"role": "assistant", "content": context}
-    ]
-    # answer = interact_with_model(messages)
-    answer = final_rag_chain.invoke({"question":question})
-
-    answer = ""
-    for text in final_rag_chain.stream({"question":question}):
-        print(text, end="", flush=True)
-        answer += text
-
 
-    return answer
+    print(f"Sending question to model: {question}")
+    try:
+        answer = final_rag_chain.invoke({"question": question})
+        print(f"Received answer: {answer}")
+        return answer
+    except Exception as e:
+        print(f"Error invoking rag_chain: {e}")
+        return "Error occurred while processing the question."
+    
 ########################################################################################################################
 
 # 將聊天紀錄個跟進問題轉化為獨立問題
@@ -255,7 +278,7 @@ def naive_rag(question, retriever):
     prompt = hub.pull("rlm/rag-prompt")
 
     # LLM
-    llm = ChatOpenAI(model_name="gpt-3.5-turbo")
+    # llm = ChatOpenAI(model_name="gpt-3.5-turbo")
 
     # Post-processing
     def format_docs(docs):
@@ -267,7 +290,8 @@ def naive_rag(question, retriever):
     rag_chain = (
         {"context": retriever | format_docs, "question": RunnablePassthrough()}
         | prompt
-        | llm
+        | taide_llm
+        # | llm
         | StrOutputParser()
     )
 
@@ -345,4 +369,11 @@ def rag_score(question, ground_truth, answer, reference_docs):
     result_df = result.to_pandas()
     print(result_df.head())
     result_df.to_csv('ragas_rag.csv')
-    return result
+    return result
+
+
+def print_current_model(llm):
+    if isinstance(llm, CustomTAIDELLM):
+        print(f"Currently using model: {llm.get_model_name()}")
+    else:
+        pass

+ 35 - 0
TAIDE-test.py

@@ -0,0 +1,35 @@
+import requests
+import json 
+
+OLLAMA_API_URL = "http://localhost:11434/api/chat"
+
+def query_ollama_taide(messages):
+    payload = {
+        "model": "taide-local",
+        "prompt": messages,
+        "stream": False
+    }
+
+    try:
+        response = requests.post(OLLAMA_API_URL, json=payload)
+        response.raise_for_status()
+        result = response.json()
+        return result.get('message', {}).get('content', '')
+    except requests.RequestException as e:
+        print(f"Error calling Ollama API:{e}")
+        return None 
+    
+def simple_rag(question, context):
+    system_prompt = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+    user_prompt = f"根據以下上下文回答問題:\n\n上下文:{context}\n\n問題:{question}\n\n請提供簡潔的回答。"
+
+    messages = [
+        {"role": "system", "content":system_prompt},
+        {"role":"user", "content":user_prompt}
+    ]
+    
+    return query_ollama_taide(messages)
+
+if __name__ == "__main__":
+    context = "台灣在2050年訂立了淨零排放目標,並制定了相關法規和政策來推動減碳。"
+    question = "台灣的溫室氣體排放法規目標是什麼?"

+ 27 - 4
docker-compose.yml

@@ -1,4 +1,27 @@
-version: '3'
+# version: '3'
+# services:
+#   ollama:
+#     image: ollama/ollama
+#     volumes:
+#       - ollama:/root/.ollama
+#       - /Users/sherry/Documents/_Personal/ChoozeMo/notebooks/carbon/llm/ollama:/models
+#     ports:
+#       - "11434:11434"
+#     mem_limit: 16g
+#     cpus: 6
+#     command: sh -c "ollama create taide-local -f /models/ide-7b-a.2-q4_k_m.gguf && ollama run taide-local"
+
+#   redis: 
+#     image: redis:redit-stack:latest
+#     ports:
+#       - "6379:6379"
+#       - "8001:8001"
+
+# volumes:
+#   ollama:
+
+
+
 services:
   ollama:
     image: ollama/ollama
@@ -9,12 +32,12 @@ services:
       - "11434:11434"
     mem_limit: 16g
     cpus: 6
-    command: sh -c "ollama create taide-local -f /models/ide-7b-a.2-q4_k_m.gguf && ollama run taide-local"
 
-  redis: 
-    image: redis:alpine
+  redis:
+    image: redis/redis-stack:latest
     ports:
       - "6379:6379"
+      - "8001:8001"
 
 volumes:
   ollama: