RAG_strategy.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from langchain.prompts import ChatPromptTemplate
  2. from langchain.load import dumps, loads
  3. from langchain_core.output_parsers import StrOutputParser
  4. from langchain_openai import ChatOpenAI
  5. from langchain_community.llms import Ollama
  6. from langchain_community.chat_models import ChatOllama
  7. from operator import itemgetter
  8. from langchain_core.runnables import RunnablePassthrough
  9. from langchain import hub
  10. from langchain.globals import set_llm_cache
  11. from langchain import PromptTemplate
  12. from langchain_core.runnables import (
  13. RunnableBranch,
  14. RunnableLambda,
  15. RunnableParallel,
  16. RunnablePassthrough,
  17. )
  18. from typing import Tuple, List, Optional
  19. from langchain_core.messages import AIMessage, HumanMessage
  20. from typing import List
  21. from dotenv import load_dotenv
  22. load_dotenv()
  23. # from local_llm import ollama_, hf
  24. # llm = hf()
  25. # llm = taide_llm
  26. ########################################################################################################################
  27. # from langchain.cache import SQLiteCache
  28. # set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  29. ########################################################################################################################
  30. # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
  31. # llm = ollama_()
  32. def multi_query_chain(llm):
  33. # Multi Query: Different Perspectives
  34. template = """
  35. <|begin_of_text|>
  36. <|start_header_id|>system<|end_header_id|>
  37. 你是一個來自台灣的ESG的AI助理,請用繁體中文回答
  38. You are an AI language model assistant. Your task is to generate three
  39. different versions of the given user question to retrieve relevant documents from a vector
  40. database. By generating multiple perspectives on the user question, your goal is to help
  41. the user overcome some of the limitations of the distance-based similarity search.
  42. Provide these alternative questions separated by newlines.
  43. You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
  44. <|eot_id|>
  45. <|start_header_id|>user<|end_header_id|>
  46. Original question: {question}
  47. 請用繁體中文
  48. <|eot_id|>
  49. <|start_header_id|>assistant<|end_header_id|>
  50. """
  51. prompt_perspectives = ChatPromptTemplate.from_template(template)
  52. # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
  53. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  54. generate_queries = (
  55. prompt_perspectives
  56. | llm
  57. | StrOutputParser()
  58. | (lambda x: x.split("\n"))
  59. )
  60. return generate_queries
  61. def multi_query(question, retriever, chat_history):
  62. def get_unique_union(documents: List[list]):
  63. """ Unique union of retrieved docs """
  64. # Flatten list of lists, and convert each Document to string
  65. flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
  66. # Get unique documents
  67. unique_docs = list(set(flattened_docs))
  68. # Return
  69. return [loads(doc) for doc in unique_docs]
  70. _search_query = get_search_query()
  71. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  72. print(modified_question)
  73. generate_queries = multi_query_chain()
  74. retrieval_chain = generate_queries | retriever.map() | get_unique_union
  75. docs = retrieval_chain.invoke({"question":modified_question})
  76. answer = multi_query_rag_prompt(retrieval_chain, modified_question)
  77. return answer, docs
  78. def multi_query_rag_prompt(retrieval_chain, question):
  79. # RAG
  80. template = """Answer the following question based on this context:
  81. {context}
  82. Question: {question}
  83. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
  84. You should not mention anything about "根據提供的文件內容" or other similar terms.
  85. Use three sentences maximum and keep the answer concise.
  86. If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  87. """
  88. prompt = ChatPromptTemplate.from_template(template)
  89. # llm = ChatOpenAI(temperature=0)
  90. # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
  91. # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
  92. final_rag_chain = (
  93. {"context": retrieval_chain,
  94. "question": itemgetter("question")}
  95. | prompt
  96. | llm
  97. | StrOutputParser()
  98. )
  99. # answer = final_rag_chain.invoke({"question":question})
  100. answer = ""
  101. for text in final_rag_chain.stream({"question":question}):
  102. print(text, end="", flush=True)
  103. answer += text
  104. return answer
  105. ########################################################################################################################
  106. def get_search_query():
  107. # Condense a chat history and follow-up question into a standalone question
  108. #
  109. # _template = """Given the following conversation and a follow up question,
  110. # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
  111. # Generate standalone question in its original language.
  112. # Chat History:
  113. # {chat_history}
  114. # Follow Up Input: {question}
  115. # Hint:
  116. # * Refer to chat history and add the subject to the question
  117. # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
  118. # Standalone question:""" # noqa: E501
  119. _template = """Rewrite the following query by incorporating relevant context from the conversation history.
  120. The rewritten query should:
  121. - Preserve the core intent and meaning of the original query
  122. - Expand and clarify the query to make it more specific and informative for retrieving relevant context
  123. - Avoid introducing new topics or queries that deviate from the original query
  124. - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
  125. - The rewritten query should be in its original language.
  126. Return ONLY the rewritten query text, without any additional formatting or explanations.
  127. Conversation History:
  128. {chat_history}
  129. Original query: [{question}]
  130. Rewritten query:
  131. """
  132. CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
  133. def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
  134. buffer = []
  135. for human, ai in chat_history:
  136. buffer.append(HumanMessage(content=human))
  137. buffer.append(AIMessage(content=ai))
  138. return buffer
  139. _search_query = RunnableBranch(
  140. # If input includes chat_history, we condense it with the follow-up question
  141. (
  142. RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
  143. run_name="HasChatHistoryCheck"
  144. ), # Condense follow-up question and chat into a standalone_question
  145. RunnablePassthrough.assign(
  146. chat_history=lambda x: _format_chat_history(x["chat_history"])
  147. )
  148. | CONDENSE_QUESTION_PROMPT
  149. | ChatOpenAI(temperature=0)
  150. | StrOutputParser(),
  151. ),
  152. # Else, we have no chat history, so just pass through the question
  153. RunnableLambda(lambda x : x["question"]),
  154. )
  155. return _search_query
  156. ########################################################################################################################
  157. def naive_rag(question, retriever, chat_history):
  158. _search_query = get_search_query()
  159. modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
  160. print(modified_question)
  161. #### RETRIEVAL and GENERATION ####
  162. # Prompt
  163. prompt = hub.pull("rlm/rag-prompt")
  164. # LLM
  165. # llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
  166. # Post-processing
  167. def format_docs(docs):
  168. return "\n\n".join(doc.page_content for doc in docs)
  169. reference = retriever.get_relevant_documents(modified_question)
  170. # Chain
  171. rag_chain = (
  172. {"context": retriever | format_docs, "question": RunnablePassthrough()}
  173. | prompt
  174. | llm
  175. | StrOutputParser()
  176. )
  177. # Question
  178. answer = rag_chain.invoke(modified_question)
  179. return answer, reference
  180. ################################################################################################
  181. if __name__ == "__main__":
  182. from faiss_index import create_faiss_retriever, faiss_query
  183. global_retriever = create_faiss_retriever()
  184. generate_queries = multi_query_chain()
  185. question = "台灣為什麼要制定氣候變遷因應法?"
  186. questions = generate_queries.invoke(question)
  187. questions = [item for item in questions if item != ""]
  188. # print(questions)
  189. results = list(map(global_retriever.get_relevant_documents, questions))
  190. results = [item for sublist in results for item in sublist]
  191. print(len(results))
  192. print(results)
  193. # retrieval_chain = generate_queries | global_retriever.map
  194. # docs = retrieval_chain.invoke(question)
  195. # print(docs)
  196. # print(len(docs))
  197. # print(len(results))
  198. # for doc in results[:10]:
  199. # print(doc)
  200. # print("-----------------------------------------------------------------------")
  201. # results = get_unique_union(results)
  202. # print(len(results))
  203. # retrieval_chain = generate_queries | global_retriever.map | get_unique_union
  204. # docs = retrieval_chain.invoke(question)
  205. # print(len(docs))