123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import datetime
- from json import loads
- import threading
- import time
- from typing import List
- from fastapi import Body, FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- import pandas as pd
- from pydantic import BaseModel
- import requests
- import uvicorn
- from dotenv import load_dotenv
- import os
- from supabase.client import Client, create_client
- from langchain.callbacks import get_openai_callback
- from ai_agent import main, rag_main
- from ai_agent_llama import main as llama_main
- from semantic_search import semantic_cache, grandson_semantic_cache
- from RAG_strategy import get_search_query
- load_dotenv()
- URI = os.getenv("SUPABASE_URI")
- supabase_url = os.environ.get("SUPABASE_URL")
- supabase_key = os.environ.get("SUPABASE_KEY")
- supabase: Client = create_client(supabase_url, supabase_key)
- app = FastAPI()
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class ChatHistoryItem(BaseModel):
- q: str
- a: str
- @app.post("/agents")
- def agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
- print(question)
- start = time.time()
- # TODO rewrite query
- # _search_query = get_search_query()
- # chat_history = [item for item in chat_history if question != item.q]
- # chat_history = [(item.q, item.a) for item in chat_history[-5:] if item.a != "" and item.a != "string" ]
- # print(chat_history)
- # modified_question = _search_query.invoke({"question": question, "chat_history": chat_history})
- with get_openai_callback() as cb:
- # cache_question, cache_answer = semantic_cache(supabase, question)
- cache_question, cache_answer = grandson_semantic_cache(question)
- # cache_answer = None
- if cache_answer:
- answer = cache_answer
- if "孫子" in answer:
- path = "https://cmm.ai/systex-ai-chatbot/video_cache/"
- video_cache = "grandson2.mp4"
- return {"Answer": answer, "video_cache": path + video_cache}
- else:
- result = main(question)
- answer = result["generation"]
- processing_time = time.time() - start
- # save_history(question + "->" + modified_question, answer, cb, processing_time)
- save_history(question, answer, cb, processing_time)
- if "test@systex.com" in answer:
- answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- print(answer)
- return {"Answer": answer}
- @app.post("/knowledge")
- def rag(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
- print(question)
- start = time.time()
- with get_openai_callback() as cb:
- # cache_question, cache_answer = semantic_cache(supabase, question)
- cache_answer = None
- if cache_answer:
- answer = cache_answer
- else:
- result = rag_main(question)
- answer = result["generation"]
- processing_time = time.time() - start
- save_history(question, answer, cb, processing_time)
- if "test@systex.com" in answer:
- answer = "很抱歉,目前我無法回答您的問題,請將您的詢問發送至 test@systex.com 以便獲得更進一步的幫助,謝謝。"
- print(answer)
- return {"Answer": answer}
- @app.post("/local_agents")
- def local_agent(question: str, chat_history: List[ChatHistoryItem] = Body(...)):
- print(question)
- start = time.time()
- with get_openai_callback() as cb:
- # cache_question, cache_answer = semantic_cache(supabase, question)
- cache_answer = None
- if cache_answer:
- answer = cache_answer
- else:
- result = llama_main(question)
- answer = result["generation"]
- processing_time = time.time() - start
- save_history(question, answer, cb, processing_time)
- return {"Answer": answer}
- def save_history(question, answer, cb, processing_time):
- # reference = [doc.dict() for doc in reference]
- record = {
- 'Question': question,
- 'Answer': answer,
- 'Total_Tokens': cb.total_tokens,
- 'Total_Cost': cb.total_cost,
- 'Processing_time': processing_time,
- }
- response = (
- supabase.table("agent_records")
- .insert(record)
- .execute()
- )
- class history_output(BaseModel):
- Question: str
- Answer: str
- Total_Tokens: int
- Total_Cost: float
- Processing_time: float
- Time: datetime.datetime
- @app.get('/history', response_model=List[history_output])
- async def get_history():
- response = supabase.table("agent_records").select("*").execute()
- df = pd.DataFrame(response.data)
- # engine = create_engine(URI, echo=True)
- # df = pd.read_sql_table("systex_records", engine.connect())
- # df.fillna('', inplace=True)
- result = df.to_json(orient='index', force_ascii=False)
- result = loads(result)
- return result.values()
- def cleanup_files():
- faiss_index_path = "faiss_index.bin"
- metadata_path = "faiss_metadata.pkl"
- try:
- if os.path.exists(faiss_index_path):
- os.remove(faiss_index_path)
- print(f"{faiss_index_path} 已刪除")
- if os.path.exists(metadata_path):
- os.remove(metadata_path)
- print(f"{metadata_path} 已刪除")
- except Exception as e:
- print(f"刪除檔案時出錯: {e}")
- def send_heartbeat(url, sec=600):
- while True:
- try:
- response = requests.get(url)
- if response.status_code != 200:
- print(f"Failed to send heartbeat, status code: {response.status_code}")
- except requests.RequestException as e:
- print(f"Error occurred: {e}")
- # 等待 60 秒
- time.sleep(sec)
- def start_heartbeat(url, sec=600):
- heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url, sec))
- heartbeat_thread.daemon = True
- heartbeat_thread.start()
- if __name__ == "__main__":
- url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
- start_heartbeat(url, sec=600)
- # uvicorn.run("systex_app:app", host='', reload=True, port=8080,
- # ssl_keyfile="/etc/ssl_file/key.pem",
- # ssl_certfile="/etc/ssl_file/cert.pem")
- try:
- uvicorn.run("systex_app:app", host='', reload=True, port=8080,
- ssl_keyfile="/etc/ssl_file/key.pem", ssl_certfile="/etc/ssl_file/cert.pem")
- except KeyboardInterrupt:
- print("收到 KeyboardInterrupt,正在清理...")
- finally:
- cleanup_files()