RAG_app_copy.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from dotenv import load_dotenv
  2. load_dotenv('environment.env')
  3. from fastapi import FastAPI, HTTPException, status, Body, Depends
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from contextlib import asynccontextmanager
  6. from pydantic import BaseModel
  7. from typing import List, Optional
  8. import uvicorn
  9. from sqlalchemy import create_engine
  10. import pandas as pd
  11. import datetime
  12. import json
  13. from json import loads
  14. import time
  15. from langchain.callbacks import get_openai_callback
  16. from langchain_openai import OpenAIEmbeddings
  17. from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
  18. import os
  19. from supabase.client import Client, create_client
  20. from add_vectordb import GetVectorStore
  21. import openai
  22. # Get API log
  23. import logging
  24. logger = logging.getLogger("uvicorn.error")
  25. openai_api_key = os.getenv("OPENAI_API_KEY")
  26. URI = os.getenv("SUPABASE_URI")
  27. openai.api_key = openai_api_key
  28. global_retriever = None
  29. @asynccontextmanager
  30. async def lifespan(app: FastAPI):
  31. global global_retriever
  32. global vector_store
  33. start = time.time()
  34. supabase_url = os.getenv("SUPABASE_URL")
  35. supabase_key = os.getenv("SUPABASE_KEY")
  36. document_table = "documents"
  37. supabase: Client = create_client(supabase_url, supabase_key)
  38. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  39. vector_store = GetVectorStore(embeddings, supabase, document_table)
  40. global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
  41. print(f"Initialization time: {time.time() - start}")
  42. yield
  43. def get_retriever():
  44. return global_retriever
  45. def get_vector_store():
  46. return vector_store
  47. app = FastAPI(lifespan=lifespan)
  48. app.add_middleware(
  49. CORSMiddleware,
  50. allow_origins=["*"],
  51. allow_credentials=True,
  52. allow_methods=["*"],
  53. allow_headers=["*"],
  54. )
  55. @app.get("/answer2")
  56. def multi_query_answer(question, retriever=Depends(get_retriever)):
  57. try:
  58. start = time.time()
  59. with get_openai_callback() as cb:
  60. final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
  61. processing_time = time.time() - start
  62. print(f"Processing time: {processing_time}")
  63. save_history(question, final_answer, reference_docs, cb, processing_time)
  64. return {"Answer": final_answer}
  65. except Exception as e:
  66. logger.error(f"Error in /answer2 endpoint: {e}")
  67. raise HTTPException(status_code=500, detail=str(e))
  68. class ChatHistoryItem(BaseModel):
  69. q: str
  70. a: str
  71. @app.post("/answer_with_history")
  72. def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  73. start = time.time()
  74. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  75. print(f"Chat history: {chat_history}")
  76. with get_openai_callback() as cb:
  77. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  78. processing_time = time.time() - start
  79. print(f"Processing time: {processing_time}")
  80. save_history(question, final_answer, reference_docs, cb, processing_time)
  81. return {"Answer": final_answer}
  82. @app.post("/answer_with_history2")
  83. def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  84. start = time.time()
  85. retriever = vector_store.as_retriever(search_kwargs={"k": 4, 'filter': {'extension':extension}})
  86. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  87. print(f"Chat history: {chat_history}")
  88. with get_openai_callback() as cb:
  89. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  90. processing_time = time.time() - start
  91. print(f"Processing time: {processing_time}")
  92. save_history(question, final_answer, reference_docs, cb, processing_time)
  93. return {"Answer": final_answer}
  94. def save_history(question, answer, reference, cb, processing_time):
  95. record = {
  96. 'Question': [question],
  97. 'Answer': [answer],
  98. 'Total_Tokens': [cb.total_tokens],
  99. 'Total_Cost': [cb.total_cost],
  100. 'Processing_time': [processing_time],
  101. 'Contexts': [str(reference)]
  102. }
  103. df = pd.DataFrame(record)
  104. engine = create_engine(URI)
  105. df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
  106. class history_output(BaseModel):
  107. Question: str
  108. Answer: str
  109. Contexts: str
  110. Total_Tokens: int
  111. Total_Cost: float
  112. Processing_time: float
  113. Time: datetime.datetime
  114. @app.get('/history', response_model=List[history_output])
  115. async def get_history():
  116. engine = create_engine(URI, echo=True)
  117. df = pd.read_sql_table("systex_records", engine.connect())
  118. df.fillna('', inplace=True)
  119. result = df.to_json(orient='index', force_ascii=False)
  120. result = loads(result)
  121. return result.values()
  122. @app.get("/")
  123. def read_root():
  124. return {"message": "Welcome to the Carbon Chatbot API"}
  125. if __name__ == "__main__":
  126. uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)