RAG_strategy_Taide.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from langchain.prompts import ChatPromptTemplate
  2. from langchain.load import dumps, loads
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_openai import ChatOpenAI
  5. from langchain_community.llms import Ollama
  6. from langchain_community.chat_models import ChatOllama
  7. from operator import itemgetter
  8. from langchain_core.runnables import RunnablePassthrough
  9. from langchain import hub
  10. from langchain.globals import set_llm_cache
  11. from langchain import PromptTemplate
  12. import subprocess
  13. import json
  14. from typing import Any, List, Optional, Dict
  15. from langchain_core.callbacks import CallbackManagerForLLMRun
  16. from langchain_core.language_models import BaseChatModel
  17. from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
  18. from langchain_core.outputs import ChatResult, ChatGeneration
  19. from pydantic import Field
  20. from langchain_core.runnables import (
  21. RunnableBranch,
  22. RunnableLambda,
  23. RunnableParallel,
  24. RunnablePassthrough,
  25. )
  26. from datasets import Dataset
  27. from ragas import evaluate
  28. from ragas.metrics import (
  29. answer_relevancy,
  30. faithfulness,
  31. context_recall,
  32. context_precision,
  33. )
  34. import os
  35. from dotenv import load_dotenv
  36. load_dotenv('environment.env')
  37. from langchain.cache import SQLiteCache
  38. from langchain_openai import OpenAIEmbeddings
  39. from langchain.globals import set_llm_cache
  40. import requests
  41. import openai
  42. openai_api_key = os.getenv("OPENAI_API_KEY")
  43. openai.api_key = openai_api_key
  44. URI = os.getenv("SUPABASE_URI")
  45. # 設置緩存,以減少對API的重複請求。使用SQLite
  46. set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  47. system_prompt: str = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
  48. class OllamaChatModel(BaseChatModel):
  49. model_name: str = Field(default="taide-local")
  50. def _generate(
  51. self,
  52. messages: List[BaseMessage],
  53. stop: Optional[List[str]] = None,
  54. run_manager: Optional[CallbackManagerForLLMRun] = None,
  55. **kwargs: Any,
  56. ) -> ChatResult:
  57. formatted_messages = []
  58. for msg in messages:
  59. if isinstance(msg, HumanMessage):
  60. formatted_messages.append({"role": "user", "content": msg.content})
  61. elif isinstance(msg, AIMessage):
  62. formatted_messages.append({"role": "assistant", "content": msg.content})
  63. elif isinstance(msg, SystemMessage):
  64. formatted_messages.append({"role": "system", "content": msg.content})
  65. prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
  66. for msg in formatted_messages:
  67. if msg['role'] == 'user':
  68. prompt += f"{msg['content']} [/INST]"
  69. elif msg['role'] == "assistant":
  70. prompt += f"{msg['content']} </s><s>[INST]"
  71. command = ["ollama", "run", self.model_name, prompt]
  72. result = subprocess.run(command, capture_output=True, text=True)
  73. if result.returncode != 0:
  74. raise Exception(f"Ollama command failed: {result.stderr}")
  75. content = result.stdout.strip()
  76. message = AIMessage(content=content)
  77. generation = ChatGeneration(message=message)
  78. return ChatResult(generations=[generation])
  79. @property
  80. def _llm_type(self) -> str:
  81. return "ollama-chat-model"
  82. taide_llm = OllamaChatModel(model_name="taide-local")
  83. def multi_query(question, retriever, chat_history):
  84. def multi_query_chain():
  85. template = """You are an AI language model assistant. Your task is to generate three
  86. different versions of the given user question to retrieve relevant documents from a vector
  87. database. By generating multiple perspectives on the user question, your goal is to help
  88. the user overcome some of the limitations of the distance-based similarity search.
  89. Provide these alternative questions separated by newlines.
  90. You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
  91. Original question: {question}"""
  92. prompt_perspectives = ChatPromptTemplate.from_template(template)
  93. generate_queries = (
  94. prompt_perspectives
  95. | taide_llm
  96. | StrOutputParser()
  97. | (lambda x: x.split("\n"))
  98. )
  99. return generate_queries
  100. def get_unique_union(documents: List[list]):
  101. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  102. unique_docs = list(set(flattened_docs))
  103. return [loads(doc) for doc in unique_docs]
  104. _search_query = get_search_query()
  105. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  106. # print(modified_question)
  107. generate_queries = multi_query_chain()
  108. retrieval_chain = generate_queries | retriever.map() | get_unique_union
  109. docs = retrieval_chain.invoke({"question":modified_question})
  110. answer = multi_query_rag_prompt(retrieval_chain, modified_question)
  111. return answer, docs
  112. def multi_query_rag_prompt(retrieval_chain, question):
  113. template = """Answer the following question based on this context:
  114. {context}
  115. Question: {question}
  116. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English.
  117. You should not mention anything about "根據提供的文件內容" or other similar terms.
  118. If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。I'm sorry I cannot answer your question. Please send your question to test@email.com for further assistance. Thank you."
  119. """
  120. prompt = ChatPromptTemplate.from_template(template)
  121. context = retrieval_chain.invoke({"question": question})
  122. # print(f"Retrieved context: {context[:200]}...") # Print first 200 chars of context
  123. final_rag_chain = (
  124. {"context": retrieval_chain,
  125. "question": itemgetter("question")}
  126. | prompt
  127. | taide_llm
  128. | StrOutputParser()
  129. )
  130. print(f"Sending question to model: {question}")
  131. try:
  132. answer = final_rag_chain.invoke({"question": question})
  133. # print(f"Received answer: {answer}")
  134. return answer
  135. except Exception as e:
  136. print(f"Error invoking rag_chain: {e}")
  137. return "Error occurred while processing the question."
  138. def get_search_query():
  139. _template = """Rewrite the following query by incorporating relevant context from the conversation history.
  140. The rewritten query should:
  141. - Preserve the core intent and meaning of the original query
  142. - Expand and clarify the query to make it more specific and informative for retrieving relevant context
  143. - Avoid introducing new topics or queries that deviate from the original query
  144. - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
  145. - The rewritten query should be in its original language.
  146. Return ONLY the rewritten query text, without any additional formatting or explanations.
  147. Conversation History:
  148. {chat_history}
  149. Original query: [{question}]
  150. Rewritten query:
  151. """
  152. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  153. def _format_chat_history(chat_history: List[tuple[str, str]]) -> List:
  154. buffer = []
  155. for human, ai in chat_history:
  156. buffer.append(HumanMessage(content=human))
  157. buffer.append(AIMessage(content=ai))
  158. return buffer
  159. _search_query = RunnableBranch(
  160. (
  161. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  162. run_name="HasChatHistoryCheck"
  163. ),
  164. RunnablePassthrough.assign(
  165. chat_history=lambda x: _format_chat_history(x["chat_history"])
  166. )
  167. | CONDENSE_QUESTION_PROMPT
  168. | ChatOpenAI()
  169. | StrOutputParser(),
  170. ),
  171. RunnableLambda(lambda x : x["question"]),
  172. )
  173. return _search_query