123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- from dotenv import load_dotenv
- load_dotenv('environment.env')
- from fastapi import FastAPI, HTTPException, status, Body, Depends
- from fastapi.middleware.cors import CORSMiddleware
- from contextlib import asynccontextmanager
- from pydantic import BaseModel
- from typing import List, Optional
- import uvicorn
- from sqlalchemy import create_engine
- import pandas as pd
- import datetime
- import json
- from json import loads
- import time
- from langchain.callbacks import get_openai_callback
- from langchain_openai import OpenAIEmbeddings
- from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
- import os
- from supabase.client import Client, create_client
- from add_vectordb import GetVectorStore
- import openai
- # Get API log
- import logging
- logger = logging.getLogger("uvicorn.error")
- openai_api_key = os.getenv("OPENAI_API_KEY")
- URI = os.getenv("SUPABASE_URI")
- openai.api_key = openai_api_key
- global_retriever = None
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- global global_retriever
- global vector_store
-
- start = time.time()
- supabase_url = os.getenv("SUPABASE_URL")
- supabase_key = os.getenv("SUPABASE_KEY")
- document_table = "documents"
- supabase: Client = create_client(supabase_url, supabase_key)
- embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
- vector_store = GetVectorStore(embeddings, supabase, document_table)
- global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
- print(f"Initialization time: {time.time() - start}")
- yield
- def get_retriever():
- return global_retriever
- def get_vector_store():
- return vector_store
- app = FastAPI(lifespan=lifespan)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.get("/answer2")
- def multi_query_answer(question, retriever=Depends(get_retriever)):
- try:
- start = time.time()
- with get_openai_callback() as cb:
- final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
- processing_time = time.time() - start
- print(f"Processing time: {processing_time}")
- save_history(question, final_answer, reference_docs, cb, processing_time)
- return {"Answer": final_answer}
- except Exception as e:
- logger.error(f"Error in /answer2 endpoint: {e}")
- raise HTTPException(status_code=500, detail=str(e))
- class ChatHistoryItem(BaseModel):
- q: str
- a: str
- @app.post("/answer_with_history")
- def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
- start = time.time()
-
- chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
- print(f"Chat history: {chat_history}")
-
- with get_openai_callback() as cb:
- final_answer, reference_docs = multi_query(question, retriever, chat_history)
- processing_time = time.time() - start
- print(f"Processing time: {processing_time}")
- save_history(question, final_answer, reference_docs, cb, processing_time)
- return {"Answer": final_answer}
- @app.post("/answer_with_history2")
- def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
- start = time.time()
- retriever = vector_store.as_retriever(search_kwargs={"k": 4, 'filter': {'extension':extension}})
-
- chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
- print(f"Chat history: {chat_history}")
-
- with get_openai_callback() as cb:
- final_answer, reference_docs = multi_query(question, retriever, chat_history)
- processing_time = time.time() - start
- print(f"Processing time: {processing_time}")
- save_history(question, final_answer, reference_docs, cb, processing_time)
- return {"Answer": final_answer}
- def save_history(question, answer, reference, cb, processing_time):
- record = {
- 'Question': [question],
- 'Answer': [answer],
- 'Total_Tokens': [cb.total_tokens],
- 'Total_Cost': [cb.total_cost],
- 'Processing_time': [processing_time],
- 'Contexts': [str(reference)]
- }
- df = pd.DataFrame(record)
- engine = create_engine(URI)
- df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
- class history_output(BaseModel):
- Question: str
- Answer: str
- Contexts: str
- Total_Tokens: int
- Total_Cost: float
- Processing_time: float
- Time: datetime.datetime
- @app.get('/history', response_model=List[history_output])
- async def get_history():
- engine = create_engine(URI, echo=True)
- df = pd.read_sql_table("systex_records", engine.connect())
- df.fillna('', inplace=True)
- result = df.to_json(orient='index', force_ascii=False)
- result = loads(result)
- return result.values()
- @app.get("/")
- def read_root():
- return {"message": "Welcome to the Carbon Chatbot API"}
- if __name__ == "__main__":
- uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)
|