123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- from langchain.prompts import ChatPromptTemplate
- from langchain.load import dumps, loads
- from langchain_core.output_parsers import StrOutputParser
- from langchain_openai import ChatOpenAI
- from langchain_community.llms import Ollama
- from langchain_community.chat_models import ChatOllama
- from operator import itemgetter
- from langchain_core.runnables import RunnablePassthrough
- from langchain import hub
- from langchain.globals import set_llm_cache
- from langchain import PromptTemplate
- from langchain_core.runnables import (
- RunnableBranch,
- RunnableLambda,
- RunnableParallel,
- RunnablePassthrough,
- )
- from typing import Tuple, List, Optional
- from langchain_core.messages import AIMessage, HumanMessage
- from typing import List
- from dotenv import load_dotenv
- load_dotenv()
- def multi_query_chain(llm):
-
- template = """You are an AI language model assistant. Your task is to generate three
- different versions of the given user question to retrieve relevant documents from a vector
- database. By generating multiple perspectives on the user question, your goal is to help
- the user overcome some of the limitations of the distance-based similarity search.
- Provide these alternative questions separated by newlines.
- You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
-
-
- Original question: {question}"""
- prompt_perspectives = ChatPromptTemplate.from_template(template)
-
-
-
- generate_queries = (
- prompt_perspectives
- | llm
- | StrOutputParser()
- | (lambda x: x.split("\n"))
- )
- return generate_queries
- def multi_query(question, retriever, chat_history):
- def get_unique_union(documents: List[list]):
- """ Unique union of retrieved docs """
-
- flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
-
- unique_docs = list(set(flattened_docs))
-
- return [loads(doc) for doc in unique_docs]
-
- _search_query = get_search_query()
- modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
- print(modified_question)
- generate_queries = multi_query_chain()
- retrieval_chain = generate_queries | retriever.map() | get_unique_union
- docs = retrieval_chain.invoke({"question":modified_question})
- answer = multi_query_rag_prompt(retrieval_chain, modified_question)
- return answer, docs
- def multi_query_rag_prompt(retrieval_chain, question):
-
- template = """Answer the following question based on this context:
- {context}
- Question: {question}
- Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
- You should not mention anything about "根據提供的文件內容" or other similar terms.
- Use three sentences maximum and keep the answer concise.
- If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- """
- prompt = ChatPromptTemplate.from_template(template)
-
-
-
- final_rag_chain = (
- {"context": retrieval_chain,
- "question": itemgetter("question")}
- | prompt
- | llm
- | StrOutputParser()
- )
-
- answer = ""
- for text in final_rag_chain.stream({"question":question}):
- print(text, end="", flush=True)
- answer += text
- return answer
- def get_search_query():
-
-
-
-
-
-
-
-
-
-
-
-
-
- _template = """Rewrite the following query by incorporating relevant context from the conversation history.
- The rewritten query should:
-
- - Preserve the core intent and meaning of the original query
- - Expand and clarify the query to make it more specific and informative for retrieving relevant context
- - Avoid introducing new topics or queries that deviate from the original query
- - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
- - The rewritten query should be in its original language.
-
- Return ONLY the rewritten query text, without any additional formatting or explanations.
-
- Conversation History:
- {chat_history}
-
- Original query: [{question}]
-
- Rewritten query:
- """
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
- def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
- buffer = []
- for human, ai in chat_history:
- buffer.append(HumanMessage(content=human))
- buffer.append(AIMessage(content=ai))
- return buffer
- _search_query = RunnableBranch(
-
- (
- RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
- run_name="HasChatHistoryCheck"
- ),
- RunnablePassthrough.assign(
- chat_history=lambda x: _format_chat_history(x["chat_history"])
- )
- | CONDENSE_QUESTION_PROMPT
- | ChatOpenAI(temperature=0)
- | StrOutputParser(),
- ),
-
- RunnableLambda(lambda x : x["question"]),
- )
- return _search_query
- def naive_rag(question, retriever, chat_history):
- _search_query = get_search_query()
- modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
- print(modified_question)
-
-
- prompt = hub.pull("rlm/rag-prompt")
-
-
-
- def format_docs(docs):
- return "\n\n".join(doc.page_content for doc in docs)
- reference = retriever.get_relevant_documents(modified_question)
-
-
- rag_chain = (
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
- | prompt
- | llm
- | StrOutputParser()
- )
-
- answer = rag_chain.invoke(modified_question)
- return answer, reference
- if __name__ == "__main__":
- from faiss_index import create_faiss_retriever, faiss_query
- global_retriever = create_faiss_retriever()
- generate_queries = multi_query_chain()
- question = "台灣為什麼要制定氣候變遷因應法?"
-
- questions = generate_queries.invoke(question)
- questions = [item for item in questions if item != ""]
-
-
- results = list(map(global_retriever.get_relevant_documents, questions))
- results = [item for sublist in results for item in sublist]
- print(len(results))
- print(results)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
|