main.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import time
  2. import pandas as pd
  3. import os
  4. from dotenv import load_dotenv
  5. from config import (
  6. current_dir, CSV_FILE, system_prompt,
  7. EMBEDDINGS_FILE, FAISS_INDEX_FILE
  8. )
  9. from langchain.globals import set_llm_cache
  10. from langchain_community.cache import SQLiteCache
  11. from embeddings import load_embeddings
  12. from rag_chain import simple_rag_prompt, calculate_similarity, get_context
  13. # Load environment variables
  14. load_dotenv('environment.env')
  15. # Set up cache
  16. set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  17. def main():
  18. # Number of questions to test
  19. n = 10
  20. # Load embeddings and index
  21. embeddings, docs, df, index = load_embeddings()
  22. # Define retrieval chain
  23. retrieval_chain = lambda q: get_context(q, index, docs)
  24. # Load questions from CSV
  25. csv_path = os.path.join(current_dir, CSV_FILE)
  26. qa_df = pd.read_csv(csv_path)
  27. # Output file
  28. output_file = 'rag_output.txt'
  29. with open(output_file, 'w', encoding='utf-8') as f:
  30. for i in range(n):
  31. try:
  32. question = qa_df.iloc[i]['question']
  33. original_answer = qa_df.iloc[i]['answer']
  34. print(f"Processing question {i+1}: {question}")
  35. start_time = time.time()
  36. rag_answer, similarity_score = simple_rag_prompt(retrieval_chain, question)
  37. end_time = time.time()
  38. response_time = end_time - start_time
  39. # answer_similarity = calculate_similarity(original_answer, rag_answer)
  40. # Check if rag_answer is a string before calculating similarity
  41. if isinstance(rag_answer, str):
  42. answer_similarity = calculate_similarity(original_answer, rag_answer)
  43. else:
  44. answer_similarity = 0
  45. print(f"Warning: RAG answer for question {i+1} is not a string. Answer: {rag_answer}")
  46. # Write results to file
  47. f.write(f"Question {i+1}: {question}\n")
  48. f.write(f"Original Answer: {original_answer}\n")
  49. f.write(f"RAG Answer: {rag_answer}\n")
  50. f.write(f"Response Time: {response_time:.2f} seconds\n")
  51. f.write(f"Retrieval Similarity Score: {similarity_score:.4f}\n")
  52. f.write(f"Answer Similarity Score: {answer_similarity:.4f}\n")
  53. f.write("-" * 50 + "\n")
  54. f.flush()
  55. print(f"Processed question {i+1}")
  56. # Add a small delay to avoid rate limiting
  57. time.sleep(1)
  58. except Exception as e:
  59. print(f"Error processing question {i+1}: {str(e)}")
  60. f.write(f"Error processing question {i+1}: {str(e)}\n")
  61. f.write("-" * 50 + "\n")
  62. f.flush()
  63. print(f"Output has been saved to {output_file}")
  64. if __name__ == "__main__":
  65. main()