|
@@ -1,296 +0,0 @@
|
|
|
-from langchain.prompts import ChatPromptTemplate
|
|
|
-from langchain.load import dumps, loads
|
|
|
-from langchain_core.output_parsers import StrOutputParser
|
|
|
-from langchain_openai import ChatOpenAI
|
|
|
-from langchain_community.llms import Ollama
|
|
|
-from langchain_community.chat_models import ChatOllama
|
|
|
-from operator import itemgetter
|
|
|
-from langchain_core.runnables import RunnablePassthrough
|
|
|
-from langchain import hub
|
|
|
-from langchain.globals import set_llm_cache
|
|
|
-from langchain import PromptTemplate
|
|
|
-import subprocess
|
|
|
-import json
|
|
|
-from typing import Any, List, Optional, Dict
|
|
|
-from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
-from langchain_core.language_models import BaseChatModel
|
|
|
-from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
|
|
|
-from langchain_core.outputs import ChatResult, ChatGeneration
|
|
|
-from pydantic import Field
|
|
|
-
|
|
|
-from langchain_core.runnables import (
|
|
|
- RunnableBranch,
|
|
|
- RunnableLambda,
|
|
|
- RunnableParallel,
|
|
|
- RunnablePassthrough,
|
|
|
-)
|
|
|
-
|
|
|
-from datasets import Dataset
|
|
|
-from ragas import evaluate
|
|
|
-from ragas.metrics import (
|
|
|
- answer_relevancy,
|
|
|
- faithfulness,
|
|
|
- context_recall,
|
|
|
- context_precision,
|
|
|
-)
|
|
|
-import os
|
|
|
-from dotenv import load_dotenv
|
|
|
-load_dotenv('environment.env')
|
|
|
-
|
|
|
-from langchain.cache import SQLiteCache
|
|
|
-from langchain_openai import OpenAIEmbeddings
|
|
|
-from langchain.globals import set_llm_cache
|
|
|
-
|
|
|
-import requests
|
|
|
-import openai
|
|
|
-openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
-openai.api_key = openai_api_key
|
|
|
-URI = os.getenv("SUPABASE_URI")
|
|
|
-
|
|
|
-# 設置緩存,以減少對API的重複請求。使用SQLite
|
|
|
-set_llm_cache(SQLiteCache(database_path=".langchain.db"))
|
|
|
-
|
|
|
-system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
|
|
|
-
|
|
|
-class OllamaChatModel(BaseChatModel):
|
|
|
- model_name: str = Field(default="taide-local")
|
|
|
-
|
|
|
- def _generate(
|
|
|
- self,
|
|
|
- messages: List[BaseMessage],
|
|
|
- stop: Optional[List[str]] = None,
|
|
|
- run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
- **kwargs: Any,
|
|
|
- ) -> ChatResult:
|
|
|
- formatted_messages = []
|
|
|
- for msg in messages:
|
|
|
- if isinstance(msg, HumanMessage):
|
|
|
- formatted_messages.append({"role": "user", "content": msg.content})
|
|
|
- elif isinstance(msg, AIMessage):
|
|
|
- formatted_messages.append({"role": "assistant", "content": msg.content})
|
|
|
- elif isinstance(msg, SystemMessage):
|
|
|
- formatted_messages.append({"role": "system", "content": msg.content})
|
|
|
-
|
|
|
- prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
|
|
|
- for msg in formatted_messages:
|
|
|
- if msg['role'] == 'user':
|
|
|
- prompt += f"{msg['content']} [/INST]"
|
|
|
- elif msg['role'] == "assistant":
|
|
|
- prompt += f"{msg['content']} </s><s>[INST]"
|
|
|
-
|
|
|
- command = ["ollama", "run", self.model_name, prompt]
|
|
|
- result = subprocess.run(command, capture_output=True, text=True)
|
|
|
-
|
|
|
- if result.returncode != 0:
|
|
|
- raise Exception(f"Ollama command failed: {result.stderr}")
|
|
|
-
|
|
|
- content = result.stdout.strip()
|
|
|
-
|
|
|
- message = AIMessage(content=content)
|
|
|
- generation = ChatGeneration(message=message)
|
|
|
- return ChatResult(generations=[generation])
|
|
|
-
|
|
|
- @property
|
|
|
- def _llm_type(self) -> str:
|
|
|
- return "ollama-chat-model"
|
|
|
-
|
|
|
-taide_llm = OllamaChatModel(model_name="taide-local")
|
|
|
-
|
|
|
-def multi_query(question, retriever, chat_history):
|
|
|
- def multi_query_chain():
|
|
|
- template = """You are an AI language model assistant. Your task is to generate three
|
|
|
- different versions of the given user question to retrieve relevant documents from a vector
|
|
|
- database. By generating multiple perspectives on the user question, your goal is to help
|
|
|
- the user overcome some of the limitations of the distance-based similarity search.
|
|
|
- Provide these alternative questions separated by newlines.
|
|
|
-
|
|
|
- You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
|
|
|
-
|
|
|
- Original question: {question}"""
|
|
|
- prompt_perspectives = ChatPromptTemplate.from_template(template)
|
|
|
-
|
|
|
- generate_queries = (
|
|
|
- prompt_perspectives
|
|
|
- | taide_llm
|
|
|
- | StrOutputParser()
|
|
|
- | (lambda x: x.split("\n"))
|
|
|
- )
|
|
|
-
|
|
|
- return generate_queries
|
|
|
-
|
|
|
- def get_unique_union(documents: List[list]):
|
|
|
- flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
|
|
|
- unique_docs = list(set(flattened_docs))
|
|
|
- return [loads(doc) for doc in unique_docs]
|
|
|
-
|
|
|
- _search_query = get_search_query()
|
|
|
- modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
|
|
|
- print(modified_question)
|
|
|
-
|
|
|
- generate_queries = multi_query_chain()
|
|
|
-
|
|
|
- retrieval_chain = generate_queries | retriever.map() | get_unique_union
|
|
|
- docs = retrieval_chain.invoke({"question":modified_question})
|
|
|
-
|
|
|
- answer = multi_query_rag_prompt(retrieval_chain, modified_question)
|
|
|
-
|
|
|
- return answer, docs
|
|
|
-
|
|
|
-def multi_query_rag_prompt(retrieval_chain, question):
|
|
|
- template = """Answer the following question based on this context:
|
|
|
-
|
|
|
- {context}
|
|
|
-
|
|
|
- Question: {question}
|
|
|
- Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English.
|
|
|
- You should not mention anything about "根據提供的文件內容" or other similar terms.
|
|
|
- If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。I'm sorry I cannot answer your question. Please send your question to test@email.com for further assistance. Thank you."
|
|
|
- """
|
|
|
-
|
|
|
- prompt = ChatPromptTemplate.from_template(template)
|
|
|
- context = retrieval_chain.invoke({"question": question})
|
|
|
- print(f"Retrieved context: {context[:200]}...") # Print first 200 chars of context
|
|
|
-
|
|
|
- final_rag_chain = (
|
|
|
- {"context": retrieval_chain,
|
|
|
- "question": itemgetter("question")}
|
|
|
- | prompt
|
|
|
- | taide_llm
|
|
|
- | StrOutputParser()
|
|
|
- )
|
|
|
-
|
|
|
- print(f"Sending question to model: {question}")
|
|
|
- try:
|
|
|
- answer = final_rag_chain.invoke({"question": question})
|
|
|
- print(f"Received answer: {answer}")
|
|
|
- return answer
|
|
|
- except Exception as e:
|
|
|
- print(f"Error invoking rag_chain: {e}")
|
|
|
- return "Error occurred while processing the question."
|
|
|
-
|
|
|
-def get_search_query():
|
|
|
- _template = """Rewrite the following query by incorporating relevant context from the conversation history.
|
|
|
- The rewritten query should:
|
|
|
-
|
|
|
- - Preserve the core intent and meaning of the original query
|
|
|
- - Expand and clarify the query to make it more specific and informative for retrieving relevant context
|
|
|
- - Avoid introducing new topics or queries that deviate from the original query
|
|
|
- - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
|
|
|
- - The rewritten query should be in its original language.
|
|
|
-
|
|
|
- Return ONLY the rewritten query text, without any additional formatting or explanations.
|
|
|
-
|
|
|
- Conversation History:
|
|
|
- {chat_history}
|
|
|
-
|
|
|
- Original query: [{question}]
|
|
|
-
|
|
|
- Rewritten query:
|
|
|
- """
|
|
|
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
|
|
-
|
|
|
- def _format_chat_history(chat_history: List[tuple[str, str]]) -> List:
|
|
|
- buffer = []
|
|
|
- for human, ai in chat_history:
|
|
|
- buffer.append(HumanMessage(content=human))
|
|
|
- buffer.append(AIMessage(content=ai))
|
|
|
- return buffer
|
|
|
-
|
|
|
- _search_query = RunnableBranch(
|
|
|
- (
|
|
|
- RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
|
|
|
- run_name="HasChatHistoryCheck"
|
|
|
- ),
|
|
|
- RunnablePassthrough.assign(
|
|
|
- chat_history=lambda x: _format_chat_history(x["chat_history"])
|
|
|
- )
|
|
|
- | CONDENSE_QUESTION_PROMPT
|
|
|
- | ChatOpenAI()
|
|
|
- | StrOutputParser(),
|
|
|
- ),
|
|
|
- RunnableLambda(lambda x : x["question"]),
|
|
|
- )
|
|
|
-
|
|
|
- return _search_query
|
|
|
-
|
|
|
-def naive_rag(question, retriever):
|
|
|
- prompt = hub.pull("rlm/rag-prompt")
|
|
|
-
|
|
|
- def format_docs(docs):
|
|
|
- return "\n\n".join(doc.page_content for doc in docs)
|
|
|
-
|
|
|
- reference = retriever.get_relevant_documents(question)
|
|
|
-
|
|
|
- rag_chain = (
|
|
|
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
- | prompt
|
|
|
- | taide_llm
|
|
|
- | StrOutputParser()
|
|
|
- )
|
|
|
-
|
|
|
- answer = rag_chain.invoke(question)
|
|
|
-
|
|
|
- return answer, reference
|
|
|
-
|
|
|
-def naive_rag_for_qapairs(question, retriever):
|
|
|
- template = """You are an assistant for question-answering tasks.
|
|
|
- Use the following pieces of retrieved context to answer the question.
|
|
|
- Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
|
|
|
- If you can not find the suitable answer, just return "False".
|
|
|
- Use three sentences maximum and Do not make up the answer.
|
|
|
-
|
|
|
- Output in user's language. If the question is in zh-tw, then the output will be in zh-tw.
|
|
|
-
|
|
|
- {context}
|
|
|
-
|
|
|
- Question: {question}
|
|
|
- """
|
|
|
- prompt = PromptTemplate.from_template(template)
|
|
|
-
|
|
|
- llm = ChatOpenAI(model_name="gpt-4-0125-preview")
|
|
|
-
|
|
|
- def format_docs(docs):
|
|
|
- return "\n\n".join(doc.page_content for doc in docs)
|
|
|
-
|
|
|
- reference = retriever.get_relevant_documents(question)
|
|
|
-
|
|
|
- rag_chain = (
|
|
|
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
|
|
|
- | prompt
|
|
|
- | llm
|
|
|
- | StrOutputParser()
|
|
|
- )
|
|
|
-
|
|
|
- answer = rag_chain.invoke(question)
|
|
|
-
|
|
|
- return answer, reference
|
|
|
-
|
|
|
-def rag_score(question, ground_truth, answer, reference_docs):
|
|
|
- datasets = {
|
|
|
- "question": [question],
|
|
|
- "answer": [answer],
|
|
|
- "contexts": [reference_docs],
|
|
|
- "ground_truths": [[ground_truth]]
|
|
|
- }
|
|
|
- evalsets = Dataset.from_dict(datasets)
|
|
|
-
|
|
|
- result = evaluate(
|
|
|
- evalsets,
|
|
|
- metrics=[
|
|
|
- context_precision,
|
|
|
- faithfulness,
|
|
|
- answer_relevancy,
|
|
|
- context_recall,
|
|
|
- ],
|
|
|
- )
|
|
|
-
|
|
|
- result_df = result.to_pandas()
|
|
|
- print(result_df.head())
|
|
|
- result_df.to_csv('ragas_rag.csv')
|
|
|
- return result
|
|
|
-
|
|
|
-def print_current_model(llm):
|
|
|
- if isinstance(llm, OllamaChatModel):
|
|
|
- print(f"Currently using model: {llm.model_name}")
|
|
|
- else:
|
|
|
- pass
|