|
@@ -9,7 +9,7 @@ from langchain_core.runnables import RunnablePassthrough
|
|
|
from langchain import hub
|
|
|
from langchain.globals import set_llm_cache
|
|
|
from langchain import PromptTemplate
|
|
|
-
|
|
|
+from langchain.llms.base import LLM
|
|
|
|
|
|
from langchain_core.runnables import (
|
|
|
RunnableBranch,
|
|
@@ -48,9 +48,11 @@ openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
openai.api_key = openai_api_key
|
|
|
URI = os.getenv("SUPABASE_URI")
|
|
|
|
|
|
+from typing import Optional, List, Any, Dict
|
|
|
+
|
|
|
# 設置緩存,以減少對API的重複請求。使用Redis
|
|
|
-# set_llm_cache(SQLiteCache(database_path=".langchain.db"))
|
|
|
-# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6380", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
|
|
|
+set_llm_cache(SQLiteCache(database_path=".langchain.db"))
|
|
|
+# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
|
|
|
|
|
|
# # TAIDE model on Ollama https://ollama.com/jcai/llama3-taide-lx-8b-chat-alpha1
|
|
|
# def interact_with_model(messages, api_url="http://localhost:11434/v1/chat/completions"):
|
|
@@ -58,23 +60,55 @@ URI = os.getenv("SUPABASE_URI")
|
|
|
# response = requests.post(api_url, json={"model": "jcai/llama3-taide-lx-8b-chat-alpha1:Q4_K_M", "messages": messages})
|
|
|
# return response.json()["choices"][0]["message"]["content"]
|
|
|
|
|
|
-# class CustomTAIDELLM(LLM):
|
|
|
-# api_url: str = "http://localhost:11434/v1/chat/completions"
|
|
|
-
|
|
|
-# def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
-# messages = [{"role": "user", "content": prompt}]
|
|
|
-# response = requests.post(self.api_url, json={
|
|
|
-# "model": "taide-local", # Use your local model name
|
|
|
-# "messages": messages
|
|
|
-# })
|
|
|
-# return response.json()["choices"][0]["message"]["content"]
|
|
|
-
|
|
|
-# @property
|
|
|
-# def _llm_type(self) -> str:
|
|
|
-# return "custom_taide"
|
|
|
-
|
|
|
-# # Create an instance of the custom LLM
|
|
|
-# taide_llm = CustomTAIDELLM()
|
|
|
+import requests
|
|
|
+from typing import Optional, List, Any, Dict
|
|
|
+from langchain.llms.base import LLM
|
|
|
+
|
|
|
+class CustomTAIDELLM(LLM):
|
|
|
+ api_url: str = "http://localhost:11434/api/chat"
|
|
|
+ model_name: str = "taide-local"
|
|
|
+ system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
+
|
|
|
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
+ # Format the prompt according to TAIDE requirements
|
|
|
+ formatted_prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{prompt} [/INST]"
|
|
|
+ print(f"Formatted prompt being sent to TAIDE model: {formatted_prompt}")
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": self.model_name,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": self.system_prompt},
|
|
|
+ {"role": "user", "content": prompt}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = requests.post(self.api_url, json=payload)
|
|
|
+ response.raise_for_status()
|
|
|
+ result = response.json()
|
|
|
+ print(f"Full API response: {result}")
|
|
|
+ if "message" in result:
|
|
|
+ return result["message"]["content"].strip()
|
|
|
+ else:
|
|
|
+ print(f"Unexpected response structure: {result}")
|
|
|
+ return "Error: Unexpected response from the model."
|
|
|
+ except requests.RequestException as e:
|
|
|
+ print(f"Error calling Ollama API: {e}")
|
|
|
+ return f"Error: Unable to get response from the model. {str(e)}"
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _llm_type(self) -> str:
|
|
|
+ return "custom_taide"
|
|
|
+
|
|
|
+ def get_model_name(self):
|
|
|
+ return self.model_name
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _identifying_params(self) -> Dict[str, Any]:
|
|
|
+ return {"model_name": self.model_name}
|
|
|
+
|
|
|
+# Create an instance of the custom LLM
|
|
|
+taide_llm = CustomTAIDELLM(api_url="http://localhost:11434/api/chat", model_name="taide-local")
|
|
|
|
|
|
# 生成多個不同版本的問題,進行檢索,並返回答案和參考文檔
|
|
|
def multi_query(question, retriever, chat_history):
|
|
@@ -100,7 +134,7 @@ def multi_query(question, retriever, chat_history):
|
|
|
# generate_queries = interact_with_model(messages).split("\n")
|
|
|
|
|
|
|
|
|
- llm = ChatOpenAI(model="gpt-4-1106-preview")
|
|
|
+ # llm = ChatOpenAI(model="gpt-4-1106-preview")
|
|
|
# llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
|
|
|
# llm = ChatOllama(model="gemma2", temperature=0)
|
|
|
# llm = ChatOllama(model=model)
|
|
@@ -108,7 +142,8 @@ def multi_query(question, retriever, chat_history):
|
|
|
|
|
|
generate_queries = (
|
|
|
prompt_perspectives
|
|
|
- | llm
|
|
|
+ | taide_llm
|
|
|
+ # | llm
|
|
|
| StrOutputParser()
|
|
|
| (lambda x: x.split("\n"))
|
|
|
)
|
|
@@ -140,49 +175,37 @@ def multi_query(question, retriever, chat_history):
|
|
|
|
|
|
# 根據檢索到的文檔和用戶問題生成最後答案
|
|
|
def multi_query_rag_prompt(retrieval_chain, question):
|
|
|
- # RAG
|
|
|
template = """Answer the following question based on this context:
|
|
|
|
|
|
{context}
|
|
|
|
|
|
Question: {question}
|
|
|
- Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English\n
|
|
|
+ Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English.
|
|
|
You should not mention anything about "根據提供的文件內容" or other similar terms.
|
|
|
If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。I'm sorry I cannot answer your question. Please send your question to test@email.com for further assistance. Thank you."
|
|
|
"""
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_template(template)
|
|
|
- context = retrieval_chain.invoke({"question": question}) # Ensure this returns the context
|
|
|
-
|
|
|
-
|
|
|
- # llm = ChatOpenAI(temperature=0)
|
|
|
- llm = ChatOpenAI(model="gpt-4-1106-preview")
|
|
|
- # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
|
|
|
- # llm = ChatOllama(model="gemma2", temperature=0)
|
|
|
-
|
|
|
+ context = retrieval_chain.invoke({"question": question})
|
|
|
+ print(f"Retrieved context: {context[:200]}...") # Print first 200 chars of context
|
|
|
|
|
|
final_rag_chain = (
|
|
|
{"context": retrieval_chain,
|
|
|
"question": itemgetter("question")}
|
|
|
| prompt
|
|
|
- | llm
|
|
|
+ | taide_llm
|
|
|
| StrOutputParser()
|
|
|
)
|
|
|
- messages = [
|
|
|
- {"role": "system", "content": template},
|
|
|
- {"role": "user", "content": question},
|
|
|
- {"role": "assistant", "content": context}
|
|
|
- ]
|
|
|
- # answer = interact_with_model(messages)
|
|
|
- answer = final_rag_chain.invoke({"question":question})
|
|
|
-
|
|
|
- answer = ""
|
|
|
- for text in final_rag_chain.stream({"question":question}):
|
|
|
- print(text, end="", flush=True)
|
|
|
- answer += text
|
|
|
-
|
|
|
|
|
|
- return answer
|
|
|
+ print(f"Sending question to model: {question}")
|
|
|
+ try:
|
|
|
+ answer = final_rag_chain.invoke({"question": question})
|
|
|
+ print(f"Received answer: {answer}")
|
|
|
+ return answer
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error invoking rag_chain: {e}")
|
|
|
+ return "Error occurred while processing the question."
|
|
|
+
|
|
|
########################################################################################################################
|
|
|
|
|
|
# 將聊天紀錄個跟進問題轉化為獨立問題
|
|
@@ -255,7 +278,7 @@ def naive_rag(question, retriever):
|
|
|
prompt = hub.pull("rlm/rag-prompt")
|
|
|
|
|
|
# LLM
|
|
|
- llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
|
|
+ # llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
|
|
|
|
|
# Post-processing
|
|
|
def format_docs(docs):
|
|
@@ -267,7 +290,8 @@ def naive_rag(question, retriever):
|
|
|
rag_chain = (
|
|
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
| prompt
|
|
|
- | llm
|
|
|
+ | taide_llm
|
|
|
+ # | llm
|
|
|
| StrOutputParser()
|
|
|
)
|
|
|
|
|
@@ -345,4 +369,11 @@ def rag_score(question, ground_truth, answer, reference_docs):
|
|
|
result_df = result.to_pandas()
|
|
|
print(result_df.head())
|
|
|
result_df.to_csv('ragas_rag.csv')
|
|
|
- return result
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def print_current_model(llm):
|
|
|
+ if isinstance(llm, CustomTAIDELLM):
|
|
|
+ print(f"Currently using model: {llm.get_model_name()}")
|
|
|
+ else:
|
|
|
+ pass
|