RAG_app.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import re
  2. import threading
  3. from fastapi import FastAPI, Request, HTTPException, Response, status, Body
  4. # from fastapi.templating import Jinja2Templates
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.responses import FileResponse, JSONResponse
  7. from fastapi import Depends
  8. from contextlib import asynccontextmanager
  9. from pydantic import BaseModel
  10. from typing import List, Optional
  11. import uvicorn
  12. import requests
  13. from typing import List, Optional
  14. import sqlparse
  15. from sqlalchemy import create_engine
  16. import pandas as pd
  17. #from retrying import retry
  18. import datetime
  19. import json
  20. from json import loads
  21. import pandas as pd
  22. import time
  23. from langchain.callbacks import get_openai_callback
  24. from langchain_community.vectorstores import Chroma
  25. from langchain_openai import OpenAIEmbeddings
  26. # from RAG_strategy import multi_query, naive_rag
  27. # from Indexing_Split import create_retriever as split_retriever
  28. # from Indexing_Split import gen_doc_from_database, gen_doc_from_history
  29. from semantic_search import semantic_cache
  30. from dotenv import load_dotenv
  31. import os
  32. from langchain_community.vectorstores import SupabaseVectorStore
  33. from langchain_openai import OpenAIEmbeddings
  34. from supabase.client import Client, create_client
  35. from file_loader.add_vectordb import GetVectorStore
  36. from faiss_index import create_faiss_retriever, faiss_query
  37. from local_llm import ollama_, hf
  38. # from local_llm import ollama_, taide_llm, hf
  39. # llm = hf()
  40. load_dotenv()
  41. URI = os.getenv("SUPABASE_URI")
  42. supabase_url = os.environ.get("SUPABASE_URL")
  43. supabase_key = os.environ.get("SUPABASE_KEY")
  44. supabase: Client = create_client(supabase_url, supabase_key)
  45. global_retriever = None
  46. llm = None
  47. @asynccontextmanager
  48. async def lifespan(app: FastAPI):
  49. global global_retriever
  50. global llm
  51. global vector_store
  52. start = time.time()
  53. document_table = "documents"
  54. embeddings = OpenAIEmbeddings()
  55. # vector_store = GetVectorStore(embeddings, supabase, document_table)
  56. # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
  57. global_retriever = create_faiss_retriever()
  58. llm = hf()
  59. print(time.time() - start)
  60. yield
  61. def get_retriever():
  62. return global_retriever
  63. def get_llm():
  64. return llm
  65. def get_vector_store():
  66. return vector_store
  67. app = FastAPI(lifespan=lifespan)
  68. # templates = Jinja2Templates(directory="temp")
  69. app.add_middleware(
  70. CORSMiddleware,
  71. allow_origins=["*"],
  72. allow_credentials=True,
  73. allow_methods=["*"],
  74. allow_headers=["*"],
  75. )
  76. class ChatHistoryItem(BaseModel):
  77. q: str
  78. a: str
  79. def replace_unicode_escapes(match):
  80. return chr(int(match.group(1), 16))
  81. @app.post("/answer_with_history")
  82. def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?', chat_history: List[ChatHistoryItem] = Body(...),
  83. retriever=Depends(get_retriever), llm=Depends(get_llm)):
  84. start = time.time()
  85. chat_history = [(item.q, item.a) for item in chat_history if item.a != "" and item.a != "string"]
  86. print(chat_history)
  87. # TODO: similarity search
  88. with get_openai_callback() as cb:
  89. # cache_question, cache_answer = semantic_cache(supabase, question)
  90. # if cache_answer:
  91. # processing_time = time.time() - start
  92. # save_history(question, cache_answer, cache_question, cb, processing_time)
  93. # return {"Answer": cache_answer}
  94. # final_answer, reference_docs = multi_query(question, retriever, chat_history)
  95. # final_answer, reference_docs = naive_rag(question, retriever, chat_history)
  96. final_answer = faiss_query(question, global_retriever, llm)
  97. decoded_string = re.sub(r'\\u([0-9a-fA-F]{4})', replace_unicode_escapes, final_answer)
  98. print(decoded_string )
  99. reference_docs = global_retriever.get_relevant_documents(question)
  100. processing_time = time.time() - start
  101. print(processing_time)
  102. save_history(question, decoded_string , reference_docs, cb, processing_time)
  103. # print(response)
  104. response_content = json.dumps({"Answer": decoded_string }, ensure_ascii=False)
  105. # Manually create a Response object if using Flask
  106. return JSONResponse(content=response_content)
  107. # response_content = json.dumps({"Answer": final_answer}, ensure_ascii=False)
  108. # print(response_content)
  109. # return json.loads(response_content)
  110. def save_history(question, answer, reference, cb, processing_time):
  111. # reference = [doc.dict() for doc in reference]
  112. record = {
  113. 'Question': [question],
  114. 'Answer': [answer],
  115. 'Total_Tokens': [cb.total_tokens],
  116. 'Total_Cost': [cb.total_cost],
  117. 'Processing_time': [processing_time],
  118. 'Contexts': [str(reference)] if isinstance(reference, list) else [reference]
  119. }
  120. df = pd.DataFrame(record)
  121. engine = create_engine(URI)
  122. df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
  123. class history_output(BaseModel):
  124. Question: str
  125. Answer: str
  126. Contexts: str
  127. Total_Tokens: int
  128. Total_Cost: float
  129. Processing_time: float
  130. Time: datetime.datetime
  131. @app.get('/history', response_model=List[history_output])
  132. async def get_history():
  133. engine = create_engine(URI, echo=True)
  134. df = pd.read_sql_table("systex_records", engine.connect())
  135. df.fillna('', inplace=True)
  136. result = df.to_json(orient='index', force_ascii=False)
  137. result = loads(result)
  138. return result.values()
  139. def send_heartbeat(url):
  140. while True:
  141. try:
  142. response = requests.get(url)
  143. if response.status_code != 200:
  144. print(f"Failed to send heartbeat, status code: {response.status_code}")
  145. except requests.RequestException as e:
  146. print(f"Error occurred: {e}")
  147. # 等待 60 秒
  148. time.sleep(600)
  149. def start_heartbeat(url):
  150. heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url,))
  151. heartbeat_thread.daemon = True
  152. heartbeat_thread.start()
  153. if __name__ == "__main__":
  154. # url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
  155. # start_heartbeat(url)
  156. uvicorn.run("RAG_app:app", host='0.0.0.0', reload=True, port=8080)