1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import time
- import os
- from dotenv import load_dotenv
- from config import system_prompt
- from langchain.globals import set_llm_cache
- from langchain_community.cache import SQLiteCache
- from embeddings import load_embeddings
- from rag_chain import simple_rag_prompt, get_context
- import pandas as pd
- from langchain_openai import OpenAIEmbeddings
- from langchain_chroma import Chroma
- # Load environment variables
- load_dotenv('environment.env')
- # Set up cache
- set_llm_cache(SQLiteCache(database_path=".langchain.db"))
- # Load standard QA data
- qa_df = pd.read_csv('log_record_rows.csv')
- # Initialize embeddings and Chroma
- embeddings_model = OpenAIEmbeddings()
- vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings_model)
- def search_similarity(query, SIMILARITY_THRESHOLD=0.82):
- docs_and_scores = vectorstore.similarity_search_with_relevance_scores(query, k=1)
- if docs_and_scores:
- doc, score = docs_and_scores[0]
- if score >= SIMILARITY_THRESHOLD:
- # Find the corresponding answer in the qa_df
- question = doc.page_content
- answer = qa_df[qa_df['question'] == question]['answer'].values
- if len(answer) > 0:
- return answer[0], score
- return None, 0
- def main():
- # Load embeddings and index
- embeddings, docs, df, index = load_embeddings()
-
- # Define retrieval chain
- retrieval_chain = lambda q: get_context(q, index, docs)
-
- print("RAG system initialized. You can start asking questions.")
- print("Type 'quit' to exit the program.")
-
- while True:
- user_input = input("\nEnter your question: ")
-
- if user_input.lower() == 'quit':
- break
-
- start_time = time.time()
-
- try:
- # First, search in the standard QA
- standard_answer, similarity_score = search_similarity(user_input)
-
- if standard_answer is not None:
- print(f"\nAnswer (from standard QA): {standard_answer}")
- # print(f"Similarity Score: {similarity_score:.4f}")
- else:
- # If not found in standard QA, use RAG
- rag_answer, rag_similarity_score = simple_rag_prompt(retrieval_chain, user_input)
- print(f"\nAnswer (from RAG): {rag_answer}")
- # print(f"Retrieval Similarity Score: {rag_similarity_score:.4f}")
-
- end_time = time.time()
- response_time = end_time - start_time
- print(f"Response Time: {response_time:.2f} seconds")
-
- except Exception as e:
- print(f"Error processing question: {str(e)}")
-
- # Add a small delay to avoid rate limiting
- time.sleep(1)
- if __name__ == "__main__":
- main()
|