Parcourir la source

added semantic search threshold

Sherry il y a 9 mois
Parent
commit
81c67cba0c
1 fichiers modifiés avec 42 ajouts et 27 suppressions
  1. 42 27
      101_semantic_search.py

+ 42 - 27
101_semantic_search.py

@@ -1,7 +1,7 @@
 ### Python = 3.9
 import os
 from dotenv import load_dotenv
-load_dotenv('environment.env')
+load_dotenv('../environment.env')
 
 import openai 
 openai_api_key = os.getenv("OPENAI_API_KEY")
@@ -18,17 +18,17 @@ from langchain_chroma import Chroma
 # supabase_key = os.getenv("SUPABASE_KEY")
 # supabase: Client = create_client(supabase_url, supabase_key)
 
-############# Load data #############
-def extract_field(doc, field_name):
-    for line in doc.page_content.split('\n'):
-        if line.startswith(f"{field_name}:"):
-            return line.split(':', 1)[1].strip()
-    return None
+# ############# Load data #############
+# def extract_field(doc, field_name):
+#     for line in doc.page_content.split('\n'):
+#         if line.startswith(f"{field_name}:"):
+#             return line.split(':', 1)[1].strip()
+#     return None
 
-loader = CSVLoader(file_path="video_cache_rows.csv")
-data = loader.load()
-field_name = "question"
-question = [extract_field(doc, field_name) for doc in data]
+# loader = CSVLoader(file_path="../video_cache_rows.csv")
+# data = loader.load()
+# field_name = "question"
+# question = [extract_field(doc, field_name) for doc in data]
 
 # ####### load data from supabase #######
 # embeddings_model = OpenAIEmbeddings()
@@ -49,28 +49,43 @@ question = [extract_field(doc, field_name) for doc in data]
 
 
 ########## generate embedding ###########
-embedding = embeddings_model.embed_documents(question)
+# embedding = embeddings_model.embed_documents(question)
 
 ########## Write embedding to the supabase table  #######
 # for id, new_embedding in zip(ids, embedding):
 #     supabase.table("video_cache_rows_duplicate").insert({"embedding": embedding.tolist()}).eq("id", id).execute()
 
-######### Vector Store ##########
-# Put pre-compute embeddings to vector store. ## save to disk
-vectorstore = Chroma.from_texts(
-    texts=question,
-    embedding=embeddings_model,
-    persist_directory="./chroma_db"
-    )
+# ######### Vector Store ##########
+# # Put pre-compute embeddings to vector store. ## save to disk
+# vectorstore = Chroma.from_texts(
+#     texts=question,
+#     embedding=embeddings_model,
+#     persist_directory="./chroma_db"
+#     )
 
-####### load from disk  #######
-query = "101可以帶狗嗎"
-vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
-docs = vectorstore.similarity_search(query)
-print(f"Query: {query}  | 最接近文檔:{docs[0].page_content}")
+# ####### load from disk  #######
+# query = ["狗狗可以進101嗎", "哪裡有賣珍奶", "101幾點關門", "101星期天有開嗎", "球鞋哪裡有賣"]
+vectorstore = Chroma(persist_directory="../chroma_db", embedding_function=embeddings_model)
+# docs = vectorstore.similarity_search(query)
+# print(f"Query: {query}  | 最接近文檔:{docs[0].page_content}")
 
 ####### Query it #########
-query = "101可以帶狗嗎"
-docs = vectorstore.similarity_search(query)
-print(f"Query: {query}  | 最接近文檔:{docs[0].page_content}")
+def search_similarity(query, SIMILARITY_THRESHOLD):
+    docs_and_scores = vectorstore.similarity_search_with_relevance_scores(query, k=1)
+    doc, score = docs_and_scores[0]
+    if score >= SIMILARITY_THRESHOLD:
+            print(f"Query: {query}  | 最接近文檔:{doc.page_content} | score:{round(score, 2)}" )
+    else:
+        print(f"Query: {query} | 沒有相關資訊 | score:{round(score, 2)}")
+
+
+query = ["狗狗可以進101嗎", "哪裡有賣珍奶",  "遺失物品哪裡找", "嬰兒車可以進電梯嗎", "101幾點關門", "101星期天有開嗎", "球鞋哪裡有賣", "殘障人士租用輪椅", "停車多少錢", "觀景台導覽", "觀景台電梯速度", "我去哪裡買觀景台的票", "觀景台的票多少錢", "101有透明地板嗎", "如何辦退稅", "紀念品可以退稅嗎", "哪裡可以退稅", "101網路可以訂票嗎"]
+SIMILARITY_THRESHOLD = 0.83
+for i in query:
+    search_similarity(i, SIMILARITY_THRESHOLD)
+
+# Define a similarity threshold
+# print('docs', docs_and_scores)
+
+