semantic_search.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. ### Python = 3.9
  2. import os
  3. from dotenv import load_dotenv
  4. load_dotenv()
  5. import openai
  6. openai_api_key = os.getenv("OPENAI_API_KEY")
  7. openai.api_key = openai_api_key
  8. from langchain_openai import OpenAIEmbeddings
  9. embeddings_model = OpenAIEmbeddings()
  10. from langchain_community.document_loaders.csv_loader import CSVLoader
  11. from langchain_chroma import Chroma
  12. from supabase import create_client, Client
  13. supabase_url = os.getenv("SUPABASE_URL")
  14. supabase_key = os.getenv("SUPABASE_KEY")
  15. supabase: Client = create_client(supabase_url, supabase_key)
  16. ############# Load data #############
  17. # def extract_field(doc, field_name):
  18. # for line in doc.page_content.split('\n'):
  19. # if line.startswith(f"{field_name}:"):
  20. # return line.split(':', 1)[1].strip()
  21. # return None
  22. # loader = CSVLoader(file_path="video_cache_rows.csv")
  23. # data = loader.load()
  24. # field_name = "question"
  25. # question = [extract_field(doc, field_name) for doc in data]
  26. # ####### load data from supabase #######
  27. # embeddings_model = OpenAIEmbeddings()
  28. response,count = supabase.table("video_cache").select("question","id").order("id").execute()
  29. data = response[1]
  30. question = [item['question'] for item in data if 'question' in item]
  31. ids = [item['id'] for item in data if 'id' in item]
  32. question_id_map = {item['question']: item['id'] for item in data if 'id' in item and 'question' in item}
  33. def get_id_by_question(question):
  34. return question_id_map.get(question)
  35. # print(question)
  36. # created_at = []
  37. # question = []
  38. # ids = []
  39. # answer = []
  40. # video_url = []
  41. # for item in data:
  42. # ids.append(item['id'])
  43. # created_at.append(item['created_at'])
  44. # question.append(item['question'])
  45. # answer.append(item['answer'])
  46. # video_url.append(item['video_url'])
  47. ########## generate embedding ###########
  48. embedding = embeddings_model.embed_documents(question)
  49. ########## Write embedding to the supabase table #######
  50. # for id, new_embedding in zip(ids, embedding):
  51. # supabase.table("video_cache_rows_duplicate").insert({"embedding": embedding.tolist()}).eq("id", id).execute()
  52. ######### Vector Store ##########
  53. # Put pre-compute embeddings to vector store. ## save to disk
  54. vectorstore = Chroma.from_texts(
  55. texts=question,
  56. embedding=embeddings_model,
  57. persist_directory="./chroma_db"
  58. )
  59. vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
  60. def ask_question(question:str, SIMILARITY_THRESHOLD:int = 0.83):
  61. docs_and_scores = vectorstore.similarity_search_with_relevance_scores(question, k=1)
  62. doc, score = docs_and_scores[0]
  63. print(doc,score)
  64. if score >= SIMILARITY_THRESHOLD:
  65. id = get_id_by_question(doc.page_content)
  66. data,count = supabase.table("video_cache").select("*").eq("id",id).execute()
  67. if data[1][0]["answer"] == None :
  68. return None
  69. return data[1]
  70. else:
  71. return None
  72. if __name__ == "__main__" :
  73. ####### load from disk #######
  74. query = "美食街在哪裡"
  75. docs = vectorstore.similarity_search(query)
  76. print(f"Query: {query} | 最接近文檔:{docs[0].page_content}")
  77. ####### Query it #########
  78. query = "101可以帶狗嗎"
  79. docs = vectorstore.similarity_search(query)
  80. print(f"Query: {query} | 最接近文檔:{docs[0].page_content}")