RAG_app.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import re
  2. import threading
  3. from fastapi import FastAPI, Request, HTTPException, Response, status, Body
  4. # from fastapi.templating import Jinja2Templates
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.responses import FileResponse, JSONResponse
  7. from fastapi import Depends
  8. from contextlib import asynccontextmanager
  9. from pydantic import BaseModel
  10. from typing import List, Optional
  11. import uvicorn
  12. import requests
  13. from typing import List, Optional
  14. import sqlparse
  15. from sqlalchemy import create_engine
  16. import pandas as pd
  17. #from retrying import retry
  18. import datetime
  19. import json
  20. from json import loads
  21. import pandas as pd
  22. import time
  23. from langchain_community.chat_models import ChatOllama
  24. from langchain.callbacks import get_openai_callback
  25. from langchain_community.vectorstores import Chroma
  26. from langchain_openai import OpenAIEmbeddings
  27. # from RAG_strategy import multi_query, naive_rag
  28. # from Indexing_Split import create_retriever as split_retriever
  29. # from Indexing_Split import gen_doc_from_database, gen_doc_from_history
  30. from semantic_search import semantic_cache
  31. from dotenv import load_dotenv
  32. import os
  33. from langchain_community.vectorstores import SupabaseVectorStore
  34. from langchain_openai import OpenAIEmbeddings
  35. from supabase.client import Client, create_client
  36. from file_loader.add_vectordb import GetVectorStore
  37. from faiss_index import create_faiss_retriever, faiss_query
  38. from local_llm import ollama_, hf
  39. from ai_agent import main
  40. # from local_llm import ollama_, taide_llm, hf
  41. # llm = hf()
  42. load_dotenv()
  43. URI = os.getenv("SUPABASE_URI")
  44. supabase_url = os.environ.get("SUPABASE_URL")
  45. supabase_key = os.environ.get("SUPABASE_KEY")
  46. supabase: Client = create_client(supabase_url, supabase_key)
  47. global_retriever = None
  48. llm = None
  49. @asynccontextmanager
  50. async def lifespan(app: FastAPI):
  51. global global_retriever
  52. global llm
  53. global vector_store
  54. start = time.time()
  55. document_table = "documents"
  56. embeddings = OpenAIEmbeddings()
  57. # vector_store = GetVectorStore(embeddings, supabase, document_table)
  58. # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
  59. global_retriever = create_faiss_retriever()
  60. local_llm = "llama3-groq-tool-use:latest"
  61. # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
  62. llm = ChatOllama(model=local_llm, temperature=0)
  63. print(time.time() - start)
  64. yield
  65. def get_retriever():
  66. return global_retriever
  67. def get_llm():
  68. return llm
  69. def get_vector_store():
  70. return vector_store
  71. app = FastAPI(lifespan=lifespan)
  72. # templates = Jinja2Templates(directory="temp")
  73. app.add_middleware(
  74. CORSMiddleware,
  75. allow_origins=["*"],
  76. allow_credentials=True,
  77. allow_methods=["*"],
  78. allow_headers=["*"],
  79. )
  80. class ChatHistoryItem(BaseModel):
  81. q: str
  82. a: str
  83. def replace_unicode_escapes(match):
  84. return chr(int(match.group(1), 16))
  85. @app.post("/answer_with_history")
  86. def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?', chat_history: List[ChatHistoryItem] = Body(...),
  87. retriever=Depends(get_retriever), llm=Depends(get_llm)):
  88. start = time.time()
  89. chat_history = [(item.q, item.a) for item in chat_history if item.a != "" and item.a != "string"]
  90. print(chat_history)
  91. # TODO: similarity search
  92. with get_openai_callback() as cb:
  93. # cache_question, cache_answer = semantic_cache(supabase, question)
  94. # if cache_answer:
  95. # processing_time = time.time() - start
  96. # save_history(question, cache_answer, cache_question, cb, processing_time)
  97. # return {"Answer": cache_answer}
  98. # final_answer, reference_docs = multi_query(question, retriever, chat_history)
  99. # final_answer, reference_docs = naive_rag(question, retriever, chat_history)
  100. final_answer = faiss_query(question, global_retriever, llm)
  101. decoded_string = re.sub(r'\\u([0-9a-fA-F]{4})', replace_unicode_escapes, final_answer)
  102. print(decoded_string )
  103. reference_docs = global_retriever.get_relevant_documents(question)
  104. processing_time = time.time() - start
  105. print(processing_time)
  106. save_history(question, decoded_string , reference_docs, cb, processing_time)
  107. # print(response)
  108. response_content = json.dumps({"Answer": decoded_string }, ensure_ascii=False)
  109. # Manually create a Response object if using Flask
  110. return JSONResponse(content=response_content)
  111. # response_content = json.dumps({"Answer": final_answer}, ensure_ascii=False)
  112. # print(response_content)
  113. # return json.loads(response_content)
  114. @app.get("/agents")
  115. def agent(question: str):
  116. answer = main(question)
  117. return {"answer": answer}
  118. def save_history(question, answer, reference, cb, processing_time):
  119. # reference = [doc.dict() for doc in reference]
  120. record = {
  121. 'Question': [question],
  122. 'Answer': [answer],
  123. 'Total_Tokens': [cb.total_tokens],
  124. 'Total_Cost': [cb.total_cost],
  125. 'Processing_time': [processing_time],
  126. 'Contexts': [str(reference)] if isinstance(reference, list) else [reference]
  127. }
  128. df = pd.DataFrame(record)
  129. engine = create_engine(URI)
  130. df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
  131. class history_output(BaseModel):
  132. Question: str
  133. Answer: str
  134. Contexts: str
  135. Total_Tokens: int
  136. Total_Cost: float
  137. Processing_time: float
  138. Time: datetime.datetime
  139. @app.get('/history', response_model=List[history_output])
  140. async def get_history():
  141. engine = create_engine(URI, echo=True)
  142. df = pd.read_sql_table("systex_records", engine.connect())
  143. df.fillna('', inplace=True)
  144. result = df.to_json(orient='index', force_ascii=False)
  145. result = loads(result)
  146. return result.values()
  147. def send_heartbeat(url):
  148. while True:
  149. try:
  150. response = requests.get(url)
  151. if response.status_code != 200:
  152. print(f"Failed to send heartbeat, status code: {response.status_code}")
  153. except requests.RequestException as e:
  154. print(f"Error occurred: {e}")
  155. # 等待 60 秒
  156. time.sleep(600)
  157. def start_heartbeat(url):
  158. heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url,))
  159. heartbeat_thread.daemon = True
  160. heartbeat_thread.start()
  161. if __name__ == "__main__":
  162. # url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
  163. # start_heartbeat(url)
  164. uvicorn.run("RAG_app:app", host='0.0.0.0', reload=True, port=8080)