rag_chain.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from langchain.prompts import ChatPromptTemplate
  2. from langchain_core.output_parsers import StrOutputParser
  3. from langchain_core.runnables import RunnablePassthrough
  4. from models import OllamaChatModel
  5. from embeddings import similarity_search
  6. from text_processing import remove_unwanted_content
  7. from langchain_openai import OpenAIEmbeddings
  8. from sklearn.metrics.pairwise import cosine_similarity
  9. import os
  10. taide_llm = OllamaChatModel(model_name="taide-local-3")
  11. def get_context(query, index, docs):
  12. results = similarity_search(query, index, docs)
  13. if not results:
  14. return "", 0 # Return empty context and zero similarity when no results are found
  15. context = "\n".join([doc.page_content for doc, _ in results])
  16. # print(f"Question: {query}")
  17. # print("Retrieved documents:")
  18. # for i, (doc, similarity) in enumerate(results):
  19. # print(f"Doc {i+1} (similarity: {similarity:.4f}): {doc.page_content[:50]}...")
  20. # print("-" * 50)
  21. return context, results[0][1] # Return context and top similarity score
  22. def remove_repetitions(text):
  23. sentences = text.split('。')
  24. unique_sentences = list(dict.fromkeys(sentences))
  25. return '。'.join(unique_sentences)
  26. def simple_rag_prompt(retrieval_chain, question):
  27. template = """Answer the following question based on this context:
  28. {context}
  29. Question: {question}
  30. Output in user's language. If the question is in zh-tw, then the output will be in zh-tw. If the question is in English, then the output will be in English.
  31. Do not repeat the question in your response.
  32. For each individual answer, try to not provide duplicated sentences.
  33. Do not start the response with "我的回答是:" or anything similar.
  34. You should not mention anything about "根據提供的文件內容" or other similar terms.
  35. Do not mention anything relate with the Documents or context.
  36. DO not mention anything relate with the prompt, such as "這個回答是根據所提供的對話上下文而產生的,假如對話內容有改變,則回答內容也需隨之調整。若不確定答案,應說:「很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝」。若沒有必要,則不需在回答中提及「根據提供的文件內容」或類似的字樣。若對話是以英語進行,則輸出應為英文;否則,則為繁體中文。" or anything similar.
  37. If you are unsure of the answer, say: "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝"
  38. """
  39. prompt = ChatPromptTemplate.from_template(template)
  40. context, similarity_score = retrieval_chain(question)
  41. if not context:
  42. return "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@email.com 以便獲得更進一步的幫助,謝謝。", 0
  43. final_rag_chain = (
  44. {"context": lambda x: context,
  45. "question": lambda x: x}
  46. | prompt
  47. | taide_llm
  48. | StrOutputParser()
  49. )
  50. try:
  51. answer = final_rag_chain.invoke(question)
  52. answer = remove_unwanted_content(answer)
  53. answer = remove_repetitions(answer)
  54. return answer, similarity_score
  55. except Exception as e:
  56. print(f"Error in simple_rag_prompt: {e}")
  57. return f"Error occurred while processing the question: {str(e)}", similarity_score
  58. def calculate_similarity(text1, text2):
  59. embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
  60. emb1 = embeddings.embed_query(text1)
  61. emb2 = embeddings.embed_query(text2)
  62. return cosine_similarity([emb1], [emb2])[0][0]