RAG_strategy.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. from langchain.prompts import ChatPromptTemplate
  2. from langchain.load import dumps, loads
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_openai import ChatOpenAI
  5. from langchain_community.llms import Ollama
  6. from langchain_community.chat_models import ChatOllama
  7. from operator import itemgetter
  8. from langchain_core.runnables import RunnablePassthrough
  9. from langchain import hub
  10. from langchain.globals import set_llm_cache
  11. from langchain import PromptTemplate
  12. from langchain_core.runnables import (
  13. RunnableBranch,
  14. RunnableLambda,
  15. RunnableParallel,
  16. RunnablePassthrough,
  17. )
  18. from typing import Tuple, List, Optional
  19. from langchain_core.messages import AIMessage, HumanMessage
  20. from typing import List
  21. from dotenv import load_dotenv
  22. load_dotenv()
  23. # from local_llm import ollama_, hf
  24. # llm = hf()
  25. # llm = taide_llm
  26. ########################################################################################################################
  27. # from langchain.cache import SQLiteCache
  28. # set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  29. ########################################################################################################################
  30. # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  31. # llm = ollama_()
  32. def multi_query_chain(llm):
  33. # Multi Query: Different Perspectives
  34. template = """You are an AI language model assistant. Your task is to generate three
  35. different versions of the given user question to retrieve relevant documents from a vector
  36. database. By generating multiple perspectives on the user question, your goal is to help
  37. the user overcome some of the limitations of the distance-based similarity search.
  38. Provide these alternative questions separated by newlines.
  39. You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
  40. Original question: {question}"""
  41. prompt_perspectives = ChatPromptTemplate.from_template(template)
  42. # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
  43. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  44. generate_queries = (
  45. prompt_perspectives
  46. | llm
  47. | StrOutputParser()
  48. | (lambda x: x.split("\n"))
  49. )
  50. return generate_queries
  51. def multi_query(question, retriever, chat_history):
  52. def get_unique_union(documents: List[list]):
  53. """ Unique union of retrieved docs """
  54. # Flatten list of lists, and convert each Document to string
  55. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  56. # Get unique documents
  57. unique_docs = list(set(flattened_docs))
  58. # Return
  59. return [loads(doc) for doc in unique_docs]
  60. _search_query = get_search_query()
  61. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  62. print(modified_question)
  63. generate_queries = multi_query_chain()
  64. retrieval_chain = generate_queries | retriever.map() | get_unique_union
  65. docs = retrieval_chain.invoke({"question":modified_question})
  66. answer = multi_query_rag_prompt(retrieval_chain, modified_question)
  67. return answer, docs
  68. def multi_query_rag_prompt(retrieval_chain, question):
  69. # RAG
  70. template = """Answer the following question based on this context:
  71. {context}
  72. Question: {question}
  73. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
  74. You should not mention anything about "根據提供的文件內容" or other similar terms.
  75. Use three sentences maximum and keep the answer concise.
  76. If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  77. """
  78. prompt = ChatPromptTemplate.from_template(template)
  79. # llm = ChatOpenAI(temperature=0)
  80. # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
  81. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  82. final_rag_chain = (
  83. {"context": retrieval_chain,
  84. "question": itemgetter("question")}
  85. | prompt
  86. | llm
  87. | StrOutputParser()
  88. )
  89. # answer = final_rag_chain.invoke({"question":question})
  90. answer = ""
  91. for text in final_rag_chain.stream({"question":question}):
  92. print(text, end="", flush=True)
  93. answer += text
  94. return answer
  95. ########################################################################################################################
  96. def get_search_query():
  97. # Condense a chat history and follow-up question into a standalone question
  98. #
  99. # _template = """Given the following conversation and a follow up question,
  100. # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
  101. # Generate standalone question in its original language.
  102. # Chat History:
  103. # {chat_history}
  104. # Follow Up Input: {question}
  105. # Hint:
  106. # * Refer to chat history and add the subject to the question
  107. # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
  108. # Standalone question:""" # noqa: E501
  109. _template = """Rewrite the following query by incorporating relevant context from the conversation history.
  110. The rewritten query should:
  111. - Preserve the core intent and meaning of the original query
  112. - Expand and clarify the query to make it more specific and informative for retrieving relevant context
  113. - Avoid introducing new topics or queries that deviate from the original query
  114. - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
  115. - The rewritten query should be in its original language.
  116. Return ONLY the rewritten query text, without any additional formatting or explanations.
  117. Conversation History:
  118. {chat_history}
  119. Original query: [{question}]
  120. Rewritten query:
  121. """
  122. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  123. def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
  124. buffer = []
  125. for human, ai in chat_history:
  126. buffer.append(HumanMessage(content=human))
  127. buffer.append(AIMessage(content=ai))
  128. return buffer
  129. _search_query = RunnableBranch(
  130. # If input includes chat_history, we condense it with the follow-up question
  131. (
  132. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  133. run_name="HasChatHistoryCheck"
  134. ), # Condense follow-up question and chat into a standalone_question
  135. RunnablePassthrough.assign(
  136. chat_history=lambda x: _format_chat_history(x["chat_history"])
  137. )
  138. | CONDENSE_QUESTION_PROMPT
  139. | ChatOpenAI(temperature=0)
  140. | StrOutputParser(),
  141. ),
  142. # Else, we have no chat history, so just pass through the question
  143. RunnableLambda(lambda x : x["question"]),
  144. )
  145. return _search_query
  146. ########################################################################################################################
  147. def naive_rag(question, retriever, chat_history):
  148. _search_query = get_search_query()
  149. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  150. print(modified_question)
  151. #### RETRIEVAL and GENERATION ####
  152. # Prompt
  153. prompt = hub.pull("rlm/rag-prompt")
  154. # LLM
  155. # llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
  156. # Post-processing
  157. def format_docs(docs):
  158. return "\n\n".join(doc.page_content for doc in docs)
  159. reference = retriever.get_relevant_documents(modified_question)
  160. # Chain
  161. rag_chain = (
  162. {"context": retriever | format_docs, "question": RunnablePassthrough()}
  163. | prompt
  164. | llm
  165. | StrOutputParser()
  166. )
  167. # Question
  168. answer = rag_chain.invoke(modified_question)
  169. return answer, reference
  170. ################################################################################################
  171. if __name__ == "__main__":
  172. from faiss_index import create_faiss_retriever, faiss_query
  173. global_retriever = create_faiss_retriever()
  174. generate_queries = multi_query_chain()
  175. question = "台灣為什麼要制定氣候變遷因應法?"
  176. questions = generate_queries.invoke(question)
  177. questions = [item for item in questions if item != ""]
  178. # print(questions)
  179. results = list(map(global_retriever.get_relevant_documents, questions))
  180. results = [item for sublist in results for item in sublist]
  181. print(len(results))
  182. print(results)
  183. # retrieval_chain = generate_queries | global_retriever.map
  184. # docs = retrieval_chain.invoke(question)
  185. # print(docs)
  186. # print(len(docs))
  187. # print(len(results))
  188. # for doc in results[:10]:
  189. # print(doc)
  190. # print("-----------------------------------------------------------------------")
  191. # results = get_unique_union(results)
  192. # print(len(results))
  193. # retrieval_chain = generate_queries | global_retriever.map | get_unique_union
  194. # docs = retrieval_chain.invoke(question)
  195. # print(len(docs))