main.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import time
  2. import os
  3. from dotenv import load_dotenv
  4. from config import system_prompt
  5. from langchain.globals import set_llm_cache
  6. from langchain_community.cache import SQLiteCache
  7. from embeddings import load_embeddings
  8. from rag_chain import simple_rag_prompt, get_context
  9. import pandas as pd
  10. from langchain_openai import OpenAIEmbeddings
  11. from langchain_chroma import Chroma
  12. # Load environment variables
  13. load_dotenv('environment.env')
  14. # Set up cache
  15. set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  16. # Load standard QA data
  17. qa_df = pd.read_csv('log_record_rows.csv')
  18. # Initialize embeddings and Chroma
  19. embeddings_model = OpenAIEmbeddings()
  20. vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
  21. def search_similarity(query, SIMILARITY_THRESHOLD=0.82):
  22. docs_and_scores = vectorstore.similarity_search_with_relevance_scores(query, k=1)
  23. if docs_and_scores:
  24. doc, score = docs_and_scores[0]
  25. if score >= SIMILARITY_THRESHOLD:
  26. # Find the corresponding answer in the qa_df
  27. question = doc.page_content
  28. answer = qa_df[qa_df['question'] == question]['answer'].values
  29. if len(answer) > 0:
  30. return answer[0], score
  31. return None, 0
  32. def main():
  33. # Load embeddings and index
  34. embeddings, docs, df, index = load_embeddings()
  35. # Define retrieval chain
  36. retrieval_chain = lambda q: get_context(q, index, docs)
  37. print("RAG system initialized. You can start asking questions.")
  38. print("Type 'quit' to exit the program.")
  39. while True:
  40. user_input = input("\nEnter your question: ")
  41. if user_input.lower() == 'quit':
  42. break
  43. start_time = time.time()
  44. try:
  45. # First, search in the standard QA
  46. standard_answer, similarity_score = search_similarity(user_input)
  47. if standard_answer is not None:
  48. print(f"\nAnswer (from standard QA): {standard_answer}")
  49. # print(f"Similarity Score: {similarity_score:.4f}")
  50. else:
  51. # If not found in standard QA, use RAG
  52. rag_answer, rag_similarity_score = simple_rag_prompt(retrieval_chain, user_input)
  53. print(f"\nAnswer (from RAG): {rag_answer}")
  54. # print(f"Retrieval Similarity Score: {rag_similarity_score:.4f}")
  55. end_time = time.time()
  56. response_time = end_time - start_time
  57. print(f"Response Time: {response_time:.2f} seconds")
  58. except Exception as e:
  59. print(f"Error processing question: {str(e)}")
  60. # Add a small delay to avoid rate limiting
  61. time.sleep(1)
  62. if __name__ == "__main__":
  63. main()