from dotenv import load_dotenv load_dotenv('environment.env') from fastapi import FastAPI, HTTPException, status, Body, Depends from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from pydantic import BaseModel from typing import List, Optional import uvicorn from sqlalchemy import create_engine import pandas as pd import datetime import json from json import loads import time from langchain.callbacks import get_openai_callback from langchain_openai import OpenAIEmbeddings from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs import os from supabase.client import Client, create_client from add_vectordb import GetVectorStore import openai # Get API log import logging logger = logging.getLogger("uvicorn.error") openai_api_key = os.getenv("OPENAI_API_KEY") URI = os.getenv("SUPABASE_URI") openai.api_key = openai_api_key global_retriever = None @asynccontextmanager async def lifespan(app: FastAPI): global global_retriever global vector_store start = time.time() supabase_url = os.getenv("SUPABASE_URL") supabase_key = os.getenv("SUPABASE_KEY") document_table = "documents" supabase: Client = create_client(supabase_url, supabase_key) embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) vector_store = GetVectorStore(embeddings, supabase, document_table) global_retriever = vector_store.as_retriever(search_kwargs={"k": 4}) print(f"Initialization time: {time.time() - start}") yield def get_retriever(): return global_retriever def get_vector_store(): return vector_store app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/answer2") def multi_query_answer(question, retriever=Depends(get_retriever)): try: start = time.time() with get_openai_callback() as cb: final_answer, reference_docs = multi_query(question, retriever, chat_history=[]) processing_time = time.time() - start print(f"Processing time: {processing_time}") save_history(question, final_answer, reference_docs, cb, processing_time) return {"Answer": final_answer} except Exception as e: logger.error(f"Error in /answer2 endpoint: {e}") raise HTTPException(status_code=500, detail=str(e)) class ChatHistoryItem(BaseModel): q: str a: str @app.post("/answer_with_history") def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)): start = time.time() chat_history = [(item.q, item.a) for item in chat_history if item.a != ""] print(f"Chat history: {chat_history}") with get_openai_callback() as cb: final_answer, reference_docs = multi_query(question, retriever, chat_history) processing_time = time.time() - start print(f"Processing time: {processing_time}") save_history(question, final_answer, reference_docs, cb, processing_time) return {"Answer": final_answer} @app.post("/answer_with_history2") def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)): start = time.time() retriever = vector_store.as_retriever(search_kwargs={"k": 4, 'filter': {'extension':extension}}) chat_history = [(item.q, item.a) for item in chat_history if item.a != ""] print(f"Chat history: {chat_history}") with get_openai_callback() as cb: final_answer, reference_docs = multi_query(question, retriever, chat_history) processing_time = time.time() - start print(f"Processing time: {processing_time}") save_history(question, final_answer, reference_docs, cb, processing_time) return {"Answer": final_answer} def save_history(question, answer, reference, cb, processing_time): record = { 'Question': [question], 'Answer': [answer], 'Total_Tokens': [cb.total_tokens], 'Total_Cost': [cb.total_cost], 'Processing_time': [processing_time], 'Contexts': [str(reference)] } df = pd.DataFrame(record) engine = create_engine(URI) df.to_sql(name='systex_records', con=engine, index=False, if_exists='append') class history_output(BaseModel): Question: str Answer: str Contexts: str Total_Tokens: int Total_Cost: float Processing_time: float Time: datetime.datetime @app.get('/history', response_model=List[history_output]) async def get_history(): engine = create_engine(URI, echo=True) df = pd.read_sql_table("systex_records", engine.connect()) df.fillna('', inplace=True) result = df.to_json(orient='index', force_ascii=False) result = loads(result) return result.values() @app.get("/") def read_root(): return {"message": "Welcome to the Carbon Chatbot API"} if __name__ == "__main__": uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)