RAG_app.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from fastapi import FastAPI, Request, HTTPException, status, Body
  2. # from fastapi.templating import Jinja2Templates
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from fastapi.responses import FileResponse
  5. from fastapi import Depends
  6. from contextlib import asynccontextmanager
  7. from pydantic import BaseModel
  8. from typing import List, Optional
  9. import uvicorn
  10. from typing import List, Optional
  11. import sqlparse
  12. from sqlalchemy import create_engine
  13. import pandas as pd
  14. #from retrying import retry
  15. import datetime
  16. import json
  17. from json import loads
  18. import pandas as pd
  19. import time
  20. from langchain.callbacks import get_openai_callback
  21. from langchain_community.vectorstores import Chroma
  22. from langchain_openai import OpenAIEmbeddings
  23. from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
  24. from Indexing_Split import create_retriever as split_retriever
  25. from Indexing_Split import gen_doc_from_database, gen_doc_from_history
  26. from dotenv import load_dotenv
  27. import os
  28. from langchain_community.vectorstores import SupabaseVectorStore
  29. from langchain_openai import OpenAIEmbeddings
  30. from supabase.client import Client, create_client
  31. from add_vectordb import GetVectorStore
  32. load_dotenv()
  33. URI = os.getenv("SUPABASE_URI")
  34. global_retriever = None
  35. @asynccontextmanager
  36. async def lifespan(app: FastAPI):
  37. global global_retriever
  38. global vector_store
  39. start = time.time()
  40. # global_retriever = split_retriever(path='./Documents', extension="docx")
  41. # global_retriever = raptor_retriever(path='../Documents', extension="txt")
  42. # global_retriever = unstructured_retriever(path='../Documents')
  43. supabase_url = os.environ.get("SUPABASE_URL")
  44. supabase_key = os.environ.get("SUPABASE_KEY")
  45. document_table = "documents"
  46. supabase: Client = create_client(supabase_url, supabase_key)
  47. embeddings = OpenAIEmbeddings()
  48. vector_store = GetVectorStore(embeddings, supabase, document_table)
  49. global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
  50. print(time.time() - start)
  51. yield
  52. def get_retriever():
  53. return global_retriever
  54. def get_vector_store():
  55. return vector_store
  56. app = FastAPI(lifespan=lifespan)
  57. # templates = Jinja2Templates(directory="temp")
  58. app.add_middleware(
  59. CORSMiddleware,
  60. allow_origins=["*"],
  61. allow_credentials=True,
  62. allow_methods=["*"],
  63. allow_headers=["*"],
  64. )
  65. @app.get("/answer2")
  66. def multi_query_answer(question, retriever=Depends(get_retriever)):
  67. start = time.time()
  68. with get_openai_callback() as cb:
  69. # qa_doc = gen_doc_from_database()
  70. # qa_history_doc = gen_doc_from_history()
  71. # qa_doc.extend(qa_history_doc)
  72. # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
  73. # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
  74. # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
  75. final_answer = 'False'
  76. if final_answer == 'False':
  77. final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
  78. # print(CHAT_HISTORY)
  79. # with get_openai_callback() as cb:
  80. # final_answer, reference_docs = multi_query(question, retriever)
  81. processing_time = time.time() - start
  82. print(processing_time)
  83. save_history(question, final_answer, reference_docs, cb, processing_time)
  84. return {"Answer": final_answer}
  85. class ChatHistoryItem(BaseModel):
  86. q: str
  87. a: str
  88. @app.post("/answer_with_history")
  89. def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  90. start = time.time()
  91. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  92. print(chat_history)
  93. # TODO: similarity search
  94. with get_openai_callback() as cb:
  95. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  96. processing_time = time.time() - start
  97. print(processing_time)
  98. save_history(question, final_answer, reference_docs, cb, processing_time)
  99. return {"Answer": final_answer}
  100. @app.post("/answer_with_history2")
  101. def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  102. start = time.time()
  103. retriever = vector_store.as_retriever(search_kwargs={"k": 4,
  104. 'filter': {'extension':extension}})
  105. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  106. print(chat_history)
  107. # TODO: similarity search
  108. with get_openai_callback() as cb:
  109. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  110. processing_time = time.time() - start
  111. print(processing_time)
  112. save_history(question, final_answer, reference_docs, cb, processing_time)
  113. return {"Answer": final_answer}
  114. def save_history(question, answer, reference, cb, processing_time):
  115. # reference = [doc.dict() for doc in reference]
  116. record = {
  117. 'Question': [question],
  118. 'Answer': [answer],
  119. 'Total_Tokens': [cb.total_tokens],
  120. 'Total_Cost': [cb.total_cost],
  121. 'Processing_time': [processing_time],
  122. 'Contexts': [str(reference)]
  123. }
  124. df = pd.DataFrame(record)
  125. engine = create_engine(URI)
  126. df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
  127. class history_output(BaseModel):
  128. Question: str
  129. Answer: str
  130. Contexts: str
  131. Total_Tokens: int
  132. Total_Cost: float
  133. Processing_time: float
  134. Time: datetime.datetime
  135. @app.get('/history', response_model=List[history_output])
  136. async def get_history():
  137. engine = create_engine(URI, echo=True)
  138. df = pd.read_sql_table("systex_records", engine.connect())
  139. df.fillna('', inplace=True)
  140. result = df.to_json(orient='index', force_ascii=False)
  141. result = loads(result)
  142. return result.values()
  143. if __name__ == "__main__":
  144. uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem",
  145. ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")