embeddings.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import os
  2. from dotenv import load_dotenv
  3. from langchain_core.documents import Document
  4. from langchain_openai import OpenAIEmbeddings
  5. from json import loads
  6. from sqlalchemy import create_engine
  7. import pandas as pd
  8. import numpy as np
  9. import faiss
  10. import pickle
  11. # Get the current script's directory
  12. current_dir = os.path.dirname(os.path.abspath(__file__))
  13. parent_dir = os.path.dirname(current_dir)
  14. env_path = os.path.join(parent_dir, 'environment.env')
  15. load_dotenv(env_path)
  16. URI = os.getenv("SUPABASE_URI")
  17. openai_api_key = os.getenv("OPENAI_API_KEY")
  18. EMBEDDINGS_FILE = 'qa_embeddings.pkl'
  19. FAISS_INDEX_FILE = 'qa_faiss_index.bin'
  20. CSV_FILE = 'log_record_rows.csv'
  21. def gen_doc_from_database():
  22. engine = create_engine(URI, echo=True)
  23. df = pd.read_sql_table("log_record", engine.connect())
  24. result = df[['question', 'answer']].to_json(orient='id', force_ascii=False)
  25. result = loads(result)
  26. df = pd.DataFrame(result).T
  27. df.drop_duplicates(subset=['question', 'answer'], keep='first', inplace=True)
  28. print(f"Number of records after removing duplicates: {len(df)}")
  29. qa_doc = []
  30. for i in range(len(df)):
  31. Question = df.iloc[i]['question']
  32. Answer = df.iloc[i]['answer']
  33. context = f'question: {Question}\nanswer: {Answer}'
  34. doc = Document(page_content=context)
  35. qa_doc.append(doc)
  36. return qa_doc, df
  37. def gen_doc_from_csv(csv_filename=CSV_FILE):
  38. csv_path = os.path.join(current_dir, csv_filename)
  39. df = pd.read_csv(csv_path)
  40. df.drop_duplicates(subset=['question', 'answer'], keep='first', inplace=True)
  41. print(f"Number of records after removing duplicates: {len(df)}")
  42. qa_doc = []
  43. for _, row in df.iterrows():
  44. Question = row['question']
  45. Answer = row['answer']
  46. context = f'question: {Question}\nanswer: {Answer}'
  47. doc = Document(page_content=context)
  48. qa_doc.append(doc)
  49. return qa_doc, df
  50. def create_embeddings(docs):
  51. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  52. print("Creating embeddings...")
  53. doc_embeddings = embeddings.embed_documents([doc.page_content for doc in docs])
  54. print(f"Created {len(doc_embeddings)} embeddings.")
  55. print(f"Each embedding is a vector of length {len(doc_embeddings[0])}.")
  56. return doc_embeddings
  57. def create_faiss_index(embeddings):
  58. dimension = len(embeddings[0])
  59. index = faiss.IndexFlatL2(dimension)
  60. index.add(np.array(embeddings).astype('float32'))
  61. return index
  62. def save_embeddings(embeddings, docs, df):
  63. with open(os.path.join(current_dir, EMBEDDINGS_FILE), 'wb') as f:
  64. pickle.dump({'embeddings': embeddings, 'docs': docs, 'df': df}, f)
  65. index = create_faiss_index(embeddings)
  66. faiss.write_index(index, os.path.join(current_dir, FAISS_INDEX_FILE))
  67. print(f"Saved embeddings to {EMBEDDINGS_FILE} and FAISS index to {FAISS_INDEX_FILE}")
  68. def load_embeddings():
  69. embeddings_path = os.path.join(current_dir, EMBEDDINGS_FILE)
  70. faiss_path = os.path.join(current_dir, FAISS_INDEX_FILE)
  71. if os.path.exists(embeddings_path) and os.path.exists(faiss_path):
  72. with open(embeddings_path, 'rb') as f:
  73. data = pickle.load(f)
  74. index = faiss.read_index(faiss_path)
  75. print("Loaded existing embeddings and FAISS index from files")
  76. return data['embeddings'], data['docs'], data['df'], index
  77. else:
  78. raise FileNotFoundError("Embeddings or FAISS index file not found. Please run embeddings.py first.")
  79. def similarity_search(query, index, docs, k=3, threshold=0.8, method='logistic', sigma=1.0):
  80. embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
  81. query_vector = embeddings.embed_query(query)
  82. # 增加 k 值以提高找到足夠唯一文檔的機會
  83. D, I = index.search(np.array([query_vector]).astype('float32'), k * 2)
  84. results = []
  85. seen_docs = set() # 用於跟踪已經看到的文檔
  86. for dist, idx in zip(D[0], I[0]):
  87. if method == 'logistic':
  88. similarity = 1 / (1 + dist)
  89. elif method == 'exponential':
  90. similarity = np.exp(-dist)
  91. elif method == 'gaussian':
  92. similarity = np.exp(-dist**2 / (2 * sigma**2))
  93. else:
  94. raise ValueError("Unknown similarity method")
  95. if similarity >= threshold:
  96. doc_content = docs[idx].page_content
  97. if doc_content not in seen_docs: # 檢查是否已經添加過這個文檔
  98. results.append((docs[idx], similarity))
  99. seen_docs.add(doc_content)
  100. if len(results) == k: # 如果我們已經找到了 k 個唯一的文檔,就停止搜索
  101. break
  102. return results
  103. def main():
  104. # Choose which method to use for generating documents
  105. use_csv = True # Set this to False if you want to use the database instead
  106. if use_csv:
  107. docs, df = gen_doc_from_csv()
  108. else:
  109. docs, df = gen_doc_from_database()
  110. embeddings = create_embeddings(docs)
  111. save_embeddings(embeddings, docs, df)
  112. if __name__ == "__main__":
  113. main()