RAG_strategy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. from langchain.prompts import ChatPromptTemplate
  2. from langchain.load import dumps, loads
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_openai import ChatOpenAI
  5. from langchain_community.llms import Ollama
  6. from langchain_community.chat_models import ChatOllama
  7. from operator import itemgetter
  8. from langchain_core.runnables import RunnablePassthrough
  9. from langchain import hub
  10. from langchain.globals import set_llm_cache
  11. from langchain import PromptTemplate
  12. from langchain_core.runnables import (
  13. RunnableBranch,
  14. RunnableLambda,
  15. RunnableParallel,
  16. RunnablePassthrough,
  17. )
  18. from typing import Tuple, List, Optional
  19. from langchain_core.messages import AIMessage, HumanMessage
  20. from datasets import Dataset
  21. from ragas import evaluate
  22. from ragas.metrics import (
  23. answer_relevancy,
  24. faithfulness,
  25. context_recall,
  26. context_precision,
  27. )
  28. from typing import List
  29. import os
  30. from dotenv import load_dotenv
  31. load_dotenv('environment.env')
  32. ########################################################################################################################
  33. ########################################################################################################################
  34. from langchain.cache import SQLiteCache
  35. from langchain.cache import RedisSemanticCache
  36. from langchain_openai import OpenAIEmbeddings
  37. from langchain.globals import set_llm_cache
  38. ########################################################################################################################
  39. import requests
  40. import openai
  41. openai_api_key = os.getenv("OPENAI_API_KEY")
  42. openai.api_key = openai_api_key
  43. URI = os.getenv("SUPABASE_URI")
  44. # 設置緩存,以減少對API的重複請求。使用Redis
  45. # set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  46. # set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6380", embedding=OpenAIEmbeddings(openai_api_key=openai_api_key), score_threshold=0.0005))
  47. # # TAIDE model on Ollama https://ollama.com/jcai/llama3-taide-lx-8b-chat-alpha1
  48. # def interact_with_model(messages, api_url="http://localhost:11434/v1/chat/completions"):
  49. # print("Using model: TAIDE")
  50. # response = requests.post(api_url, json={"model": "jcai/llama3-taide-lx-8b-chat-alpha1:Q4_K_M", "messages": messages})
  51. # return response.json()["choices"][0]["message"]["content"]
  52. # class CustomTAIDELLM(LLM):
  53. # api_url: str = "http://localhost:11434/v1/chat/completions"
  54. # def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
  55. # messages = [{"role": "user", "content": prompt}]
  56. # response = requests.post(self.api_url, json={
  57. # "model": "taide-local", # Use your local model name
  58. # "messages": messages
  59. # })
  60. # return response.json()["choices"][0]["message"]["content"]
  61. # @property
  62. # def _llm_type(self) -> str:
  63. # return "custom_taide"
  64. # # Create an instance of the custom LLM
  65. # taide_llm = CustomTAIDELLM()
  66. # 生成多個不同版本的問題,進行檢索,並返回答案和參考文檔
  67. def multi_query(question, retriever, chat_history):
  68. def multi_query_chain():
  69. # Multi Query: Different Perspectives
  70. template = """You are an AI language model assistant. Your task is to generate three
  71. different versions of the given user question to retrieve relevant documents from a vector
  72. database. By generating multiple perspectives on the user question, your goal is to help
  73. the user overcome some of the limitations of the distance-based similarity search.
  74. Provide these alternative questions separated by newlines.
  75. You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
  76. Original question: {question}"""
  77. prompt_perspectives = ChatPromptTemplate.from_template(template)
  78. messages = [
  79. {"role": "system", "content": template},
  80. {"role": "user", "content": question},
  81. ]
  82. # generate_queries = interact_with_model(messages).split("\n")
  83. llm = ChatOpenAI(model="gpt-4-1106-preview")
  84. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  85. # llm = ChatOllama(model="gemma2", temperature=0)
  86. # llm = ChatOllama(model=model)
  87. generate_queries = (
  88. prompt_perspectives
  89. | llm
  90. | StrOutputParser()
  91. | (lambda x: x.split("\n"))
  92. )
  93. return generate_queries
  94. def get_unique_union(documents: List[list]):
  95. """ Unique union of retrieved docs """
  96. # Flatten list of lists, and convert each Document to string
  97. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  98. # Get unique documents
  99. unique_docs = list(set(flattened_docs))
  100. # Return
  101. return [loads(doc) for doc in unique_docs]
  102. _search_query = get_search_query()
  103. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  104. print(modified_question)
  105. generate_queries = multi_query_chain()
  106. retrieval_chain = generate_queries | retriever.map() | get_unique_union
  107. docs = retrieval_chain.invoke({"question":modified_question})
  108. answer = multi_query_rag_prompt(retrieval_chain, modified_question)
  109. return answer, docs
  110. # 根據檢索到的文檔和用戶問題生成最後答案
  111. def multi_query_rag_prompt(retrieval_chain, question):
  112. # RAG
  113. template = """Answer the following question based on this context:
  114. {context}
  115. Question: {question}
  116. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English\n
  117. You should not mention anything about "根據提供的文件內容" or other similar terms.
  118. If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。I'm sorry I cannot answer your question. Please send your question to test@email.com for further assistance. Thank you."
  119. """
  120. prompt = ChatPromptTemplate.from_template(template)
  121. context = retrieval_chain.invoke({"question": question}) # Ensure this returns the context
  122. # llm = ChatOpenAI(temperature=0)
  123. llm = ChatOpenAI(model="gpt-4-1106-preview")
  124. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  125. # llm = ChatOllama(model="gemma2", temperature=0)
  126. final_rag_chain = (
  127. {"context": retrieval_chain,
  128. "question": itemgetter("question")}
  129. | prompt
  130. | llm
  131. | StrOutputParser()
  132. )
  133. messages = [
  134. {"role": "system", "content": template},
  135. {"role": "user", "content": question},
  136. {"role": "assistant", "content": context}
  137. ]
  138. # answer = interact_with_model(messages)
  139. answer = final_rag_chain.invoke({"question":question})
  140. answer = ""
  141. for text in final_rag_chain.stream({"question":question}):
  142. print(text, end="", flush=True)
  143. answer += text
  144. return answer
  145. ########################################################################################################################
  146. # 將聊天紀錄個跟進問題轉化為獨立問題
  147. def get_search_query():
  148. # Condense a chat history and follow-up question into a standalone question
  149. #
  150. # _template = """Given the following conversation and a follow up question,
  151. # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
  152. # Generate standalone question in its original language.
  153. # Chat History:
  154. # {chat_history}
  155. # Follow Up Input: {question}
  156. # Hint:
  157. # * Refer to chat history and add the subject to the question
  158. # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
  159. # Standalone question:""" # noqa: E501
  160. _template = """Rewrite the following query by incorporating relevant context from the conversation history.
  161. The rewritten query should:
  162. - Preserve the core intent and meaning of the original query
  163. - Expand and clarify the query to make it more specific and informative for retrieving relevant context
  164. - Avoid introducing new topics or queries that deviate from the original query
  165. - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
  166. - The rewritten query should be in its original language.
  167. Return ONLY the rewritten query text, without any additional formatting or explanations.
  168. Conversation History:
  169. {chat_history}
  170. Original query: [{question}]
  171. Rewritten query:
  172. """
  173. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  174. def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
  175. buffer = []
  176. for human, ai in chat_history:
  177. buffer.append(HumanMessage(content=human))
  178. buffer.append(AIMessage(content=ai))
  179. return buffer
  180. _search_query = RunnableBranch(
  181. # If input includes chat_history, we condense it with the follow-up question
  182. (
  183. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  184. run_name="HasChatHistoryCheck"
  185. ), # Condense follow-up question and chat into a standalone_question
  186. RunnablePassthrough.assign(
  187. chat_history=lambda x: _format_chat_history(x["chat_history"])
  188. )
  189. | CONDENSE_QUESTION_PROMPT
  190. | ChatOpenAI()
  191. | StrOutputParser(),
  192. ),
  193. # Else, we have no chat history, so just pass through the question
  194. RunnableLambda(lambda x : x["question"]),
  195. )
  196. return _search_query
  197. ########################################################################################################################
  198. # 檢索文檔並生成答案
  199. def naive_rag(question, retriever):
  200. #### RETRIEVAL and GENERATION ####
  201. # Prompt
  202. prompt = hub.pull("rlm/rag-prompt")
  203. # LLM
  204. llm = ChatOpenAI(model_name="gpt-3.5-turbo")
  205. # Post-processing
  206. def format_docs(docs):
  207. return "\n\n".join(doc.page_content for doc in docs)
  208. reference = retriever.get_relevant_documents(question)
  209. # Chain
  210. rag_chain = (
  211. {"context": retriever | format_docs, "question": RunnablePassthrough()}
  212. | prompt
  213. | llm
  214. | StrOutputParser()
  215. )
  216. # Question
  217. answer = rag_chain.invoke(question)
  218. return answer, reference
  219. ################################################################################################
  220. # 處理question-answer pairs的檢索和生成答案
  221. def naive_rag_for_qapairs(question, retriever):
  222. #### RETRIEVAL and GENERATION ####
  223. # Prompt
  224. # prompt = hub.pull("rlm/rag-prompt")
  225. template = """You are an assistant for question-answering tasks.
  226. Use the following pieces of retrieved context to answer the question.
  227. Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
  228. If you can not find the suitable answer, just return "False".
  229. Use three sentences maximum and Do not make up the answer.
  230. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw.
  231. {context}
  232. Question: {question}
  233. """
  234. prompt = PromptTemplate.from_template(template)
  235. # LLM
  236. llm = ChatOpenAI(model_name="gpt-4-0125-preview")
  237. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  238. # llm = ChatOllama(model="gemma2", num_gpu=1, temperature=0)
  239. # Post-processing
  240. def format_docs(docs):
  241. return "\n\n".join(doc.page_content for doc in docs)
  242. reference = retriever.get_relevant_documents(question)
  243. # Chain
  244. rag_chain = (
  245. {"context": retriever | format_docs, "question": RunnablePassthrough()}
  246. | prompt
  247. | llm
  248. | StrOutputParser()
  249. )
  250. # Question
  251. answer = rag_chain.invoke(question)
  252. return answer, reference
  253. ########################################################################################################################
  254. def rag_score(question, ground_truth, answer, reference_docs):
  255. datasets = {
  256. "question": [question], # question: list[str]
  257. "answer": [answer], # answer: list[str]
  258. "contexts": [reference_docs], # contexts: list[list[str]]
  259. "ground_truths": [[ground_truth]] # ground_truth: list[list[str]]
  260. }
  261. evalsets = Dataset.from_dict(datasets)
  262. result = evaluate(
  263. evalsets,
  264. metrics=[
  265. context_precision,
  266. faithfulness,
  267. answer_relevancy,
  268. context_recall,
  269. ],
  270. )
  271. result_df = result.to_pandas()
  272. print(result_df.head())
  273. result_df.to_csv('ragas_rag.csv')
  274. return result