import re import threading from fastapi import FastAPI, Request, HTTPException, Response, status, Body # from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse from fastapi import Depends from contextlib import asynccontextmanager from pydantic import BaseModel from typing import List, Optional import uvicorn import requests from typing import List, Optional import sqlparse from sqlalchemy import create_engine import pandas as pd #from retrying import retry import datetime import json from json import loads import pandas as pd import time from langchain_community.chat_models import ChatOllama from langchain.callbacks import get_openai_callback from langchain_community.vectorstores import Chroma from langchain_openai import OpenAIEmbeddings # from RAG_strategy import multi_query, naive_rag # from Indexing_Split import create_retriever as split_retriever # from Indexing_Split import gen_doc_from_database, gen_doc_from_history from semantic_search import semantic_cache from dotenv import load_dotenv import os from langchain_community.vectorstores import SupabaseVectorStore from langchain_openai import OpenAIEmbeddings from supabase.client import Client, create_client from file_loader.add_vectordb import GetVectorStore from faiss_index import create_faiss_retriever, faiss_query from local_llm import ollama_, hf from ai_agent import main # from local_llm import ollama_, taide_llm, hf # llm = hf() load_dotenv() URI = os.getenv("SUPABASE_URI") supabase_url = os.environ.get("SUPABASE_URL") supabase_key = os.environ.get("SUPABASE_KEY") supabase: Client = create_client(supabase_url, supabase_key) global_retriever = None llm = None @asynccontextmanager async def lifespan(app: FastAPI): global global_retriever global llm global vector_store start = time.time() document_table = "documents" embeddings = OpenAIEmbeddings() # vector_store = GetVectorStore(embeddings, supabase, document_table) # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5}) global_retriever = create_faiss_retriever() local_llm = "llama3-groq-tool-use:latest" # llm_json = ChatOllama(model=local_llm, format="json", temperature=0) llm = ChatOllama(model=local_llm, temperature=0) print(time.time() - start) yield def get_retriever(): return global_retriever def get_llm(): return llm def get_vector_store(): return vector_store app = FastAPI(lifespan=lifespan) # templates = Jinja2Templates(directory="temp") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ChatHistoryItem(BaseModel): q: str a: str def replace_unicode_escapes(match): return chr(int(match.group(1), 16)) @app.post("/answer_with_history") def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?', chat_history: List[ChatHistoryItem] = Body(...), retriever=Depends(get_retriever), llm=Depends(get_llm)): start = time.time() chat_history = [(item.q, item.a) for item in chat_history if item.a != "" and item.a != "string"] print(chat_history) # TODO: similarity search with get_openai_callback() as cb: # cache_question, cache_answer = semantic_cache(supabase, question) # if cache_answer: # processing_time = time.time() - start # save_history(question, cache_answer, cache_question, cb, processing_time) # return {"Answer": cache_answer} # final_answer, reference_docs = multi_query(question, retriever, chat_history) # final_answer, reference_docs = naive_rag(question, retriever, chat_history) final_answer = faiss_query(question, global_retriever, llm) decoded_string = re.sub(r'\\u([0-9a-fA-F]{4})', replace_unicode_escapes, final_answer) print(decoded_string ) reference_docs = global_retriever.get_relevant_documents(question) processing_time = time.time() - start print(processing_time) save_history(question, decoded_string , reference_docs, cb, processing_time) # print(response) response_content = json.dumps({"Answer": decoded_string }, ensure_ascii=False) # Manually create a Response object if using Flask return JSONResponse(content=response_content) # response_content = json.dumps({"Answer": final_answer}, ensure_ascii=False) # print(response_content) # return json.loads(response_content) @app.get("/agents") def agent(question: str): answer = main(question) return {"answer": answer} def save_history(question, answer, reference, cb, processing_time): # reference = [doc.dict() for doc in reference] record = { 'Question': [question], 'Answer': [answer], 'Total_Tokens': [cb.total_tokens], 'Total_Cost': [cb.total_cost], 'Processing_time': [processing_time], 'Contexts': [str(reference)] if isinstance(reference, list) else [reference] } df = pd.DataFrame(record) engine = create_engine(URI) df.to_sql(name='systex_records', con=engine, index=False, if_exists='append') class history_output(BaseModel): Question: str Answer: str Contexts: str Total_Tokens: int Total_Cost: float Processing_time: float Time: datetime.datetime @app.get('/history', response_model=List[history_output]) async def get_history(): engine = create_engine(URI, echo=True) df = pd.read_sql_table("systex_records", engine.connect()) df.fillna('', inplace=True) result = df.to_json(orient='index', force_ascii=False) result = loads(result) return result.values() def send_heartbeat(url): while True: try: response = requests.get(url) if response.status_code != 200: print(f"Failed to send heartbeat, status code: {response.status_code}") except requests.RequestException as e: print(f"Error occurred: {e}") # 等待 60 秒 time.sleep(600) def start_heartbeat(url): heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url,)) heartbeat_thread.daemon = True heartbeat_thread.start() if __name__ == "__main__": # url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping=' # start_heartbeat(url) uvicorn.run("RAG_app:app", host='0.0.0.0', reload=True, port=8080)