123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- from langchain.prompts import ChatPromptTemplate
- from langchain.load import dumps, loads
- from langchain_core.output_parsers import StrOutputParser
- from langchain_openai import ChatOpenAI
- from langchain_community.llms import Ollama
- from langchain_community.chat_models import ChatOllama
- from operator import itemgetter
- from langchain_core.runnables import RunnablePassthrough
- from langchain import hub
- from langchain.globals import set_llm_cache
- from langchain import PromptTemplate
- from langchain_core.runnables import (
- RunnableBranch,
- RunnableLambda,
- RunnableParallel,
- RunnablePassthrough,
- )
- from typing import Tuple, List, Optional
- from langchain_core.messages import AIMessage, HumanMessage
- from typing import List
- from dotenv import load_dotenv
- load_dotenv()
- # from local_llm import ollama_, hf
- # llm = hf()
- # llm = taide_llm
- ########################################################################################################################
- # from langchain.cache import SQLiteCache
- # set_llm_cache(SQLiteCache(database_path=".langchain.db"))
- ########################################################################################################################
- # llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
- # llm = ollama_()
- def multi_query_chain(llm):
- # Multi Query: Different Perspectives
- template = """
- <|begin_of_text|>
-
- <|start_header_id|>system<|end_header_id|>
- 你是一個來自台灣的ESG的AI助理,請用繁體中文回答
- You are an AI language model assistant. Your task is to generate three
- different versions of the given user question to retrieve relevant documents from a vector
- database. By generating multiple perspectives on the user question, your goal is to help
- the user overcome some of the limitations of the distance-based similarity search.
- Provide these alternative questions separated by newlines.
- You must return original question also, which means that you return 1 original version + 3 different versions = 4 questions.
- <|eot_id|>
-
- <|start_header_id|>user<|end_header_id|>
-
- Original question: {question}
- 請用繁體中文
- <|eot_id|>
-
- <|start_header_id|>assistant<|end_header_id|>
- """
- prompt_perspectives = ChatPromptTemplate.from_template(template)
-
- # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
- # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
- generate_queries = (
- prompt_perspectives
- | llm
- | StrOutputParser()
- | (lambda x: x.split("\n"))
- )
- return generate_queries
- def multi_query(question, retriever, chat_history):
- def get_unique_union(documents: List[list]):
- """ Unique union of retrieved docs """
- # Flatten list of lists, and convert each Document to string
- flattened_docs = [dumps(doc) for sublist in documents for doc in sublist]
- # Get unique documents
- unique_docs = list(set(flattened_docs))
- # Return
- return [loads(doc) for doc in unique_docs]
-
- _search_query = get_search_query()
- modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
- print(modified_question)
- generate_queries = multi_query_chain()
- retrieval_chain = generate_queries | retriever.map() | get_unique_union
- docs = retrieval_chain.invoke({"question":modified_question})
- answer = multi_query_rag_prompt(retrieval_chain, modified_question)
- return answer, docs
- def multi_query_rag_prompt(retrieval_chain, question):
- # RAG
- template = """Answer the following question based on this context:
- {context}
- Question: {question}
- Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. \n
- You should not mention anything about "根據提供的文件內容" or other similar terms.
- Use three sentences maximum and keep the answer concise.
- If you don't know the answer, just say that "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- """
- prompt = ChatPromptTemplate.from_template(template)
- # llm = ChatOpenAI(temperature=0)
- # llm = ChatOpenAI(temperature=0, model="gpt-4-1106-preview")
- # llm = ChatOllama(model="llama3", num_gpu=1, temperature=0)
- final_rag_chain = (
- {"context": retrieval_chain,
- "question": itemgetter("question")}
- | prompt
- | llm
- | StrOutputParser()
- )
- # answer = final_rag_chain.invoke({"question":question})
- answer = ""
- for text in final_rag_chain.stream({"question":question}):
- print(text, end="", flush=True)
- answer += text
- return answer
- ########################################################################################################################
- def get_search_query():
- # Condense a chat history and follow-up question into a standalone question
- #
- # _template = """Given the following conversation and a follow up question,
- # rephrase the follow up question to be a standalone question to help others understand the question without having to go back to the conversation transcript.
- # Generate standalone question in its original language.
- # Chat History:
- # {chat_history}
- # Follow Up Input: {question}
- # Hint:
- # * Refer to chat history and add the subject to the question
- # * Replace the pronouns in the question with the correct person or thing, please refer to chat history
-
- # Standalone question:""" # noqa: E501
- _template = """Rewrite the following query by incorporating relevant context from the conversation history.
- The rewritten query should:
-
- - Preserve the core intent and meaning of the original query
- - Expand and clarify the query to make it more specific and informative for retrieving relevant context
- - Avoid introducing new topics or queries that deviate from the original query
- - DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
- - The rewritten query should be in its original language.
-
- Return ONLY the rewritten query text, without any additional formatting or explanations.
-
- Conversation History:
- {chat_history}
-
- Original query: [{question}]
-
- Rewritten query:
- """
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
- def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
- buffer = []
- for human, ai in chat_history:
- buffer.append(HumanMessage(content=human))
- buffer.append(AIMessage(content=ai))
- return buffer
- _search_query = RunnableBranch(
- # If input includes chat_history, we condense it with the follow-up question
- (
- RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
- run_name="HasChatHistoryCheck"
- ), # Condense follow-up question and chat into a standalone_question
- RunnablePassthrough.assign(
- chat_history=lambda x: _format_chat_history(x["chat_history"])
- )
- | CONDENSE_QUESTION_PROMPT
- | ChatOpenAI(temperature=0)
- | StrOutputParser(),
- ),
- # Else, we have no chat history, so just pass through the question
- RunnableLambda(lambda x : x["question"]),
- )
- return _search_query
- ########################################################################################################################
- def naive_rag(question, retriever, chat_history):
- _search_query = get_search_query()
- modified_question = _search_query.invoke({"question":question, "chat_history": chat_history})
- print(modified_question)
- #### RETRIEVAL and GENERATION ####
- # Prompt
- prompt = hub.pull("rlm/rag-prompt")
- # LLM
- # llm = ChatOpenAI(model_name="gpt-4o", temperature=0)
- # Post-processing
- def format_docs(docs):
- return "\n\n".join(doc.page_content for doc in docs)
- reference = retriever.get_relevant_documents(modified_question)
-
- # Chain
- rag_chain = (
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
- | prompt
- | llm
- | StrOutputParser()
- )
- # Question
- answer = rag_chain.invoke(modified_question)
- return answer, reference
- ################################################################################################
- if __name__ == "__main__":
- from faiss_index import create_faiss_retriever, faiss_query
- global_retriever = create_faiss_retriever()
- generate_queries = multi_query_chain()
- question = "台灣為什麼要制定氣候變遷因應法?"
-
- questions = generate_queries.invoke(question)
- questions = [item for item in questions if item != ""]
- # print(questions)
-
- results = list(map(global_retriever.get_relevant_documents, questions))
- results = [item for sublist in results for item in sublist]
- print(len(results))
- print(results)
-
- # retrieval_chain = generate_queries | global_retriever.map
- # docs = retrieval_chain.invoke(question)
- # print(docs)
- # print(len(docs))
- # print(len(results))
- # for doc in results[:10]:
- # print(doc)
- # print("-----------------------------------------------------------------------")
- # results = get_unique_union(results)
- # print(len(results))
- # retrieval_chain = generate_queries | global_retriever.map | get_unique_union
- # docs = retrieval_chain.invoke(question)
- # print(len(docs))
|