systex_app.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import datetime
  2. from json import loads
  3. import threading
  4. import time
  5. from typing import List
  6. from fastapi import Body, FastAPI
  7. from fastapi.middleware.cors import CORSMiddleware
  8. import pandas as pd
  9. from pydantic import BaseModel
  10. import requests
  11. import uvicorn
  12. from dotenv import load_dotenv
  13. import os
  14. from supabase.client import Client, create_client
  15. from langchain.callbacks import get_openai_callback
  16. from ai_agent import main, rag_main
  17. from ai_agent_llama import main as llama_main
  18. from semantic_search import semantic_cache, grandson_semantic_cache
  19. from RAG_strategy import get_search_query
  20. load_dotenv()
  21. URI = os.getenv("SUPABASE_URI")
  22. supabase_url = os.environ.get("SUPABASE_URL")
  23. supabase_key = os.environ.get("SUPABASE_KEY")
  24. supabase: Client = create_client(supabase_url, supabase_key)
  25. app = FastAPI()
  26. app.add_middleware(
  27. CORSMiddleware,
  28. allow_origins=["*"],
  29. allow_credentials=True,
  30. allow_methods=["*"],
  31. allow_headers=["*"],
  32. )
  33. class ChatHistoryItem(BaseModel):
  34. q: str
  35. a: str
  36. @app.post("/agents")
  37. def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
  38. print(question)
  39. start = time.time()
  40. # TODO rewrite query
  41. # _search_query = get_search_query()
  42. # chat_history = [item for item in chat_history if question != item.q]
  43. # chat_history = [(item.q, item.a) for item in chat_history[-5:] if item.a != "" and item.a != "string" ]
  44. # print(chat_history)
  45. # modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
  46. with get_openai_callback() as cb:
  47. # cache_question, cache_answer = semantic_cache(supabase, question)
  48. cache_question, cache_answer = grandson_semantic_cache(question)
  49. # cache_answer = None
  50. if cache_answer:
  51. answer = cache_answer
  52. if "孫子" in answer:
  53. path = "https://cmm.ai/systex-ai-chatbot/video_cache/"
  54. video_cache = "grandson2.mp4"
  55. return {"Answer": answer, "video_cache": path + video_cache}
  56. else:
  57. result = main(question)
  58. answer = result["generation"]
  59. processing_time = time.time() - start
  60. # save_history(question + "->" + modified_question, answer, cb, processing_time)
  61. save_history(question, answer, cb, processing_time)
  62. if "test@systex.com" in answer:
  63. answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  64. print(answer)
  65. return {"Answer": answer}
  66. @app.post("/knowledge")
  67. def rag(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
  68. print(question)
  69. start = time.time()
  70. with get_openai_callback() as cb:
  71. # cache_question, cache_answer = semantic_cache(supabase, question)
  72. cache_answer = None
  73. if cache_answer:
  74. answer = cache_answer
  75. else:
  76. result = rag_main(question)
  77. answer = result["generation"]
  78. processing_time = time.time() - start
  79. save_history(question, answer, cb, processing_time)
  80. if "test@systex.com" in answer:
  81. answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
  82. print(answer)
  83. return {"Answer": answer}
  84. @app.post("/local_agents")
  85. def local_agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
  86. print(question)
  87. start = time.time()
  88. with get_openai_callback() as cb:
  89. # cache_question, cache_answer = semantic_cache(supabase, question)
  90. cache_answer = None
  91. if cache_answer:
  92. answer = cache_answer
  93. else:
  94. result = llama_main(question)
  95. answer = result["generation"]
  96. processing_time = time.time() - start
  97. save_history(question, answer, cb, processing_time)
  98. return {"Answer": answer}
  99. def save_history(question, answer, cb, processing_time):
  100. # reference = [doc.dict() for doc in reference]
  101. record = {
  102. 'Question': question,
  103. 'Answer': answer,
  104. 'Total_Tokens': cb.total_tokens,
  105. 'Total_Cost': cb.total_cost,
  106. 'Processing_time': processing_time,
  107. }
  108. response = (
  109. supabase.table("agent_records")
  110. .insert(record)
  111. .execute()
  112. )
  113. class history_output(BaseModel):
  114. Question: str
  115. Answer: str
  116. Total_Tokens: int
  117. Total_Cost: float
  118. Processing_time: float
  119. Time: datetime.datetime
  120. @app.get('/history', response_model=List[history_output])
  121. async def get_history():
  122. response = supabase.table("agent_records").select("*").execute()
  123. df = pd.DataFrame(response.data)
  124. # engine = create_engine(URI, echo=True)
  125. # df = pd.read_sql_table("systex_records", engine.connect())
  126. # df.fillna('', inplace=True)
  127. result = df.to_json(orient='index', force_ascii=False)
  128. result = loads(result)
  129. return result.values()
  130. def cleanup_files():
  131. faiss_index_path = "faiss_index.bin"
  132. metadata_path = "faiss_metadata.pkl"
  133. try:
  134. if os.path.exists(faiss_index_path):
  135. os.remove(faiss_index_path)
  136. print(f"{faiss_index_path} 已刪除")
  137. if os.path.exists(metadata_path):
  138. os.remove(metadata_path)
  139. print(f"{metadata_path} 已刪除")
  140. except Exception as e:
  141. print(f"刪除檔案時出錯: {e}")
  142. def send_heartbeat(url, sec=600):
  143. while True:
  144. try:
  145. response = requests.get(url)
  146. if response.status_code != 200:
  147. print(f"Failed to send heartbeat, status code: {response.status_code}")
  148. except requests.RequestException as e:
  149. print(f"Error occurred: {e}")
  150. # 等待 60 秒
  151. time.sleep(sec)
  152. def start_heartbeat(url, sec=600):
  153. heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url, sec))
  154. heartbeat_thread.daemon = True
  155. heartbeat_thread.start()
  156. if __name__ == "__main__":
  157. url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
  158. start_heartbeat(url, sec=600)
  159. # uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080,
  160. # ssl_keyfile="/etc/ssl_file/key.pem",
  161. # ssl_certfile="/etc/ssl_file/cert.pem")
  162. try:
  163. uvicorn.run("systex_app:app", host='0.0.0.0', reload=True, port=8080,
  164. ssl_keyfile="/etc/ssl_file/key.pem", ssl_certfile="/etc/ssl_file/cert.pem")
  165. except KeyboardInterrupt:
  166. print("收到 KeyboardInterrupt,正在清理...")
  167. finally:
  168. cleanup_files()