RAG_app.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from dotenv import load_dotenv
  2. load_dotenv('environment.env')
  3. from fastapi import FastAPI, Request, HTTPException, status, Body
  4. # from fastapi.templating import Jinja2Templates
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.responses import FileResponse
  7. from fastapi import Depends
  8. from contextlib import asynccontextmanager
  9. from pydantic import BaseModel
  10. from typing import List, Optional
  11. import uvicorn
  12. from typing import List, Optional
  13. import sqlparse
  14. from sqlalchemy import create_engine
  15. import pandas as pd
  16. #from retrying import retry
  17. import datetime
  18. import json
  19. from json import loads
  20. import pandas as pd
  21. import time
  22. from langchain.callbacks import get_openai_callback
  23. from langchain_community.vectorstores import Chroma
  24. from langchain_openai import OpenAIEmbeddings
  25. from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
  26. from Indexing_Split import create_retriever as split_retriever
  27. from Indexing_Split import gen_doc_from_database, gen_doc_from_history
  28. import os
  29. from langchain_community.vectorstores import SupabaseVectorStore
  30. from langchain_openai import OpenAIEmbeddings
  31. from supabase.client import Client, create_client
  32. from add_vectordb import GetVectorStore
  33. from langchain_community.cache import RedisSemanticCache # 更新导入路径
  34. from langchain_core.prompts import PromptTemplate
  35. import openai
  36. openai_api_key = os.getenv("OPENAI_API_KEY")
  37. openai.api_key = openai_api_key
  38. URI = os.getenv("SUPABASE_URI")
  39. global_retriever = None
  40. @asynccontextmanager
  41. async def lifespan(app: FastAPI):
  42. global global_retriever
  43. global vector_store
  44. start = time.time()
  45. # global_retriever = split_retriever(path='./Documents', extension="docx")
  46. # global_retriever = raptor_retriever(path='../Documents', extension="txt")
  47. # global_retriever = unstructured_retriever(path='../Documents')
  48. supabase_url = os.getenv("SUPABASE_URL")
  49. supabase_key = os.getenv("SUPABASE_KEY")
  50. document_table = "documents"
  51. supabase: Client = create_client(supabase_url, supabase_key)
  52. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  53. vector_store = GetVectorStore(embeddings, supabase, document_table)
  54. global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
  55. print(time.time() - start)
  56. yield
  57. def get_retriever():
  58. return global_retriever
  59. def get_vector_store():
  60. return vector_store
  61. app = FastAPI(lifespan=lifespan)
  62. # templates = Jinja2Templates(directory="temp")
  63. app.add_middleware(
  64. CORSMiddleware,
  65. allow_origins=["*"],
  66. allow_credentials=True,
  67. allow_methods=["*"],
  68. allow_headers=["*"],
  69. )
  70. @app.get("/answer2")
  71. def multi_query_answer(question, retriever=Depends(get_retriever)):
  72. start = time.time()
  73. with get_openai_callback() as cb:
  74. # qa_doc = gen_doc_from_database()
  75. # qa_history_doc = gen_doc_from_history()
  76. # qa_doc.extend(qa_history_doc)
  77. # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
  78. # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
  79. # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
  80. final_answer = 'False'
  81. if final_answer == 'False':
  82. final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
  83. # print(CHAT_HISTORY)
  84. # with get_openai_callback() as cb:
  85. # final_answer, reference_docs = multi_query(question, retriever)
  86. processing_time = time.time() - start
  87. print(processing_time)
  88. save_history(question, final_answer, reference_docs, cb, processing_time)
  89. return {"Answer": final_answer}
  90. class ChatHistoryItem(BaseModel):
  91. q: str
  92. a: str
  93. @app.post("/answer_with_history")
  94. def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  95. start = time.time()
  96. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  97. print(chat_history)
  98. # TODO: similarity search
  99. with get_openai_callback() as cb:
  100. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  101. processing_time = time.time() - start
  102. print(processing_time)
  103. save_history(question, final_answer, reference_docs, cb, processing_time)
  104. return {"Answer": final_answer}
  105. @app.post("/answer_with_history2")
  106. def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  107. start = time.time()
  108. retriever = vector_store.as_retriever(search_kwargs={"k": 4,
  109. 'filter': {'extension':extension}})
  110. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  111. print(chat_history)
  112. # TODO: similarity search
  113. with get_openai_callback() as cb:
  114. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  115. processing_time = time.time() - start
  116. print(processing_time)
  117. save_history(question, final_answer, reference_docs, cb, processing_time)
  118. return {"Answer": final_answer}
  119. def save_history(question, answer, reference, cb, processing_time):
  120. # reference = [doc.dict() for doc in reference]
  121. record = {
  122. 'Question': [question],
  123. 'Answer': [answer],
  124. 'Total_Tokens': [cb.total_tokens],
  125. 'Total_Cost': [cb.total_cost],
  126. 'Processing_time': [processing_time],
  127. 'Contexts': [str(reference)]
  128. }
  129. df = pd.DataFrame(record)
  130. engine = create_engine(URI)
  131. df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
  132. class history_output(BaseModel):
  133. Question: str
  134. Answer: str
  135. Contexts: str
  136. Total_Tokens: int
  137. Total_Cost: float
  138. Processing_time: float
  139. Time: datetime.datetime
  140. @app.get('/history', response_model=List[history_output])
  141. async def get_history():
  142. engine = create_engine(URI, echo=True)
  143. df = pd.read_sql_table("systex_records", engine.connect())
  144. df.fillna('', inplace=True)
  145. result = df.to_json(orient='index', force_ascii=False)
  146. result = loads(result)
  147. return result.values()
  148. @app.get("/")
  149. def read_root():
  150. return {"message": "Welcome to the SYSTEX API"}
  151. if __name__ == "__main__":
  152. uvicorn.run("RAG_app:app", host='127.0.0.1', port=8081, reload=True)
  153. # if __name__ == "__main__":
  154. # uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem",
  155. # ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")