from langchain.prompts import ChatPromptTemplate from langchain.load import dumps, loads from langchain_core.output_parsers import StrOutputParser from langchain_openai import ChatOpenAI from langchain_community.llms import Ollama from langchain_community.chat_models import ChatOllama from operator import itemgetter from langchain_core.runnables import RunnablePassthrough from langchain import hub from langchain.globals import set_llm_cache from langchain import PromptTemplate from langchain_core.runnables import ( RunnableBranch, RunnableLambda, RunnableParallel, RunnablePassthrough, ) from typing import Tuple, List, Optional from langchain_core.messages import AIMessage, HumanMessage from typing import List from dotenv import load_dotenv load_dotenv() # from local_llm import ollama_, hf # llm = hf() # llm = taide_llm ######################################################################################################################## # from langchain.cache import SQLiteCache # set_llm_cache(SQLiteCache(database_path=".langchain.db")) ######################################################################################################################## # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) # llm = ollama_() def multi_query_chain(llm): # Multi Query: Different Perspectives template = """You are an AI language model assistant. Your task is to generate three different versions of the given user question to retrieve relevant documents from a vector database. By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of the distance-based similarity search. Provide these alternative questions separated by newlines. You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions. Original question: {question}""" prompt_perspectives = ChatPromptTemplate.from_template(template) # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview") # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0) generate_queries = ( prompt_perspectives | llm | StrOutputParser() | (lambda x: x.split("\n")) ) return generate_queries def multi_query(question, retriever, chat_history): def get_unique_union(documents: List[list]): """ Unique union of retrieved docs """ # Flatten list of lists, and convert each Document to string flattened_docs = [dumps(doc) for sublist in documents for doc in sublist] # Get unique documents unique_docs = list(set(flattened_docs)) # Return return [loads(doc) for doc in unique_docs] _search_query = get_search_query() modified_question = _search_query.invoke({"question":question, "chat_history": chat_history}) print(modified_question) generate_queries = multi_query_chain() retrieval_chain = generate_queries | retriever.map() | get_unique_union docs = retrieval_chain.invoke({"question":modified_question}) answer = multi_query_rag_prompt(retrieval_chain, modified_question) return answer, docs def multi_query_rag_prompt(retrieval_chain, question): # RAG template = """Answer the following question based on this context: {context} Question: {question} Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n You should not mention anything about "根據提供的文件內容" or other similar terms. Use three sentences maximum and keep the answer concise. If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。" """ prompt = ChatPromptTemplate.from_template(template) # llm = ChatOpenAI(temperature=0) # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview") # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0) final_rag_chain = ( {"context": retrieval_chain, "question": itemgetter("question")} | prompt | llm | StrOutputParser() ) # answer = final_rag_chain.invoke({"question":question}) answer = "" for text in final_rag_chain.stream({"question":question}): print(text, end="", flush=True) answer += text return answer ######################################################################################################################## def get_search_query(): # Condense a chat history and follow-up question into a standalone question # # _template = """Given the following conversation and a follow up question, # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript. # Generate standalone question in its original language. # Chat History: # {chat_history} # Follow Up Input: {question} # Hint: # * Refer to chat history and add the subject to the question # * Replace the pronouns in the question with the correct person or thing, please refer to chat history # Standalone question:""" # noqa: E501 _template = """Rewrite the following query by incorporating relevant context from the conversation history. The rewritten query should: - Preserve the core intent and meaning of the original query - Expand and clarify the query to make it more specific and informative for retrieving relevant context - Avoid introducing new topics or queries that deviate from the original query - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query - The rewritten query should be in its original language. Return ONLY the rewritten query text, without any additional formatting or explanations. Conversation History: {chat_history} Original query: [{question}] Rewritten query: """ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List: buffer = [] for human, ai in chat_history: buffer.append(HumanMessage(content=human)) buffer.append(AIMessage(content=ai)) return buffer _search_query = RunnableBranch( # If input includes chat_history, we condense it with the follow-up question ( RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config( run_name="HasChatHistoryCheck" ), # Condense follow-up question and chat into a standalone_question RunnablePassthrough.assign( chat_history=lambda x: _format_chat_history(x["chat_history"]) ) | CONDENSE_QUESTION_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(), ), # Else, we have no chat history, so just pass through the question RunnableLambda(lambda x : x["question"]), ) return _search_query ######################################################################################################################## def naive_rag(question, retriever, chat_history): _search_query = get_search_query() modified_question = _search_query.invoke({"question":question, "chat_history": chat_history}) print(modified_question) #### RETRIEVAL and GENERATION #### # Prompt prompt = hub.pull("rlm/rag-prompt") # LLM # llm = ChatOpenAI(model_name="gpt-4o", temperature=0) # Post-processing def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) reference = retriever.get_relevant_documents(modified_question) # Chain rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) # Question answer = rag_chain.invoke(modified_question) return answer, reference ################################################################################################ if __name__ == "__main__": from faiss_index import create_faiss_retriever, faiss_query global_retriever = create_faiss_retriever() generate_queries = multi_query_chain() question = "台灣為什麼要制定氣候變遷因應法?" questions = generate_queries.invoke(question) questions = [item for item in questions if item != ""] # print(questions) results = list(map(global_retriever.get_relevant_documents, questions)) results = [item for sublist in results for item in sublist] print(len(results)) print(results) # retrieval_chain = generate_queries | global_retriever.map # docs = retrieval_chain.invoke(question) # print(docs) # print(len(docs)) # print(len(results)) # for doc in results[:10]: # print(doc) # print("-----------------------------------------------------------------------") # results = get_unique_union(results) # print(len(results)) # retrieval_chain = generate_queries | global_retriever.map | get_unique_union # docs = retrieval_chain.invoke(question) # print(len(docs))