123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import re
- import threading
- from fastapi import FastAPI, Request, HTTPException, Response, status, Body
- # from fastapi.templating import Jinja2Templates
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse, JSONResponse
- from fastapi import Depends
- from contextlib import asynccontextmanager
- from pydantic import BaseModel
- from typing import List, Optional
- import uvicorn
- import requests
- from typing import List, Optional
- import sqlparse
- from sqlalchemy import create_engine
- import pandas as pd
- #from retrying import retry
- import datetime
- import json
- from json import loads
- import pandas as pd
- import time
- from langchain_community.chat_models import ChatOllama
- from langchain.callbacks import get_openai_callback
- from langchain_community.vectorstores import Chroma
- from langchain_openai import OpenAIEmbeddings
- # from RAG_strategy import multi_query, naive_rag
- # from Indexing_Split import create_retriever as split_retriever
- # from Indexing_Split import gen_doc_from_database, gen_doc_from_history
- from semantic_search import semantic_cache
- from dotenv import load_dotenv
- import os
- from langchain_community.vectorstores import SupabaseVectorStore
- from langchain_openai import OpenAIEmbeddings
- from supabase.client import Client, create_client
- from file_loader.add_vectordb import GetVectorStore
- from faiss_index import create_faiss_retriever, faiss_query
- from local_llm import ollama_, hf
- from ai_agent import main
- # from local_llm import ollama_, taide_llm, hf
- # llm = hf()
- load_dotenv()
- URI = os.getenv("SUPABASE_URI")
- supabase_url = os.environ.get("SUPABASE_URL")
- supabase_key = os.environ.get("SUPABASE_KEY")
- supabase: Client = create_client(supabase_url, supabase_key)
- global_retriever = None
- llm = None
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- global global_retriever
- global llm
- global vector_store
-
- start = time.time()
- document_table = "documents"
- embeddings = OpenAIEmbeddings()
- # vector_store = GetVectorStore(embeddings, supabase, document_table)
- # global_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
- global_retriever = create_faiss_retriever()
- local_llm = "llama3-groq-tool-use:latest"
- # llm_json = ChatOllama(model=local_llm, format="json", temperature=0)
- llm = ChatOllama(model=local_llm, temperature=0)
- print(time.time() - start)
- yield
- def get_retriever():
- return global_retriever
- def get_llm():
- return llm
- def get_vector_store():
- return vector_store
- app = FastAPI(lifespan=lifespan)
- # templates = Jinja2Templates(directory="temp")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- class ChatHistoryItem(BaseModel):
- q: str
- a: str
- def replace_unicode_escapes(match):
- return chr(int(match.group(1), 16))
- @app.post("/answer_with_history")
- def multi_query_answer(question: Optional[str] = '什麼是逸散排放源?', chat_history: List[ChatHistoryItem] = Body(...),
- retriever=Depends(get_retriever), llm=Depends(get_llm)):
- start = time.time()
-
- chat_history = [(item.q, item.a) for item in chat_history if item.a != "" and item.a != "string"]
- print(chat_history)
- # TODO: similarity search
-
- with get_openai_callback() as cb:
- # cache_question, cache_answer = semantic_cache(supabase, question)
- # if cache_answer:
- # processing_time = time.time() - start
- # save_history(question, cache_answer, cache_question, cb, processing_time)
- # return {"Answer": cache_answer}
-
- # final_answer, reference_docs = multi_query(question, retriever, chat_history)
- # final_answer, reference_docs = naive_rag(question, retriever, chat_history)
- final_answer = faiss_query(question, global_retriever, llm)
-
- decoded_string = re.sub(r'\\u([0-9a-fA-F]{4})', replace_unicode_escapes, final_answer)
- print(decoded_string )
- reference_docs = global_retriever.get_relevant_documents(question)
- processing_time = time.time() - start
- print(processing_time)
- save_history(question, decoded_string , reference_docs, cb, processing_time)
- # print(response)
- response_content = json.dumps({"Answer": decoded_string }, ensure_ascii=False)
- # Manually create a Response object if using Flask
- return JSONResponse(content=response_content)
- # response_content = json.dumps({"Answer": final_answer}, ensure_ascii=False)
- # print(response_content)
- # return json.loads(response_content)
- @app.get("/agents")
- def agent(question: str):
- answer = main(question)
- return {"answer": answer}
- def save_history(question, answer, reference, cb, processing_time):
- # reference = [doc.dict() for doc in reference]
- record = {
- 'Question': [question],
- 'Answer': [answer],
- 'Total_Tokens': [cb.total_tokens],
- 'Total_Cost': [cb.total_cost],
- 'Processing_time': [processing_time],
- 'Contexts': [str(reference)] if isinstance(reference, list) else [reference]
- }
- df = pd.DataFrame(record)
- engine = create_engine(URI)
- df.to_sql(name='systex_records', con=engine, index=False, if_exists='append')
- class history_output(BaseModel):
- Question: str
- Answer: str
- Contexts: str
- Total_Tokens: int
- Total_Cost: float
- Processing_time: float
- Time: datetime.datetime
-
- @app.get('/history', response_model=List[history_output])
- async def get_history():
- engine = create_engine(URI, echo=True)
- df = pd.read_sql_table("systex_records", engine.connect())
- df.fillna('', inplace=True)
- result = df.to_json(orient='index', force_ascii=False)
- result = loads(result)
- return result.values()
- def send_heartbeat(url):
- while True:
- try:
- response = requests.get(url)
- if response.status_code != 200:
- print(f"Failed to send heartbeat, status code: {response.status_code}")
- except requests.RequestException as e:
- print(f"Error occurred: {e}")
- # 等待 60 秒
- time.sleep(600)
- def start_heartbeat(url):
- heartbeat_thread = threading.Thread(target=send_heartbeat, args=(url,))
- heartbeat_thread.daemon = True
- heartbeat_thread.start()
-
- if __name__ == "__main__":
-
- # url = 'http://db.ptt.cx:3001/api/push/luX7WcY3Gz?status=up&msg=OK&ping='
- # start_heartbeat(url)
-
- uvicorn.run("RAG_app:app", host='0.0.0.0', reload=True, port=8080)
|