RAG_strategy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from langchain.prompts import ChatPromptTemplate
  2. from langchain.load import dumps, loads
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_openai import ChatOpenAI
  5. from langchain_community.llms import Ollama
  6. from langchain_community.chat_models import ChatOllama
  7. from operator import itemgetter
  8. from langchain_core.runnables import RunnablePassthrough
  9. from langchain import hub
  10. from langchain.globals import set_llm_cache
  11. from langchain import PromptTemplate
  12. from langchain_core.runnables import (
  13. RunnableBranch,
  14. RunnableLambda,
  15. RunnableParallel,
  16. RunnablePassthrough,
  17. )
  18. from typing import Tuple, List, Optional
  19. from langchain_core.messages import AIMessage, HumanMessage
  20. from datasets import Dataset
  21. from ragas import evaluate
  22. from ragas.metrics import (
  23. answer_relevancy,
  24. faithfulness,
  25. context_recall,
  26. context_precision,
  27. )
  28. from typing import List
  29. from dotenv import load_dotenv
  30. load_dotenv()
  31. ########################################################################################################################
  32. ########################################################################################################################
  33. from langchain.cache import SQLiteCache
  34. from langchain.cache import RedisSemanticCache
  35. from langchain_openai import OpenAIEmbeddings
  36. # set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  37. set_llm_cache(RedisSemanticCache(redis_url="redis://localhost:6380", embedding=OpenAIEmbeddings(), score_threshold=0.0005))
  38. ########################################################################################################################
  39. def multi_query(question, retriever, chat_history):
  40. def multi_query_chain():
  41. # Multi Query: Different Perspectives
  42. template = """You are an AI language model assistant. Your task is to generate three
  43. different versions of the given user question to retrieve relevant documents from a vector
  44. database. By generating multiple perspectives on the user question, your goal is to help
  45. the user overcome some of the limitations of the distance-based similarity search.
  46. Provide these alternative questions separated by newlines.
  47. You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
  48. Original question: {question}"""
  49. prompt_perspectives = ChatPromptTemplate.from_template(template)
  50. llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
  51. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  52. generate_queries = (
  53. prompt_perspectives
  54. | llm
  55. | StrOutputParser()
  56. | (lambda x: x.split("\n"))
  57. )
  58. return generate_queries
  59. def get_unique_union(documents: List[list]):
  60. """ Unique union of retrieved docs """
  61. # Flatten list of lists, and convert each Document to string
  62. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  63. # Get unique documents
  64. unique_docs = list(set(flattened_docs))
  65. # Return
  66. return [loads(doc) for doc in unique_docs]
  67. _search_query = get_search_query()
  68. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  69. print(modified_question)
  70. generate_queries = multi_query_chain()
  71. retrieval_chain = generate_queries | retriever.map() | get_unique_union
  72. docs = retrieval_chain.invoke({"question":modified_question})
  73. answer = multi_query_rag_prompt(retrieval_chain, modified_question)
  74. return answer, docs
  75. def multi_query_rag_prompt(retrieval_chain, question):
  76. # RAG
  77. template = """Answer the following question based on this context:
  78. {context}
  79. Question: {question}
  80. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
  81. You should not mention anything about "根據提供的文件內容" or other similar terms.
  82. If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  83. """
  84. prompt = ChatPromptTemplate.from_template(template)
  85. # llm = ChatOpenAI(temperature=0)
  86. llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
  87. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  88. final_rag_chain = (
  89. {"context": retrieval_chain,
  90. "question": itemgetter("question")}
  91. | prompt
  92. | llm
  93. | StrOutputParser()
  94. )
  95. # answer = final_rag_chain.invoke({"question":question})
  96. answer = ""
  97. for text in final_rag_chain.stream({"question":question}):
  98. print(text, end="", flush=True)
  99. answer += text
  100. return answer
  101. ########################################################################################################################
  102. def get_search_query():
  103. # Condense a chat history and follow-up question into a standalone question
  104. #
  105. # _template = """Given the following conversation and a follow up question,
  106. # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
  107. # Generate standalone question in its original language.
  108. # Chat History:
  109. # {chat_history}
  110. # Follow Up Input: {question}
  111. # Hint:
  112. # * Refer to chat history and add the subject to the question
  113. # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
  114. # Standalone question:""" # noqa: E501
  115. _template = """Rewrite the following query by incorporating relevant context from the conversation history.
  116. The rewritten query should:
  117. - Preserve the core intent and meaning of the original query
  118. - Expand and clarify the query to make it more specific and informative for retrieving relevant context
  119. - Avoid introducing new topics or queries that deviate from the original query
  120. - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
  121. - The rewritten query should be in its original language.
  122. Return ONLY the rewritten query text, without any additional formatting or explanations.
  123. Conversation History:
  124. {chat_history}
  125. Original query: [{question}]
  126. Rewritten query:
  127. """
  128. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  129. def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
  130. buffer = []
  131. for human, ai in chat_history:
  132. buffer.append(HumanMessage(content=human))
  133. buffer.append(AIMessage(content=ai))
  134. return buffer
  135. _search_query = RunnableBranch(
  136. # If input includes chat_history, we condense it with the follow-up question
  137. (
  138. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  139. run_name="HasChatHistoryCheck"
  140. ), # Condense follow-up question and chat into a standalone_question
  141. RunnablePassthrough.assign(
  142. chat_history=lambda x: _format_chat_history(x["chat_history"])
  143. )
  144. | CONDENSE_QUESTION_PROMPT
  145. | ChatOpenAI(temperature=0)
  146. | StrOutputParser(),
  147. ),
  148. # Else, we have no chat history, so just pass through the question
  149. RunnableLambda(lambda x : x["question"]),
  150. )
  151. return _search_query
  152. ########################################################################################################################
  153. def naive_rag(question, retriever):
  154. #### RETRIEVAL and GENERATION ####
  155. # Prompt
  156. prompt = hub.pull("rlm/rag-prompt")
  157. # LLM
  158. llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
  159. # Post-processing
  160. def format_docs(docs):
  161. return "\n\n".join(doc.page_content for doc in docs)
  162. reference = retriever.get_relevant_documents(question)
  163. # Chain
  164. rag_chain = (
  165. {"context": retriever | format_docs, "question": RunnablePassthrough()}
  166. | prompt
  167. | llm
  168. | StrOutputParser()
  169. )
  170. # Question
  171. answer = rag_chain.invoke(question)
  172. return answer, reference
  173. ################################################################################################
  174. def naive_rag_for_qapairs(question, retriever):
  175. #### RETRIEVAL and GENERATION ####
  176. # Prompt
  177. # prompt = hub.pull("rlm/rag-prompt")
  178. template = """You are an assistant for question-answering tasks.
  179. Use the following pieces of retrieved context to answer the question.
  180. Following retrieved context is question-answer pairs of historical QA, Find the suitable answer from the qa pairs
  181. If you can not find the suitable answer, just return "False".
  182. Use three sentences maximum and Do not make up the answer.
  183. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw.
  184. {context}
  185. Question: {question}
  186. """
  187. prompt = PromptTemplate.from_template(template)
  188. # LLM
  189. llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
  190. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  191. # Post-processing
  192. def format_docs(docs):
  193. return "\n\n".join(doc.page_content for doc in docs)
  194. reference = retriever.get_relevant_documents(question)
  195. # Chain
  196. rag_chain = (
  197. {"context": retriever | format_docs, "question": RunnablePassthrough()}
  198. | prompt
  199. | llm
  200. | StrOutputParser()
  201. )
  202. # Question
  203. answer = rag_chain.invoke(question)
  204. return answer, reference
  205. ########################################################################################################################
  206. def rag_score(question, ground_truth, answer, reference_docs):
  207. datasets = {
  208. "question": [question], # question: list[str]
  209. "answer": [answer], # answer: list[str]
  210. "contexts": [reference_docs], # contexts: list[list[str]]
  211. "ground_truths": [[ground_truth]] # ground_truth: list[list[str]]
  212. }
  213. evalsets = Dataset.from_dict(datasets)
  214. result = evaluate(
  215. evalsets,
  216. metrics=[
  217. context_precision,
  218. faithfulness,
  219. answer_relevancy,
  220. context_recall,
  221. ],
  222. )
  223. return result