RAG_app_copy.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from dotenv import load_dotenv
  2. load_dotenv('environment.env')
  3. from fastapi import FastAPI, Request, HTTPException, status, Body
  4. # from fastapi.templating import Jinja2Templates
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.responses import FileResponse
  7. from fastapi import Depends
  8. from contextlib import asynccontextmanager
  9. from pydantic import BaseModel
  10. from typing import List, Optional
  11. import uvicorn
  12. import sqlparse
  13. from sqlalchemy import create_engine
  14. import pandas as pd
  15. #from retrying import retry
  16. import datetime
  17. import json
  18. from json import loads
  19. import time
  20. from langchain.callbacks import get_openai_callback
  21. from langchain_community.vectorstores import Chroma
  22. from langchain_openai import OpenAIEmbeddings
  23. from RAG_strategy import multi_query, naive_rag, naive_rag_for_qapairs
  24. from Indexing_Split import create_retriever as split_retriever
  25. from Indexing_Split import gen_doc_from_database, gen_doc_from_history
  26. import os
  27. from langchain_community.vectorstores import SupabaseVectorStore
  28. from langchain_openai import OpenAIEmbeddings
  29. from supabase.client import Client, create_client
  30. from add_vectordb import GetVectorStore
  31. from langchain_community.cache import RedisSemanticCache # 更新导入路径
  32. from langchain_core.prompts import PromptTemplate
  33. import openai
  34. # Get API log
  35. import logging
  36. logger = logging.getLogger("uvicorn.error")
  37. openai_api_key = os.getenv("OPENAI_API_KEY")
  38. URI = os.getenv("SUPABASE_URI")
  39. openai.api_key = openai_api_key
  40. global_retriever = None
  41. # 定義FastAPI的生命週期管理器,在啟動和關閉時執行特定操作
  42. @asynccontextmanager
  43. async def lifespan(app: FastAPI):
  44. global global_retriever
  45. global vector_store
  46. start = time.time()
  47. # global_retriever = split_retriever(path='./Documents', extension="docx")
  48. # global_retriever = raptor_retriever(path='../Documents', extension="txt")
  49. # global_retriever = unstructured_retriever(path='../Documents')
  50. supabase_url = os.getenv("SUPABASE_URL")
  51. supabase_key = os.getenv("SUPABASE_KEY")
  52. URI = os.getenv("SUPABASE_URI")
  53. document_table = "documents"
  54. supabase: Client = create_client(supabase_url, supabase_key)
  55. embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
  56. vector_store = GetVectorStore(embeddings, supabase, document_table)
  57. global_retriever = vector_store.as_retriever(search_kwargs={"k": 4})
  58. print(time.time() - start)
  59. yield
  60. # 定義依賴注入函數,用於在請求處理過程中獲取全局變量
  61. def get_retriever():
  62. return global_retriever
  63. def get_vector_store():
  64. return vector_store
  65. # 創建FastAPI應用實例並配置以及中間件
  66. app = FastAPI(lifespan=lifespan)
  67. # templates = Jinja2Templates(directory="temp")
  68. app.add_middleware(
  69. CORSMiddleware,
  70. allow_origins=["*"],
  71. allow_credentials=True,
  72. allow_methods=["*"],
  73. allow_headers=["*"],
  74. )
  75. # 定義API路由和處理函數
  76. # 處理傳入的問題並返回答案
  77. @app.get("/answer2")
  78. def multi_query_answer(question, retriever=Depends(get_retriever)):
  79. try:
  80. start = time.time()
  81. with get_openai_callback() as cb:
  82. # qa_doc = gen_doc_from_database()
  83. # qa_history_doc = gen_doc_from_history()
  84. # qa_doc.extend(qa_history_doc)
  85. # vectorstore = Chroma.from_documents(documents=qa_doc, embedding=OpenAIEmbeddings(), collection_name="qa_pairs")
  86. # retriever_qa = vectorstore.as_retriever(search_kwargs={"k": 3})
  87. # final_answer, reference_docs = naive_rag_for_qapairs(question, retriever_qa)
  88. final_answer = 'False'
  89. if final_answer == 'False':
  90. final_answer, reference_docs = multi_query(question, retriever, chat_history=[])
  91. # print(CHAT_HISTORY)
  92. # with get_openai_callback() as cb:
  93. # final_answer, reference_docs = multi_query(question, retriever)
  94. processing_time = time.time() - start
  95. print(processing_time)
  96. save_history(question, final_answer, reference_docs, cb, processing_time)
  97. return {"Answer": final_answer}
  98. except Exception as e:
  99. logger.error(f"Error in /answer2 endpoint: {e}")
  100. raise HTTPException(status_code=500, detail="Internal Server Error")
  101. class ChatHistoryItem(BaseModel):
  102. q: str
  103. a: str
  104. # 處理帶有歷史聊天紀錄的問題並返回答案
  105. @app.post("/answer_with_history")
  106. def multi_query_answer(question: Optional[str] = '', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  107. start = time.time()
  108. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  109. print(chat_history)
  110. # TODO: similarity search
  111. with get_openai_callback() as cb:
  112. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  113. processing_time = time.time() - start
  114. print(processing_time)
  115. save_history(question, final_answer, reference_docs, cb, processing_time)
  116. return {"Answer": final_answer}
  117. # 處理帶有聊天歷史紀錄和文件名過濾的問題,並返回答案
  118. @app.post("/answer_with_history2")
  119. def multi_query_answer(question: Optional[str] = '', extension: Optional[str] = 'pdf', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever)):
  120. start = time.time()
  121. retriever = vector_store.as_retriever(search_kwargs={"k": 4,
  122. 'filter': {'extension':extension}})
  123. chat_history = [(item.q, item.a) for item in chat_history if item.a != ""]
  124. print(chat_history)
  125. # TODO: similarity search
  126. with get_openai_callback() as cb:
  127. final_answer, reference_docs = multi_query(question, retriever, chat_history)
  128. processing_time = time.time() - start
  129. print(processing_time)
  130. save_history(question, final_answer, reference_docs, cb, processing_time)
  131. return {"Answer": final_answer}
  132. # 保存歷史。將處理結果儲存到數據庫
  133. def save_history(question, answer, reference, cb, processing_time):
  134. # reference = [doc.dict() for doc in reference]
  135. record = {
  136. 'Question': [question],
  137. 'Answer': [answer],
  138. 'Total_Tokens': [cb.total_tokens],
  139. 'Total_Cost': [cb.total_cost],
  140. 'Processing_time': [processing_time],
  141. 'Contexts': [str(reference)]
  142. }
  143. df = pd.DataFrame(record)
  144. engine = create_engine(URI)
  145. df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
  146. class history_output(BaseModel):
  147. Question: str
  148. Answer: str
  149. Contexts: str
  150. Total_Tokens: int
  151. Total_Cost: float
  152. Processing_time: float
  153. Time: datetime.datetime
  154. # 定義獲取歷史紀錄的路由
  155. @app.get('/history', response_model=List[history_output])
  156. async def get_history():
  157. engine = create_engine(URI, echo=True)
  158. df = pd.read_sql_table("systex_records", engine.connect())
  159. df.fillna('', inplace=True)
  160. result = df.to_json(orient='index', force_ascii=False)
  161. result = loads(result)
  162. return result.values()
  163. @app.get("/")
  164. def read_root():
  165. return {"message": "Welcome to the Carbon Chatbot API"}
  166. if __name__ == "__main__":
  167. uvicorn.run("RAG_app_copy:app", host='127.0.0.1', port=8081, reload=True)
  168. # if __name__ == "__main__":
  169. # uvicorn.run("RAG_app:app", host='cmm.ai', port=8081, reload=True, ssl_keyfile="/etc/letsencrypt/live/cmm.ai/privkey.pem",
  170. # ssl_certfile="/etc/letsencrypt/live/cmm.ai/fullchain.pem")