semantic_search.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. ### Python = 3.9
  2. import os
  3. from dotenv import load_dotenv
  4. load_dotenv()
  5. import openai
  6. openai_api_key = os.getenv("OPENAI_API_KEY")
  7. openai.api_key = openai_api_key
  8. from langchain_openai import OpenAIEmbeddings
  9. embeddings_model = OpenAIEmbeddings()
  10. from langchain_community.document_loaders.csv_loader import CSVLoader
  11. from langchain_chroma import Chroma
  12. from supabase import create_client, Client
  13. supabase_url = os.getenv("SUPABASE_URL")
  14. supabase_key = os.getenv("SUPABASE_KEY")
  15. supabase: Client = create_client(supabase_url, supabase_key)
  16. from apscheduler.schedulers.background import BackgroundScheduler
  17. from typing import AsyncIterator
  18. scheduler = BackgroundScheduler(timezone="Asia/Taipei")
  19. ############# Load data #############
  20. # def extract_field(doc, field_name):
  21. # for line in doc.page_content.split('\n'):
  22. # if line.startswith(f"{field_name}:"):
  23. # return line.split(':', 1)[1].strip()
  24. # return None
  25. # loader = CSVLoader(file_path="video_cache_rows.csv")
  26. # data = loader.load()
  27. # field_name = "question"
  28. # question = [extract_field(doc, field_name) for doc in data]
  29. # ####### load data from supabase #######
  30. # embeddings_model = OpenAIEmbeddings()
  31. vectorstore_list ={}
  32. question_id_map_list = {}
  33. def generated(language:str ="ch"):
  34. global response,count,ids
  35. response,count = supabase.table("video_cache").select("question","id").eq("language",language).order("id").execute()
  36. data = response[1]
  37. question = [item['question'] for item in data if 'question' in item]
  38. #print(question)
  39. ids = [item['id'] for item in data if 'id' in item]
  40. question_id_map = {item['question']: item['id'] for item in data if 'id' in item and 'question' in item}
  41. question_id_map_list[language] = question_id_map
  42. ########## generate embedding ###########
  43. embedding = embeddings_model.embed_documents(question)
  44. ########## Write embedding to the supabase table #######
  45. # for id, new_embedding in zip(ids, embedding):
  46. # supabase.table("video_cache_rows_duplicate").insert({"embedding": embedding.tolist()}).eq("id", id).execute()
  47. ######### Vector Store ##########
  48. # Put pre-compute embeddings to vector store. ## save to disk
  49. persist_directory = f"./chroma_db_{language}"
  50. vectorstore = Chroma.from_texts(
  51. texts=question_id_map_list[language],
  52. embedding=embeddings_model,
  53. persist_directory=persist_directory
  54. )
  55. #vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
  56. vectorstore_list[language] = vectorstore
  57. print(f"gernerate {language}")
  58. #print(question_id_map_list)
  59. generated("ch")
  60. generated("en")
  61. generated("jp")
  62. generated("ko")
  63. scheduler.add_job(generated, 'cron' ,hour='*/2',kwargs={"language" : "ch"})
  64. scheduler.add_job(generated, 'cron' ,hour='*/2',kwargs={"language" : "en"})
  65. scheduler.add_job(generated, 'cron' ,hour='*/2',kwargs={"language" : "jp"})
  66. scheduler.add_job(generated, 'cron' ,hour='*/2',kwargs={"language" : "ko"})
  67. scheduler.start()
  68. def get_id_by_question(question,language):
  69. return question_id_map_list[language].get(question)
  70. # print(question)
  71. # created_at = []
  72. # question = []
  73. # ids = []
  74. # answer = []
  75. # video_url = []
  76. # for item in data:
  77. # ids.append(item['id'])
  78. # created_at.append(item['created_at'])
  79. # question.append(item['question'])
  80. # answer.append(item['answer'])
  81. # video_url.append(item['video_url'])
  82. def ask_question(question:str, SIMILARITY_THRESHOLD:int = 0.83,language:str ="ch"):
  83. # generated(language=language)
  84. print(language)
  85. vectorstore = vectorstore_list[language]
  86. print(vectorstore)
  87. docs_and_scores = vectorstore.similarity_search_with_relevance_scores(question, k=1)
  88. doc, score = docs_and_scores[0]
  89. print(doc,score)
  90. if score >= SIMILARITY_THRESHOLD:
  91. id = get_id_by_question(doc.page_content,language)
  92. data,count = supabase.table("video_cache").select("*").eq("id",id).execute()
  93. if data[1][0]["answer"] == None :
  94. return None
  95. return data[1]
  96. else:
  97. return None
  98. def ask_question_find_brand(question:str):
  99. # 使用 OpenAI 模型生成查询
  100. # 使用 OpenAI ChatCompletion 模型生成关键词
  101. response = openai.ChatCompletion.create(
  102. model="gpt-4", # 选择 GPT-4 模型
  103. messages=[
  104. {"role": "system", "content": "You are a helpful assistant."},
  105. {"role": "user", "content": f"Extract keywords from the following text for a database search: {question}"}
  106. ],
  107. max_tokens=50,
  108. temperature=0.5,
  109. )
  110. # 提取模型返回的关键词
  111. keywords = response.choices[0].message['content'].strip().split(", ")
  112. return keywords
  113. if __name__ == "__main__" :
  114. ####### load from disk #######
  115. query = "美食街在哪裡"
  116. #docs = vectorstore.similarity_search(query)
  117. #print(f"Query: {query} | 最接近文檔:{docs[0].page_content}")
  118. ####### Query it #########
  119. query = "101可以帶狗嗎"
  120. #docs = vectorstore.similarity_search(query)
  121. #print(f"Query: {query} | 最接近文檔:{docs[0].page_content}")