semantic_search.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. ### Python = 3.9
  2. import os
  3. from dotenv import load_dotenv
  4. load_dotenv()
  5. import openai
  6. openai_api_key = os.getenv("OPENAI_API_KEY")
  7. openai.api_key = openai_api_key
  8. from langchain_openai import OpenAIEmbeddings
  9. embeddings_model = OpenAIEmbeddings()
  10. from langchain_community.document_loaders.csv_loader import CSVLoader
  11. from langchain_chroma import Chroma
  12. from supabase import create_client, Client
  13. supabase_url = os.getenv("SUPABASE_URL")
  14. supabase_key = os.getenv("SUPABASE_KEY")
  15. supabase: Client = create_client(supabase_url, supabase_key)
  16. ############# Load data #############
  17. # def extract_field(doc, field_name):
  18. # for line in doc.page_content.split('\n'):
  19. # if line.startswith(f"{field_name}:"):
  20. # return line.split(':', 1)[1].strip()
  21. # return None
  22. # loader = CSVLoader(file_path="video_cache_rows.csv")
  23. # data = loader.load()
  24. # field_name = "question"
  25. # question = [extract_field(doc, field_name) for doc in data]
  26. # ####### load data from supabase #######
  27. # embeddings_model = OpenAIEmbeddings()
  28. def generated(language:str ="ch"):
  29. global response,count,question,ids,question_id_map,vectorstore
  30. response,count = supabase.table("video_cache").select("question","id").eq("language",language).order("id").execute()
  31. data = response[1]
  32. question = [item['question'] for item in data if 'question' in item]
  33. ids = [item['id'] for item in data if 'id' in item]
  34. question_id_map = {item['question']: item['id'] for item in data if 'id' in item and 'question' in item}
  35. ########## generate embedding ###########
  36. embedding = embeddings_model.embed_documents(question)
  37. ########## Write embedding to the supabase table #######
  38. # for id, new_embedding in zip(ids, embedding):
  39. # supabase.table("video_cache_rows_duplicate").insert({"embedding": embedding.tolist()}).eq("id", id).execute()
  40. ######### Vector Store ##########
  41. # Put pre-compute embeddings to vector store. ## save to disk
  42. vectorstore = Chroma.from_texts(
  43. texts=question,
  44. embedding=embeddings_model,
  45. persist_directory="./chroma_db"
  46. )
  47. vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
  48. print("gernerate")
  49. def get_id_by_question(question):
  50. return question_id_map.get(question)
  51. # print(question)
  52. # created_at = []
  53. # question = []
  54. # ids = []
  55. # answer = []
  56. # video_url = []
  57. # for item in data:
  58. # ids.append(item['id'])
  59. # created_at.append(item['created_at'])
  60. # question.append(item['question'])
  61. # answer.append(item['answer'])
  62. # video_url.append(item['video_url'])
  63. def ask_question(question:str, SIMILARITY_THRESHOLD:int = 0.83,language:str ="ch"):
  64. generated(language=language)
  65. docs_and_scores = vectorstore.similarity_search_with_relevance_scores(question, k=1)
  66. doc, score = docs_and_scores[0]
  67. print(doc,score)
  68. if score >= SIMILARITY_THRESHOLD:
  69. id = get_id_by_question(doc.page_content)
  70. data,count = supabase.table("video_cache").select("*").eq("id",id).execute()
  71. if data[1][0]["answer"] == None :
  72. return None
  73. return data[1]
  74. else:
  75. return None
  76. if __name__ == "__main__" :
  77. ####### load from disk #######
  78. query = "美食街在哪裡"
  79. docs = vectorstore.similarity_search(query)
  80. print(f"Query: {query} | 最接近文檔:{docs[0].page_content}")
  81. ####### Query it #########
  82. query = "101可以帶狗嗎"
  83. docs = vectorstore.similarity_search(query)
  84. print(f"Query: {query} | 最接近文檔:{docs[0].page_content}")