main.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import time
  2. import pandas as pd
  3. from config import (
  4. current_dir, CSV_FILE, system_prompt,
  5. EMBEDDINGS_FILE, FAISS_INDEX_FILE
  6. )
  7. from langchain.globals import set_llm_cache
  8. from langchain_community.cache import SQLiteCache
  9. from embeddings import load_embeddings
  10. from rag_chain import get_context, simple_rag_prompt, calculate_similarity
  11. # Set up cache
  12. set_llm_cache(SQLiteCache(database_path=".langchain.db"))
  13. def main():
  14. # 測試前N個問題
  15. n = 8
  16. embeddings, docs, df, index = load_embeddings()
  17. retrieval_chain = lambda q: get_context(q, index, docs)
  18. csv_path = f"{current_dir}/{CSV_FILE}"
  19. qa_df = pd.read_csv(csv_path)
  20. output_file = 'rag_output.txt'
  21. with open(output_file, 'w', encoding='utf-8') as f:
  22. for i in range(n):
  23. try:
  24. question = qa_df.iloc[i]['question']
  25. original_answer = qa_df.iloc[i]['answer']
  26. start_time = time.time()
  27. rag_answer, similarity_score = simple_rag_prompt(retrieval_chain, question)
  28. end_time = time.time()
  29. response_time = end_time - start_time
  30. answer_similarity = calculate_similarity(original_answer, rag_answer)
  31. f.write(f"Question {i+1}: {question}\n")
  32. f.write(f"Original Answer: {original_answer}\n")
  33. f.write(f"RAG Answer: {rag_answer}\n")
  34. f.write(f"Response Time: {response_time:.2f} seconds\n")
  35. f.write(f"Retrieval Similarity Score: {similarity_score:.4f}\n")
  36. f.write(f"Answer Similarity Score: {answer_similarity:.4f}\n")
  37. f.write("-" * 50 + "\n")
  38. f.flush()
  39. print(f"Processed question {i+1}")
  40. time.sleep(1)
  41. except Exception as e:
  42. print(f"Error processing question {i+1}: {str(e)}")
  43. f.write(f"Error processing question {i+1}: {str(e)}\n")
  44. f.write("-" * 50 + "\n")
  45. f.flush()
  46. print(f"Output has been saved to {output_file}")
  47. if __name__ == "__main__":
  48. main()