Browse Source

Added standard QA similarity check before RAG

SherryLiu 5 months ago
parent
commit
4e94a20c77
7 changed files with 138 additions and 12 deletions
  1. 6 5
      RAG/README.md
  2. 0 0
      RAG/__Init__.py
  3. 45 0
      RAG/chromadb_generate.py
  4. 1 1
      RAG/config.py
  5. 1 1
      RAG/embeddings.py
  6. 80 0
      RAG/main.py
  7. 5 5
      RAG/rag_chain.py

+ 6 - 5
RAG/README.md

@@ -1,9 +1,10 @@
-#### 101 RAG chatbot
-- Using OpenAI Embeddings
-- Using FAISS index
-- Using TAIDE LLM model (llama 3 version)
+### 101 RAG chatbot
+
+#### Prerequisite
+Save API tokens in `environment.env`
 
 #### To run the code
 `./run.sh`
 
-There are sample outputs in RAG_output.txt
+#### To exit the conversation
+`quit`

+ 0 - 0
RAG/__Init__.py


+ 45 - 0
RAG/chromadb_generate.py

@@ -0,0 +1,45 @@
+import os
+from dotenv import load_dotenv
+from langchain_openai import OpenAIEmbeddings
+from langchain_community.document_loaders.csv_loader import CSVLoader
+from langchain_chroma import Chroma
+import openai
+
+# Load environment variables
+load_dotenv('../environment.env')
+
+# Set up OpenAI API
+openai_api_key = os.getenv("OPENAI_API_KEY")
+if not openai_api_key:
+    raise ValueError("No OpenAI API key found in environment variables")
+openai.api_key = openai_api_key
+
+# Initialize embeddings model
+embeddings_model = OpenAIEmbeddings()
+
+def extract_field(doc, field_name):
+    for line in doc.page_content.split('\n'):
+        if line.startswith(f"{field_name}:"):
+            return line.split(':', 1)[1].strip()
+    return None
+
+# Check if Chroma DB already exists
+if not os.path.exists("./chroma_db"):
+    try:
+        # Load and process CSV data
+        loader = CSVLoader(file_path="log_record_rows.csv")
+        data = loader.load()
+        field_name = "question"
+        questions = [extract_field(doc, field_name) for doc in data]
+
+        # Create and save Chroma vector store
+        vectorstore = Chroma.from_texts(
+            texts=questions,
+            embedding=embeddings_model,
+            persist_directory="./chroma_db"
+        )
+        print("Chroma database created successfully.")
+    except Exception as e:
+        print(f"An error occurred while creating the Chroma database: {e}")
+else:
+    print("Chroma database already exists.")

+ 1 - 1
RAG/config.py

@@ -10,4 +10,4 @@ EMBEDDINGS_FILE = 'qa_embeddings.pkl'
 FAISS_INDEX_FILE = 'qa_faiss_index.bin'
 CSV_FILE = 'log_record_rows.csv'
 
-system_prompt = "你是一個來自台灣的AI助理,你的名字是 TAIDE,樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"
+system_prompt = "你是台北101大樓的AI助理,你樂於以台灣人的立場幫助使用者,會用繁體中文回答問題。"

+ 1 - 1
RAG/embeddings.py

@@ -93,7 +93,7 @@ def load_embeddings():
     else:
         raise FileNotFoundError("Embeddings or FAISS index file not found. Please run embeddings.py first.")
 
-def similarity_search(query, index, docs, k=3, threshold=0.83, method='logistic', sigma=1.0):
+def similarity_search(query, index, docs, k=3, threshold=0.8, method='logistic', sigma=1.0):
     embeddings = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY"))
     query_vector = embeddings.embed_query(query)
     

+ 80 - 0
RAG/main.py

@@ -0,0 +1,80 @@
+import time
+import os
+from dotenv import load_dotenv
+from config import system_prompt
+from langchain.globals import set_llm_cache
+from langchain_community.cache import SQLiteCache
+from embeddings import load_embeddings
+from rag_chain import simple_rag_prompt, get_context
+import pandas as pd
+from langchain_openai import OpenAIEmbeddings
+from langchain_chroma import Chroma
+
+# Load environment variables
+load_dotenv('environment.env')
+
+# Set up cache
+set_llm_cache(SQLiteCache(database_path=".langchain.db"))
+
+# Load standard QA data
+qa_df = pd.read_csv('log_record_rows.csv')
+
+# Initialize embeddings and Chroma
+embeddings_model = OpenAIEmbeddings()
+vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
+
+def search_similarity(query, SIMILARITY_THRESHOLD=0.82):
+    docs_and_scores = vectorstore.similarity_search_with_relevance_scores(query, k=1)
+    if docs_and_scores:
+        doc, score = docs_and_scores[0]
+        if score >= SIMILARITY_THRESHOLD:
+            # Find the corresponding answer in the qa_df
+            question = doc.page_content
+            answer = qa_df[qa_df['question'] == question]['answer'].values
+            if len(answer) > 0:
+                return answer[0], score
+    return None, 0
+
+def main():
+    # Load embeddings and index
+    embeddings, docs, df, index = load_embeddings()
+    
+    # Define retrieval chain
+    retrieval_chain = lambda q: get_context(q, index, docs)
+    
+    print("RAG system initialized. You can start asking questions.")
+    print("Type 'quit' to exit the program.")
+    
+    while True:
+        user_input = input("\nEnter your question: ")
+        
+        if user_input.lower() == 'quit':
+            break
+        
+        start_time = time.time()
+        
+        try:
+            # First, search in the standard QA
+            standard_answer, similarity_score = search_similarity(user_input)
+            
+            if standard_answer is not None:
+                print(f"\nAnswer (from standard QA): {standard_answer}")
+                # print(f"Similarity Score: {similarity_score:.4f}")
+            else:
+                # If not found in standard QA, use RAG
+                rag_answer, rag_similarity_score = simple_rag_prompt(retrieval_chain, user_input)
+                print(f"\nAnswer (from RAG): {rag_answer}")
+                # print(f"Retrieval Similarity Score: {rag_similarity_score:.4f}")
+            
+            end_time = time.time()
+            response_time = end_time - start_time
+            print(f"Response Time: {response_time:.2f} seconds")
+        
+        except Exception as e:
+            print(f"Error processing question: {str(e)}")
+        
+        # Add a small delay to avoid rate limiting
+        time.sleep(1)
+
+if __name__ == "__main__":
+    main()

+ 5 - 5
RAG/rag_chain.py

@@ -17,11 +17,11 @@ def get_context(query, index, docs):
 
     context = "\n".join([doc.page_content for doc, _ in results])
 
-    print(f"Question: {query}")
-    print("Retrieved documents:")
-    for i, (doc, similarity) in enumerate(results):
-        print(f"Doc {i+1} (similarity: {similarity:.4f}): {doc.page_content[:50]}...")
-    print("-" * 50)
+    # print(f"Question: {query}")
+    # print("Retrieved documents:")
+    # for i, (doc, similarity) in enumerate(results):
+    #     print(f"Doc {i+1} (similarity: {similarity:.4f}): {doc.page_content[:50]}...")
+    # print("-" * 50)
     
     return context, results[0][1]  # Return context and top similarity score