|
@@ -9,7 +9,14 @@ from langchain_core.runnables import RunnablePassthrough
|
|
|
from langchain import hub
|
|
|
from langchain.globals import set_llm_cache
|
|
|
from langchain import PromptTemplate
|
|
|
-from langchain.llms.base import LLM
|
|
|
+import subprocess
|
|
|
+import json
|
|
|
+from typing import Any, List, Optional, Dict
|
|
|
+from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
+from langchain_core.language_models import BaseChatModel
|
|
|
+from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
|
|
+from langchain_core.outputs import ChatResult, ChatGeneration
|
|
|
+from pydantic import Field
|
|
|
|
|
|
from langchain_core.runnables import (
|
|
|
RunnableBranch,
|
|
@@ -17,9 +24,6 @@ from langchain_core.runnables import (
|
|
|
RunnableParallel,
|
|
|
RunnablePassthrough,
|
|
|
)
|
|
|
-from typing import Tuple, List, Optional
|
|
|
-from langchain_core.messages import AIMessage, HumanMessage
|
|
|
-
|
|
|
|
|
|
from datasets import Dataset
|
|
|
from ragas import evaluate
|
|
@@ -29,92 +33,71 @@ from ragas.metrics import (
|
|
|
context_recall,
|
|
|
context_precision,
|
|
|
)
|
|
|
-from typing import List
|
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
load_dotenv('environment.env')
|
|
|
|
|
|
-########################################################################################################################
|
|
|
-########################################################################################################################
|
|
|
from langchain.cache import SQLiteCache
|
|
|
-from langchain.cache import RedisSemanticCache
|
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
from langchain.globals import set_llm_cache
|
|
|
|
|
|
-########################################################################################################################
|
|
|
import requests
|
|
|
import openai
|
|
|
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
openai.api_key = openai_api_key
|
|
|
URI = os.getenv("SUPABASE_URI")
|
|
|
|
|
|
-from typing import Optional, List, Any, Dict
|
|
|
-
|
|
|
-# 設置緩存,以減少對API的重複請求。使用Redis
|
|
|
+# 設置緩存,以減少對API的重複請求。使用SQLite
|
|
|
set_llm_cache(SQLiteCache(database_path=".langchain.db"))
|
|
|
-# set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6379", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
|
|
|
|
|
|
-# # TAIDE model on Ollama https://ollama.com/jcai/llama3-taide-lx-8b-chat-alpha1
|
|
|
-# def interact_with_model(messages, api_url="http://localhost:11434/v1/chat/completions"):
|
|
|
-# print("Using model: TAIDE")
|
|
|
-# response = requests.post(api_url, json={"model": "jcai/llama3-taide-lx-8b-chat-alpha1:Q4_K_M", "messages": messages})
|
|
|
-# return response.json()["choices"][0]["message"]["content"]
|
|
|
-
|
|
|
-import requests
|
|
|
-from typing import Optional, List, Any, Dict
|
|
|
-from langchain.llms.base import LLM
|
|
|
-
|
|
|
-class CustomTAIDELLM(LLM):
|
|
|
- api_url: str = "http://localhost:11434/api/chat"
|
|
|
- model_name: str = "taide-local"
|
|
|
- system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
-
|
|
|
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
- # Format the prompt according to TAIDE requirements
|
|
|
- formatted_prompt = f"<s>[INST] <<SYS>>\n{self.system_prompt}\n<</SYS>>\n\n{prompt} [/INST]"
|
|
|
- print(f"Formatted prompt being sent to TAIDE model: {formatted_prompt}")
|
|
|
-
|
|
|
- payload = {
|
|
|
- "model": self.model_name,
|
|
|
- "messages": [
|
|
|
- {"role": "system", "content": self.system_prompt},
|
|
|
- {"role": "user", "content": prompt}
|
|
|
- ]
|
|
|
- }
|
|
|
-
|
|
|
- try:
|
|
|
- response = requests.post(self.api_url, json=payload)
|
|
|
- response.raise_for_status()
|
|
|
- result = response.json()
|
|
|
- print(f"Full API response: {result}")
|
|
|
- if "message" in result:
|
|
|
- return result["message"]["content"].strip()
|
|
|
- else:
|
|
|
- print(f"Unexpected response structure: {result}")
|
|
|
- return "Error: Unexpected response from the model."
|
|
|
- except requests.RequestException as e:
|
|
|
- print(f"Error calling Ollama API: {e}")
|
|
|
- return f"Error: Unable to get response from the model. {str(e)}"
|
|
|
+system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
+
|
|
|
+class OllamaChatModel(BaseChatModel):
|
|
|
+ model_name: str = Field(default="taide-local")
|
|
|
+
|
|
|
+ def _generate(
|
|
|
+ self,
|
|
|
+ messages: List[BaseMessage],
|
|
|
+ stop: Optional[List[str]] = None,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> ChatResult:
|
|
|
+ formatted_messages = []
|
|
|
+ for msg in messages:
|
|
|
+ if isinstance(msg, HumanMessage):
|
|
|
+ formatted_messages.append({"role": "user", "content": msg.content})
|
|
|
+ elif isinstance(msg, AIMessage):
|
|
|
+ formatted_messages.append({"role": "assistant", "content": msg.content})
|
|
|
+ elif isinstance(msg, SystemMessage):
|
|
|
+ formatted_messages.append({"role": "system", "content": msg.content})
|
|
|
+
|
|
|
+ prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
|
|
+ for msg in formatted_messages:
|
|
|
+ if msg['role'] == 'user':
|
|
|
+ prompt += f"{msg['content']} [/INST]"
|
|
|
+ elif msg['role'] == "assistant":
|
|
|
+ prompt += f"{msg['content']} </s><s>[INST]"
|
|
|
+
|
|
|
+ command = ["ollama", "run", self.model_name, prompt]
|
|
|
+ result = subprocess.run(command, capture_output=True, text=True)
|
|
|
+
|
|
|
+ if result.returncode != 0:
|
|
|
+ raise Exception(f"Ollama command failed: {result.stderr}")
|
|
|
+
|
|
|
+ content = result.stdout.strip()
|
|
|
|
|
|
+ message = AIMessage(content=content)
|
|
|
+ generation = ChatGeneration(message=message)
|
|
|
+ return ChatResult(generations=[generation])
|
|
|
+
|
|
|
@property
|
|
|
def _llm_type(self) -> str:
|
|
|
- return "custom_taide"
|
|
|
-
|
|
|
- def get_model_name(self):
|
|
|
- return self.model_name
|
|
|
-
|
|
|
- @property
|
|
|
- def _identifying_params(self) -> Dict[str, Any]:
|
|
|
- return {"model_name": self.model_name}
|
|
|
-
|
|
|
-# Create an instance of the custom LLM
|
|
|
-taide_llm = CustomTAIDELLM(api_url="http://localhost:11434/api/chat", model_name="taide-local")
|
|
|
+ return "ollama-chat-model"
|
|
|
+
|
|
|
+taide_llm = OllamaChatModel(model_name="taide-local")
|
|
|
|
|
|
-# 生成多個不同版本的問題,進行檢索,並返回答案和參考文檔
|
|
|
def multi_query(question, retriever, chat_history):
|
|
|
-
|
|
|
def multi_query_chain():
|
|
|
- # Multi Query: Different Perspectives
|
|
|
template = """You are an AI language model assistant. Your task is to generate three
|
|
|
different versions of the given user question to retrieve relevant documents from a vector
|
|
|
database. By generating multiple perspectives on the user question, your goal is to help
|
|
@@ -123,27 +106,12 @@ def multi_query(question, retriever, chat_history):
|
|
|
|
|
|
You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
|
|
|
|
|
|
-
|
|
|
Original question: {question}"""
|
|
|
prompt_perspectives = ChatPromptTemplate.from_template(template)
|
|
|
|
|
|
- messages = [
|
|
|
- {"role": "system", "content": template},
|
|
|
- {"role": "user", "content": question},
|
|
|
- ]
|
|
|
- # generate_queries = interact_with_model(messages).split("\n")
|
|
|
-
|
|
|
-
|
|
|
- # llm = ChatOpenAI(model="gpt-4-1106-preview")
|
|
|
- # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
|
|
|
- # llm = ChatOllama(model="gemma2", temperature=0)
|
|
|
- # llm = ChatOllama(model=model)
|
|
|
-
|
|
|
-
|
|
|
generate_queries = (
|
|
|
prompt_perspectives
|
|
|
| taide_llm
|
|
|
- # | llm
|
|
|
| StrOutputParser()
|
|
|
| (lambda x: x.split("\n"))
|
|
|
)
|
|
@@ -151,14 +119,9 @@ def multi_query(question, retriever, chat_history):
|
|
|
return generate_queries
|
|
|
|
|
|
def get_unique_union(documents: List[list]):
|
|
|
- """ Unique union of retrieved docs """
|
|
|
- # Flatten list of lists, and convert each Document to string
|
|
|
flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
|
|
|
- # Get unique documents
|
|
|
unique_docs = list(set(flattened_docs))
|
|
|
- # Return
|
|
|
return [loads(doc) for doc in unique_docs]
|
|
|
-
|
|
|
|
|
|
_search_query = get_search_query()
|
|
|
modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
|
|
@@ -173,7 +136,6 @@ def multi_query(question, retriever, chat_history):
|
|
|
|
|
|
return answer, docs
|
|
|
|
|
|
-# 根據檢索到的文檔和用戶問題生成最後答案
|
|
|
def multi_query_rag_prompt(retrieval_chain, question):
|
|
|
template = """Answer the following question based on this context:
|
|
|
|
|
@@ -205,25 +167,8 @@ def multi_query_rag_prompt(retrieval_chain, question):
|
|
|
except Exception as e:
|
|
|
print(f"Error invoking rag_chain: {e}")
|
|
|
return "Error occurred while processing the question."
|
|
|
-
|
|
|
-########################################################################################################################
|
|
|
|
|
|
-# 將聊天紀錄個跟進問題轉化為獨立問題
|
|
|
def get_search_query():
|
|
|
- # Condense a chat history and follow-up question into a standalone question
|
|
|
- #
|
|
|
- # _template = """Given the following conversation and a follow up question,
|
|
|
- # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
|
|
|
- # Generate standalone question in its original language.
|
|
|
- # Chat History:
|
|
|
- # {chat_history}
|
|
|
- # Follow Up Input: {question}
|
|
|
-
|
|
|
- # Hint:
|
|
|
- # * Refer to chat history and add the subject to the question
|
|
|
- # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
|
|
|
-
|
|
|
- # Standalone question:""" # noqa: E501
|
|
|
_template = """Rewrite the following query by incorporating relevant context from the conversation history.
|
|
|
The rewritten query should:
|
|
|
|
|
@@ -244,7 +189,7 @@ def get_search_query():
|
|
|
"""
|
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
|
|
|
|
|
- def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
|
|
|
+ def _format_chat_history(chat_history: List[tuple[str, str]]) -> List:
|
|
|
buffer = []
|
|
|
for human, ai in chat_history:
|
|
|
buffer.append(HumanMessage(content=human))
|
|
@@ -252,11 +197,10 @@ def get_search_query():
|
|
|
return buffer
|
|
|
|
|
|
_search_query = RunnableBranch(
|
|
|
- # If input includes chat_history, we condense it with the follow-up question
|
|
|
(
|
|
|
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
|
|
|
run_name="HasChatHistoryCheck"
|
|
|
- ), # Condense follow-up question and chat into a standalone_question
|
|
|
+ ),
|
|
|
RunnablePassthrough.assign(
|
|
|
chat_history=lambda x: _format_chat_history(x["chat_history"])
|
|
|
)
|
|
@@ -264,48 +208,31 @@ def get_search_query():
|
|
|
| ChatOpenAI()
|
|
|
| StrOutputParser(),
|
|
|
),
|
|
|
- # Else, we have no chat history, so just pass through the question
|
|
|
RunnableLambda(lambda x : x["question"]),
|
|
|
)
|
|
|
|
|
|
return _search_query
|
|
|
-########################################################################################################################
|
|
|
-# 檢索文檔並生成答案
|
|
|
-def naive_rag(question, retriever):
|
|
|
- #### RETRIEVAL and GENERATION ####
|
|
|
|
|
|
- # Prompt
|
|
|
+def naive_rag(question, retriever):
|
|
|
prompt = hub.pull("rlm/rag-prompt")
|
|
|
|
|
|
- # LLM
|
|
|
- # llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
|
|
-
|
|
|
- # Post-processing
|
|
|
def format_docs(docs):
|
|
|
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
|
|
|
reference = retriever.get_relevant_documents(question)
|
|
|
|
|
|
- # Chain
|
|
|
rag_chain = (
|
|
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
| prompt
|
|
|
| taide_llm
|
|
|
- # | llm
|
|
|
| StrOutputParser()
|
|
|
)
|
|
|
|
|
|
- # Question
|
|
|
answer = rag_chain.invoke(question)
|
|
|
|
|
|
return answer, reference
|
|
|
-################################################################################################
|
|
|
-# 處理question-answer pairs的檢索和生成答案
|
|
|
-def naive_rag_for_qapairs(question, retriever):
|
|
|
- #### RETRIEVAL and GENERATION ####
|
|
|
|
|
|
- # Prompt
|
|
|
- # prompt = hub.pull("rlm/rag-prompt")
|
|
|
+def naive_rag_for_qapairs(question, retriever):
|
|
|
template = """You are an assistant for question-answering tasks.
|
|
|
Use the following pieces of retrieved context to answer the question.
|
|
|
Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
|
|
@@ -320,19 +247,13 @@ def naive_rag_for_qapairs(question, retriever):
|
|
|
"""
|
|
|
prompt = PromptTemplate.from_template(template)
|
|
|
|
|
|
- # LLM
|
|
|
llm = ChatOpenAI(model_name="gpt-4-0125-preview")
|
|
|
- # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
|
|
|
- # llm = ChatOllama(model="gemma2", num_gpu=1, temperature=0)
|
|
|
|
|
|
-
|
|
|
- # Post-processing
|
|
|
def format_docs(docs):
|
|
|
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
|
|
|
reference = retriever.get_relevant_documents(question)
|
|
|
|
|
|
- # Chain
|
|
|
rag_chain = (
|
|
|
{"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
| prompt
|
|
@@ -340,19 +261,16 @@ def naive_rag_for_qapairs(question, retriever):
|
|
|
| StrOutputParser()
|
|
|
)
|
|
|
|
|
|
- # Question
|
|
|
answer = rag_chain.invoke(question)
|
|
|
|
|
|
return answer, reference
|
|
|
-########################################################################################################################
|
|
|
|
|
|
def rag_score(question, ground_truth, answer, reference_docs):
|
|
|
-
|
|
|
datasets = {
|
|
|
- "question": [question], # question: list[str]
|
|
|
- "answer": [answer], # answer: list[str]
|
|
|
- "contexts": [reference_docs], # contexts: list[list[str]]
|
|
|
- "ground_truths": [[ground_truth]] # ground_truth: list[list[str]]
|
|
|
+ "question": [question],
|
|
|
+ "answer": [answer],
|
|
|
+ "contexts": [reference_docs],
|
|
|
+ "ground_truths": [[ground_truth]]
|
|
|
}
|
|
|
evalsets = Dataset.from_dict(datasets)
|
|
|
|
|
@@ -371,9 +289,8 @@ def rag_score(question, ground_truth, answer, reference_docs):
|
|
|
result_df.to_csv('ragas_rag.csv')
|
|
|
return result
|
|
|
|
|
|
-
|
|
|
def print_current_model(llm):
|
|
|
- if isinstance(llm, CustomTAIDELLM):
|
|
|
- print(f"Currently using model: {llm.get_model_name()}")
|
|
|
+ if isinstance(llm, OllamaChatModel):
|
|
|
+ print(f"Currently using model: {llm.model_name}")
|
|
|
else:
|
|
|
pass
|